"""This module defines a Form class that wraps FFC forms and UFC forms
into a cpp.Form (dolfin::Form)."""

# Copyright (C) 2008 Johan Hake
#
# This file is part of DOLFIN.
#
# DOLFIN is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# DOLFIN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Anders Logg, 2008-2011
# Modified by Marie E. Rognes, 2011
#
# First added:  2008-12-04
# Last changed: 2009-12-11

__all__ = ["Form"]

# Import SWIG-generated extension module (DOLFIN C++)
import dolfin.cpp as cpp

# Import JIT compiler
from dolfin.compilemodules.jit import jit

# Note that we need to store _compiled_form and _compiled_coefficients
# to prevent Python from garbage-collecting these while still in use.
# FIXME: Figure out how to solve this with shared_ptr

class Form(cpp.Form):

    def __init__(self, form,
                 function_spaces=None,
                 coefficients=None,
                 subdomains=None,
                 form_compiler_parameters=None,
                 common_cell=None):
        "Create JIT-compiled form from any given form (compiled or not)."

        # Compile form if necessary
        if not hasattr(form, "create_cell_integral"):
            (self._compiled_form, module, self.form_data, prefix) \
                = jit(form,
                      form_compiler_parameters,
                      common_cell)
        else:
            self._compiled_form = form
            self.form_data = None

        # Extract function spaces
        self.function_spaces = _extract_function_spaces(self.form_data,
                                                        self._compiled_form,
                                                        function_spaces)

        # Extract coefficients
        (self.coefficients, self._compiled_coefficients) = \
            _extract_coefficients(self.form_data, coefficients)

        # Initialize base class
        cpp.Form.__init__(self, self._compiled_form,
                          self.function_spaces, self.coefficients)

        # Extract subdomains from form_data, override if given explicitly
        self.subdomains = _extract_subdomains(self.form_data, subdomains)

        # Attach subdomains if we have them
        subdomains = self.subdomains.get("cell")
        if subdomains is not None:
            self.set_cell_domains(subdomains)
        subdomains = self.subdomains.get("exterior_facet")
        if subdomains is not None:
            self.set_exterior_facet_domains(subdomains)
        subdomains = self.subdomains.get("interior_facet")
        if subdomains is not None:
            self.set_interior_facet_domains(subdomains)

def _extract_function_spaces(form_data, compiled_form, given_function_spaces):
    "Extract list of test spaces."

    function_space_error = "Error while extracting test and/or trial spaces. "

    function_spaces = []

    if given_function_spaces is None:
        if not hasattr(form_data,"original_arguments"):
            raise TypeError, function_space_error + \
                  "Missing data about basis functions in form data."
        for func in form_data.original_arguments:
            if not isinstance(func.function_space(), cpp.FunctionSpace):
                raise TypeError, function_space_error
            function_spaces.append(func.function_space())
    else:
        if not isinstance(given_function_spaces, (list, cpp.FunctionSpace)):
            raise TypeError, function_space_error
        if isinstance(given_function_spaces, list):
            if len(given_function_spaces) != compiled_form.rank():
                raise ValueError, function_space_error + \
                      " Wrong number of test spaces (should be %d)." % compiled_form.rank()
            for V in given_function_spaces:
                function_spaces.append(V)
        else:
            for i in xrange(compiled_form.rank()):
                function_spaces.append(given_function_spaces)

    return function_spaces

def _extract_coefficients(form_data, given_coefficients):
    "Extract list of coefficients."

    coefficient_error = "Error while extracting coefficients. "

    coefficients = []
    _compiled_coefficients = []

    # Return if nothing to extract
    if form_data is None and given_coefficients is None:
        return (coefficients,  _compiled_coefficients)

    if given_coefficients is None:
        if not hasattr(form_data, "original_coefficients"):
            raise TypeError, coefficient_error + \
                  "Missing data about coefficients in form data."
        for c in form_data.original_coefficients:
            if not isinstance(c, cpp.GenericFunction):
                raise TypeError, coefficient_error + \
                      "Either provide a dict of cpp.GenericFunctions, or use Function to define your form."
            coefficients.append(c)
    else:
        # FIXME: I have disabled compiled_functions based on strings for now
        #       We could ofcourse add it back, but they need a FunctionSpace to
        #       be initialized.
        ## Compile all strings as dolfin::Function
        #string_expressions = []
        #for c in coefficients:
        #    # Note: To allow tuples of floats or ints below, this logic becomes more involved...
        #    if isinstance(c, (tuple, str)):
        #        string_expressions.append(c)
        #if string_expressions:
        #    compiled_functions = compile_functions(string_expressions, mesh)
        #    compiled_functions.reverse()
        #
        # Build list of coefficients
        error_info = "Provide a 'list' with cpp.GenericFunctions"
        if not isinstance(given_coefficients, list):
            raise TypeError, coefficient_error + error_info
        for c in given_coefficients:
            # FIXME: I have turned of these for now. Should probably add something for
            #       at least constant functions
            # Note: We could generalize this to support more objects
            # like sympy expressions, tuples for constant vectors, etc...
            #if isinstance(c, (float, int)):
            #    c = cpp.Function(mesh, float(c))
            #elif isinstance(c, (tuple, str)):
            #    c = compiled_functions.pop()
            if not isinstance(c, cpp.GenericFunction):
                raise TypeError, coefficient_error
            coefficients.append(c)
            _compiled_coefficients.append(c)

    return (coefficients, _compiled_coefficients)

def _extract_subdomains(form_data, override_subdomains):
    "Extract list of subdomains."

    override_subdomains = override_subdomains or {}
    if form_data is None:
        return override_subdomains

    domain_types = ("cell", "exterior_facet", "interior_facet")

    additional_keys = set(override_subdomains.keys()) - set(domain_types)
    if additional_keys:
        raise TypeError, "Invalid keys in domain_types: %s" % additional_keys

    subdomains = {}
    for domain in domain_types:

        domains = override_subdomains.get(domain)
        if domains is None:
            domains = form_data.domain_data.get(domain)

        # FIXME: Add test for MeshFunctions here.
        # Like this?
        #if (domains is not None) and (not isinstance(domains, MeshFunction)):
        #    raise TypeError, "Invalid subdomains type %s" % type(domains)

        subdomains[domain] = domains

    return subdomains
