# -*- coding: utf-8 -*-
"""
This module contains representation classes for integrals.
"""

# Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Kent-Andre Mardal, 2010.
#
# First added:  2008-08-13
# Last changed: 2009-03-12

from itertools import izip

#import SyFi
import swiginac

#import ufl
from ufl.permutation import compute_indices
from ufl.algorithms import Graph, partition, expand_indices2, expand_indices, strip_variables, tree_format
from ufl.classes import Expr, Terminal, UtilityType, SpatialDerivative, Indexed

from sfc.common import sfc_assert, sfc_warning, sfc_debug, sfc_error
from sfc.common.utilities import indices_subset
from sfc.codegeneration.codeformatting import indent, row_major_index_string
from sfc.symbolic_utils import symbols, symbol
from sfc.representation.swiginac_eval import SwiginacEvaluator, EvaluateAsSwiginac

def find_locals(Vin, Vout, P):
    # Count the number of vertices depending on each vertex
    Vcount = [len(vi) for vi in Vin]
    # Count down all dependencies internally in the partition, O(|E| log|P|)
    Pset = set(P)
    for i in P: # v_i is in our partition
        for j in Vout[i]: # v_i depends on v_j 
            if j in Pset: # is v_j in our partition?
                Vcount[j] -= 1 # don't count with internal dependencies in partition
    # Return mapping (v_i -> is_local_to_P) for all v_i in P
    return dict((i, (Vcount[i] == 0)) for i in P)

def is_simple(e): 
    return (e.nops() == 0)
    # TODO: Try something like this:
    #return (e.nops() == 0) or \
    #    (isinstance(e, (swiginac.add, swiginac.mul)) \
    #     and all(o.nops() for o in e.ops()))

class IntegralData:
    def __init__(self, itgrep, integral):
        self.itgrep = itgrep
        self.integral = integral

        # TODO: Override these and other options with metadata
        metadata = integral.measure().metadata()
        self.integration_method = itgrep.options.integration_method
        self.integration_order  = itgrep.options.integration_order
        
        self.quad_rule = None
        self.facet_quad_rules = None

        self.G = None
        self.Vdeps = None
        self.Vin_count = None
        self.partitions = {}
        self.evaluator = None
        self.evaluate = None
        
        self.V_symbols   = {}
        self.vertex_data_set = {}

    def store(self, value, i, basis_functions):
        #print "STORING AT", i, basis_functions, str(value)
        vd = self.vertex_data_set.get(basis_functions)
        if vd is None:
            vd = {}
            self.vertex_data_set[basis_functions] = vd
        vd[i] = value

    def get_vd(self, j, basis_functions):
        jkeep = [(("v%d" % k) in self.Vdeps[j]) for k in range(self.itgrep.formrep.rank)]
        iota, = indices_subset([basis_functions], jkeep)
        if all(i is None for i in iota):
            iota = ()
        vd = self.vertex_data_set[iota]
        return vd

    def fetch_storage(self, j, basis_functions):
        vd = self.get_vd(j, basis_functions)
        return vd[j]

    def free_storage(self, j, basis_functions):
        "Delete stored data for j."
        vd = self.get_vd(j, basis_functions)
        #del vd[j] # FIXME: Enable when counts are corrected

class IntegralRepresentation(object):
    def __init__(self, integrals, formrep, on_facet):
        sfc_debug("Entering IntegralRepresentation.__init__")
        
        self.integrals = integrals
        self.formrep   = formrep
        self._on_facet = on_facet
        self.classname = formrep.itg_names[integrals[0]]
        
        self.options = self.formrep.options.code.integral
        
        # Shape of element tensor
        A_shape = []
        for i in range(self.formrep.rank):
            element = self.formrep.formdata.elements[i]
            rep = self.formrep.element_reps[element]
            A_shape.append(rep.local_dimension)
        self.A_shape = tuple(A_shape)
        
        # All permutations of element tensor indices TODO: Not same for interior_facet_integrals, move to subclasses
        self.indices = compute_indices(self.A_shape)
        
        # Symbols for output tensor entry A[iota]
        Asym = symbols("A[%s]" % row_major_index_string(i, self.A_shape) for i in self.indices)
        self.A_sym = dict(izip(self.indices, Asym))
        
        # Data structures for token symbols
        self._free_symbols   = []
        self._symbol_counter = 0
        
        # Data structures to hold integrals for each measure
        self.symbolic_integral = None  # Only a single symbolic integral is allowed
        self.quadrature_integrals = [] # Multiple quadrature integrals are possible, each with different quadrature rules
        self.integral_data = {}        # Convenient mapping from measure to IntegralData object

        # Compute data about integrals for each measure
        fd = self.formrep.formdata
        for integral in self.integrals:
            data = IntegralData(self, integral)
            
            if data.integration_method == "symbolic":
                self.symbolic_integral = integral
                data.evaluator = SwiginacEvaluator(self.formrep, use_symbols=False, on_facet=self._on_facet)
            
            elif data.integration_method == "quadrature":
                
                # FIXME: Build a separate quadrature rule for each integral, if necessary
                data.quad_rule = self.formrep.quad_rule
                data.facet_quad_rules = self.formrep.facet_quad_rules
                
                self.quadrature_integrals.append(integral)
                if self.options.safemode:
                    # Produce naive code, good for debugging and verification
                    data.evaluator = SwiginacEvaluator(self.formrep, use_symbols=True, on_facet=self._on_facet)
                
                else:
                    # Create an evaluation multifunction
                    data.evaluate = EvaluateAsSwiginac(self.formrep, self, data, on_facet=self._on_facet)

                    # Expand all indices before building graph
                    integrand = integral.integrand()
                    integrand = strip_variables(integrand)
                    if self.options.use_expand_indices2:
                        integrand = expand_indices2(integrand)
                    else:
                        integrand = expand_indices(integrand)
                    
                    # Build linearized computational graph 
                    data.G = Graph(integrand)
                    
                    # Count dependencies for each node
                    V = data.G.V()
                    n = len(V)
                    data.Vin_count = [0]*n
                    for i, vs in enumerate(data.G.Vin()):
                        for j in vs:
                            vj = V[j]
                            # ... FIXME: Need to store one counter for each iota relevant for i, and to add "size(iota_space_j)" instead of 1
                            data.Vin_count[i] += 1
                    #data.Vin_count[-1] += ... # FIXME: Increment further for integrand root vertex
                    
                    # Partition graph based on dependencies
                    #basis_function_deps = [set(("v%d" % i, "x")) \
                    #                       for (i, bf) in enumerate(fd.basis_functions)]
                    #function_deps = [set(("w", "c", "x",)) \
                    #                 for (i, bf) in enumerate(fd.functions)]
                    # TODO: Input actual dependencies of functions and
                    #       basis functions and their derivatives here
                    data.partitions, data.Vdeps = partition(data.G)
            
            self.integral_data[integral.measure()] = data
            
        sfc_debug("Leaving IntegralRepresentation.__init__")
    
    # --- Printing ---
    
    def __str__(self):
        s = ""
        s += "IntegralRepresentation:\n"
        s += "  classname: %s\n" % self.classname
        s += "  A_shape: %s\n" % self.A_shape
        s += "  indices: %s\n" % self.indices
        s += "  A_sym:   %s\n" % self.A_sym

        # Not computing whole A at any point because of memory usage
        #s += "  A:\n"
        #for ii in self.indices:
        #    s += "    %s = %s\n" % (self.A_sym[ii], self.A[ii]) 

        s += "  UFL Integrals:\n"
        for integral in self.integrals:
            s += indent(str(integral)) + "\n\n"
        return s.strip("\n")

    # --- Utilities for allocation of symbols and association of data with graph vertices ---

    def allocate(self, data, j):
        """Allocate symbol(s) for vertex j. Either gets symbol from the free
        symbol set or creates a new symbol and increases the symbol counter."""
        if self._free_symbols:
            return self._free_symbols.pop()
        s = symbol("s[%d]" % self._symbol_counter)
        self._symbol_counter += 1
        data.V_symbols[j] = s
        return s
    
    def free_symbols(self, data, j):
        "Delete stored data for j and make its symbols available again."
        s = data.V_symbols.get(j)
        if s is None:
            sfc_debug("Trying to deallocate symbols that are not allocated!")
            return
        self._free_symbols.append(s)
        del data.V_symbols[j]
    
    # --- Generators for tokens of various kinds ---
    
    def iter_partition(self, data, deps, basis_functions=()):
        sfc_debug("Entering IntegralRepresentation.iter_partition")
        
        deps = frozenset(deps)
        
        P = data.partitions.get(deps)
        if not P:
            sfc_debug("Leaving IntegralRepresentation.iter_partition, empty")
            return
        
        data.evaluate.current_basis_function = basis_functions
        
        # Get graph components
        G = data.G
        V = G.V()
        E = G.E()
        Vin  = G.Vin()
        Vout = G.Vout()
        
        # Figure out which variables are only needed within the partition
        is_local = find_locals(Vin, Vout, P)
        
        # For all vertices in partition P
        for i in P:
            v = V[i]
            
            if isinstance(v, UtilityType):
                # Skip labels and indices
                continue
            
            if v.shape():
                # Skip expressions with shape: they will be evaluated when indexed.
                if not isinstance(v, (Terminal, SpatialDerivative)):
                    print "="*30
                    print "type:", type(v)
                    print "str:", str(v)
                    print "child types:", [type(V[j]) for j in Vout[i]]
                    print "child str:"
                    print "\n".join( "    vertex %d: %s" % (j, str(V[j])) for j in Vout[i] )
                    print "number of parents:", len(Vin[i])
                    if len(Vin[i]) < 5:
                        print "parent types:", [type(V[j]) for j in Vin[i]]
                    #print "parent str:"
                    #print "\n".join( "    vertex %d: %s" % (j, str(V[j])) for j in Vin[i] )
                    sfc_error("Expecting all indexing to have been propagated to terminals?")
                continue
            
            if Vin[i] and all(isinstance(V[j], SpatialDerivative) for j in Vin[i]):
                # Skipping expressions that only occur later in
                # differentiated form: we don't need their non-differentiated
                # expression so they will be evaluated when indexed.
                if not isinstance(v, (Terminal, SpatialDerivative, Indexed)):
                    print "="*30
                    print type(v)
                    print str(v)
                    sfc_error("Expecting all indexing to have been propagated to terminals?")
                continue
            
            if isinstance(v, (Indexed, SpatialDerivative)):
                ops = v.operands()

                # Evaluate vertex v with given mapped ops
                if not all(isinstance(o, (Expr, swiginac.basic)) for o in ops):
                    print ";"*80
                    print tree_format(v)
                    print str(v)
                    print type(ops)
                    print str(ops)
                    print repr(ops)
                    print "types:"
                    print "\n".join(str(type(o)) for o in ops)
                    print ";"*80
        
                e = data.evaluate(v, *ops)
            
            else:
                # Get already computed operands
                # (if a vertex isn't already computed
                # at this point, that is a bug)
                ops = []
                for j in Vout[i]:
                    try:
                        # Fetch expression or symbol for vertex j
                        e = data.fetch_storage(j, basis_functions)
                    except:
                        print "Failed to fetch expression for vertex %d," % j
                        print "    V[%d] = %s" % (j, repr(V[j]))
                        print "    parent V[%d] = %s" % (i, repr(V[i]))
                        raise RuntimeError
                    ops.append(e)
                ops = tuple(ops)
                e = data.evaluate(v, *ops)
            
            # TODO: Make sure that dependencies of skipped expressions 
            #       are counted down when evaluated later!
            #       Or do we need that? We don't make symbols for them anyway.
            
            # Count down number of uses of v's dependencies
            for j in Vout[i]: # For each vertex j that depend on vertex i
                data.Vin_count[j] -= 1
                if False: # data.Vin_count[j] == 0: # FIXME: This doesn't work yet, counters are wrong
                    # Make the symbols for j available again
                    self.free_symbols(data, j)
                    data.free_storage(j, basis_functions)
            
            # Store some result for vertex i and yield token based on some heuristic
            if is_simple(e):
                # Store simple expression associated with
                # vertex j, no code generation necessary
                data.store(e, i, basis_functions)
            else:
                if is_local[i]:
                    pass # TODO: Use this information to make a local set of symbols "sl[%d]" for each partition
                
                # Allocate a symbol and remember it, just throw away the expression
                s = self.allocate(data, i)
                data.store(s, i, basis_functions)
                # Only yield when we have a variable to generate code for!
                yield (s, e)
        
        sfc_debug("Leaving IntegralRepresentation.iter_partition")
        
        # Some stuff from the old code:
        #bf = self._current_basis_function # FIXME: Filter depending on deps!
        #key = tuple(chain((count,), component, index_values, bf))
        #compstr = "_".join("%d" % k for k in key)
        #vname = "s_%s" % compstr
        
        # Register token (s, e) with variable v in evaluator,
        # such that it can return it from evaluator.variable 
        #self.evaluator._variable2symbol[key] = vsym
    
    def iter_member_quad_tokens(self, data): # FIXME: Make this into precomputation of a static array of constants
        "Return an iterator over member tokens dependent of spatial variables. Overload in subclasses!"
        
        assert data.integration_method == "quadrature"
        
        # TODO: Precompute basis function values symbolically here instead of generating a loop in the constructor 
        
        # TODO: Skip what's not needed!
        
        # Precompute all basis functions
        fr = self.formrep
        fd = fr.formdata
        generated = set()
        for iarg in range(fr.rank + fr.num_coefficients):
            elm = fd.elements[iarg]
            rep = fr.element_reps[elm]
            for i in range(rep.local_dimension):
                for component in rep.value_components:
                    # Yield basis function itself
                    s = fr.v_sym(iarg, i, component, self._on_facet)
                    if not (s == 0 or s in generated):
                        e = fr.v_expr(iarg, i, component)
                        t = (s, e)
                        yield t
                        generated.add(s)
                    # Yield its derivatives w.r.t. local coordinates
                    for d in range(fr.cell.nsd):
                        der = (d,)
                        s = fr.dv_sym(iarg, i, component, der, self._on_facet)
                        if not (s == 0 or s in generated):
                            e = fr.dv_expr(iarg, i, component, der)
                            t = (s, e)
                            yield t
                            generated.add(s)
    
    def iter_geometry_tokens(self):
        "Return an iterator over geometry tokens independent of spatial variables. Overload in subclasses!"
        fr = self.formrep
        
        # TODO: Skip what's not needed!
        
        # vx
        for (ss,ee) in zip(fr.vx_sym, fr.vx_expr):
            for i in range(ss.nops()):
                yield (ss.op(i), ee.op(i))
        
        # G
        (ss,ee) = (fr.G_sym, fr.G_expr)
        for i in range(ss.nops()):
            yield (ss.op(i), ee.op(i))
        
        # detG
        yield (fr.detGtmp_sym, fr.detGtmp_expr)
        yield (fr.detG_sym, fr.detG_expr)
        
        # Ginv
        (ss,ee) = (fr.Ginv_sym, fr.Ginv_expr)
        for i in range(ss.nops()):
            yield (ss.op(i), ee.op(i))
        
        if self._on_facet:
            # Needed for normal vector TODO: Skip if not needed
            yield (fr.detG_sign_sym, fr.detG_sign_expr)
        else:
            if self.symbolic_integral is not None:
                # Scaling by cell volume factor, determinant of coordinate mapping
                yield (fr.D_sym, fr.detG_sym)
    
    def iter_runtime_quad_tokens(self, data):
        "Return an iterator over runtime tokens dependent of spatial variables. Overload in subclasses!"
        assert data.integration_method == "quadrature"
        
        # TODO: yield geometry tokens like G etc here instead if they depend on x,y,z
        
        # TODO: Skip what's not needed!
        
        # Generate all basis function derivatives
        fr = self.formrep
        fd = fr.formdata
        generated = set()
        for iarg in range(fr.rank + fr.num_coefficients):
            elm = fd.elements[iarg]
            rep = fr.element_reps[elm]
            for i in range(rep.local_dimension):
                for component in rep.value_components:
                    # Yield first order derivatives w.r.t. global coordinates
                    for d in range(fr.cell.nsd):
                        der = (d,)
                        s = fr.Dv_sym(iarg, i, component, der, self._on_facet)
                        if not (s == 0 or s in generated):
                            e = fr.Dv_expr(iarg, i, component, der, True, self._on_facet)
                            t = (s, e)
                            yield t
                            generated.add(s)
        
        # Currently placing all w in here:
        generated = set()
        for iarg in range(fr.num_coefficients):
            elm = fd.elements[fr.rank+iarg]
            rep = fr.element_reps[elm]
            for component in rep.value_components:
                # Yield coefficient function itself
                s = fr.w_sym(iarg, component)
                if not s in generated:
                    e = fr.w_expr(iarg, component, True, self._on_facet)
                    t = (s, e)
                    yield t
                    generated.add(s)
                # Yield first order derivatives w.r.t. global coordinates
                for d in range(fr.cell.nsd):
                    der = (d,)
                    s = fr.Dw_sym(iarg, component, der)
                    if not (s == 0 or s in generated):
                        e = fr.Dw_expr(iarg, component, der, True, self._on_facet)
                        t = (s, e)
                        yield t
                        generated.add(s)
        
        # Scaling factor
        if self._on_facet:
            # TODO: yield tokens that depend on both x and facet here
            
            # Scaling by facet area factor
            D_expr = fr.quad_weight_sym*fr.facet_D_sym
        else:
            # Scaling by cell volume factor, determinant of coordinate mapping
            D_expr = fr.quad_weight_sym*fr.detG_sym
        yield (fr.D_sym, D_expr)

    # --- Element tensor producers
    
    def iter_A_tokens(self, data, facet=None):
        "Iterate over all A[iota] tokens."
        for iota in self.indices:
            A_sym = self.A_sym[iota]
            A_expr = self.compute_A(data, iota, facet)
            yield (A_sym, A_expr)

    def compute_A(self, data, iota, facet=None):
        "Compute expression for A[iota]. Overload in subclasses!"
        raise NotImplementedError

