#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module contains code generation tools for the ufc::dofmap class.
"""

# Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# First added:  2008-08-13
# Last changed: 2009-05-14

import ufl
from sfc.codegeneration.codeformatting import indent, CodeFormatter, gen_token_assignments
from sfc.geometry import gen_geometry_code
from sfc.symbolic_utils import symbol, symbols
from sfc.common import sfc_assert, sfc_error, sfc_warning, sfc_info

class DofMapCG:
    def __init__(self, elementrep):
        self.rep = elementrep
        self.classname  = elementrep.dof_map_classname
        self.signature  = repr(self.rep.ufl_element)
        self.options = self.rep.options.code.dof_map
        
        if self.options.enable_dof_ptv:
            # variables for full initialization by construction, reused a few places:
            vars = ["global_component_stride", "loc2glob_size"]
            self.constructor_vars = vars
            self.constructor_arg_string  = ", ".join(["unsigned int %s_" % v for v in vars])
            self.constructor_arg_string2 = ", ".join(vars)
    
    def hincludes(self):
        l = []
        if self.options.enable_dof_ptv:
            l.extend(["Ptv.h", "DofT.h", "Dof_Ptv.h"])
        return l

    def cincludes(self):
        l = []
        return l

    def generate_code_dict(self):
        vars = {
            'classname'             : self.classname,
            'constructor'           : indent(self.gen_constructor()),
            "constructor_arguments" : indent(self.gen_constructor_arguments()),
            "initializer_list"      : indent(self.gen_initializer_list()),
            'destructor'            : indent(self.gen_destructor()),
            "create"                : indent(self.gen_create()),
            'signature'             : indent(self.gen_signature()),
            'needs_mesh_entities'   : indent(self.gen_needs_mesh_entities()),
            'init_mesh'             : indent(self.gen_init_mesh()),
            'init_cell'             : indent(self.gen_init_cell()),
            'init_cell_finalize'    : indent(self.gen_init_cell_finalize()),
            'global_dimension'      : indent(self.gen_global_dimension()),
            'local_dimension'       : indent(self.gen_local_dimension()),
            'max_local_dimension'   : indent(self.gen_max_local_dimension()),
            'geometric_dimension'   : indent(self.gen_geometric_dimension()),
            "topological_dimension" : indent(self.gen_topological_dimension()),
            'num_facet_dofs'        : indent(self.gen_num_facet_dofs()),
            'num_entity_dofs'       : indent(self.gen_num_entity_dofs()),
            'tabulate_dofs'         : indent(self.gen_tabulate_dofs()),
            'tabulate_facet_dofs'   : indent(self.gen_tabulate_facet_dofs()),
            'tabulate_entity_dofs'  : indent(self.gen_tabulate_entity_dofs()),
            'tabulate_coordinates'  : indent(self.gen_tabulate_coordinates()),
            'num_sub_dofmaps'       : indent(self.gen_num_sub_dofmaps()),
            'create_sub_dofmap'     : indent(self.gen_create_sub_dofmap()),
            'members'               : indent(self.gen_members()),
        }
        return vars

    def gen_constructor(self):
        return ""

    def gen_constructor_arguments(self):
        return ""

    def gen_initializer_list(self):
        return ""

    def gen_destructor(self):
        return ""

    def gen_create(self):
        code = "return new %s();" % self.classname
        return code

    def gen_signature(self):
        """const char* signature() const"""
        return 'return "%s";' % self.signature

    def gen_needs_mesh_entities(self):
        """bool needs_mesh_entities(unsigned int d) const"""
        if self.rep.ufl_element.family() == "Real":
            return 'return false;'
        
        # pick the mesh entities we need
        needs = tuple( [ ('true' if n else 'false') for n in self.rep.num_entity_dofs] )
        # return false when a type of mesh entities are not needed
        code = CodeFormatter()
        code.begin_switch("d")
        for i, n in enumerate(needs):
            code += "case %d: return %s;" % (i, n)
        code.end_switch()
        code += 'throw std::runtime_error("Invalid dimension in needs_mesh_entities.");'
        return str(code)

    def gen_init_mesh(self):
        """bool init_mesh(const mesh& m)"""
        nsd = self.rep.cell.nsd
        if self.rep.ufl_element.family() == "Real":
            return 'return false;'
        
        if not self.options.enable_dof_ptv:
            # compute and store global dimension
            num_entities = symbols(["m.num_entities[%d]" % i for i in range(nsd+1)])
            global_dimension = sum(self.rep.num_entity_dofs[i]*num_entities[i] for i in range(nsd+1))
            code = '_global_dimension = %s;\n' % global_dimension.printc()
            code += "return false;"
            return code

        # This code doesn't work for general mixed elements!
        if isinstance(self.rep.ufl_element, ufl.MixedElement):
            assert isinstance(self.rep.ufl_element, (ufl.VectorElement, ufl.TensorElement))
        local_component_stride = (self.rep.local_dimension // len(self.rep.sub_elements))
        code= CodeFormatter()
        code += "// allocating space for loc2glob map"
        code += "dof.init(m.num_entities[%d], %d);" % (nsd, local_component_stride)
        code += "loc2glob_size = m.num_entities[%d] * %d;\n" % (nsd, local_component_stride)
        code += "return true;"
        return str(code)
    
    def gen_init_cell(self):
        """void init_cell(const mesh& m, const cell& c)"""
        
        if not self.options.enable_dof_ptv:
            return ""
        if self.rep.ufl_element.family() == "Real":
            return ""
        
        # FIXME: This code needs updating, e.g. it doesn't handle mixed elements.
        
        #fe    = self.rep.syfi_element
        nsd   = self.rep.cell.nsd
        nbf   = self.rep.local_dimension

        code = CodeFormatter()
        code.new_text( gen_geometry_code(nsd, detG=False) )
        code += "unsigned int element = c.entity_indices[%d][0];" % (nsd)

        nsc = len(self.rep.sub_elements)

        if nsc > 1:
            code += "// ASSUMING HERE THAT THE DOFS FOR EACH SUB COMPONENT ARE GROUPED"
            code += "// Only counting and numbering dofs for a single sub components"

        for i in range(nbf // nsc):
            x_strings = [self.rep.dof_x[i][d].printc() for d in range(nsd)]
            dof_vals = ", ".join( x_strings )
            num_dof_vals = nsd

            if nsc > 1:
                assert nsc == self.rep.value_size
                # This code could be useful later if for some reason we need to revert
                # to the old numbering scheme where sub component dofs are intertwined:
                #dofi = fe.dof(i)
                #assert isinstance(dofi[0], list)
                #for d in dofi[1:]:
                #    dof_vals += ", " + d.printc()
                #    num_dof_vals += 1

            code += ""
            code += "double dof%d[%d] = { %s };" % (i, num_dof_vals, dof_vals)
            code += "Ptv pdof%d(%d, dof%d);" % (i, num_dof_vals, i)
            code += "dof.insert_dof(element, %d, pdof%d);" % (i, i)

        return str(code)
    
    def gen_init_cell_finalize(self):
        """void init_cell_finalize()"""
        if self.rep.ufl_element.family() == "Real":
            return ""

        code = ""
        if self.options.enable_dof_ptv:
            #code += "dof.build_loc2glob();\n"
            code += "loc2glob = dof.get_loc2glob_array();\n"
            # constant for tabulating vector element dofs from a scalar element numbering 
            code += "global_component_stride = dof.global_dimension();\n"
            # store global dimension
            code += '_global_dimension = global_component_stride * %d;\n' % len(self.rep.sub_elements)
            # clear temporary datastructures
            code += 'dof.clear();\n'
        return code

    def gen_global_dimension(self):
        """unsigned int global_dimension() const"""
        if self.rep.ufl_element.family() == "Real":
            return 'return %d;' % self.rep.value_size
        return 'return _global_dimension;'

    def gen_local_dimension(self):
        """unsigned int local_dimension(const cell& c) const"""
        return 'return %d;' % self.rep.local_dimension

    def gen_max_local_dimension(self):
        """unsigned int max_local_dimension() const"""
        return 'return %d;' % self.rep.local_dimension

    def gen_geometric_dimension(self):
        """unsigned int geometric_dimension() const"""
        return 'return %d;' % self.rep.cell.nsd

    def gen_topological_dimension(self):
        return "return %d;" % self.rep.cell.nsd

    def gen_num_facet_dofs(self):
        """unsigned int num_facet_dofs() const"""
        return "return %d;" % self.rep.num_facet_dofs

    def gen_num_entity_dofs(self):
        """unsigned int num_entity_dofs(unsigned int d) const"""
        if self.rep.ufl_element.family() == "Real":
            return 'return 0;'
        code = CodeFormatter()
        code.begin_switch("d")
        for i in range(self.rep.cell.nsd+1):
            code.begin_case(i)
            code += "return %d;" % self.rep.num_entity_dofs[i]
            code.end_case()
        code.end_switch()
        code += 'throw std::runtime_error("Invalid entity dimension.");'
        return str(code)
    
    def gen_tabulate_dofs(self):
        if self.rep.ufl_element.family() == "Real":
            return '\n'.join('dofs[%d] = %d;' % (i, i) for i in range(self.rep.value_size))

        if self.options.enable_dof_ptv:
            return self.gen_tabulate_dofs__dof_ptv()
        else:
            return self.gen_tabulate_dofs__implicit()

    def gen_tabulate_dofs__implicit(self):
        """void tabulate_dofs(unsigned int* dofs,
                               const mesh& m,
                               const cell& c) const"""
        code = CodeFormatter()

        cell = self.rep.cell
        nsd = cell.nsd

        # symbols referencing entity index arrays
        mesh_num_entities = symbols("m.num_entities[%d]" % d for d in range(nsd+1))
        cell_entity_indices = []
        for d in range(nsd+1):
            cell_entity_indices += [symbols( "c.entity_indices[%d][%d]" % (d, i) for i in range(cell.num_entities[d]) )]
        
        def iter_sub_elements(rep):
            "Flatten the sub element hierarchy into a list."
            if rep.sub_elements:
                for r in rep.sub_elements:
                    for s in iter_sub_elements(r):
                        yield s
            else:
                yield rep
        
        # (A) Iterate over all basic elements in order
        local_subelement_offset = 0
        global_subelement_offset = symbol("global_subelement_offset")
        code += "int %s = 0;" % global_subelement_offset
        for rep in iter_sub_elements(self.rep):
            # (B) Loop over entity dimensions d in order
            local_entity_offset = 0
            global_entity_offset = 0
            tokens = []
            for d in range(nsd+1):
                # The offset for dofs in this loop is (global_subelement_offset + global_entity_offset)
                # (C) Loop over entities (d,i) in order
                for i in range(cell.num_entities[d]):
                    
                    # this is the global mesh index of cell entity (d,i)
                    entity_index = cell_entity_indices[d][i]
                    
                    # For each entity (d,i) we have a list of dofs
                    entity_dofs = rep.entity_dofs[d][i]
                    sfc_assert(len(entity_dofs) == rep.num_entity_dofs[d], "Inconsistency in entity dofs.")
                    
                    for (j,dof) in enumerate(entity_dofs):
                        local_value = entity_index * rep.num_entity_dofs[d] + j
                        value = global_subelement_offset + global_entity_offset + local_value
                        name = symbol("dofs[%d]" % (local_subelement_offset + dof))
                        tokens.append((name, value))
                
                # (B) Accumulate offsets to dofs on entities of dimension d
                local_entity_offset  += cell.num_entities[d] * rep.num_entity_dofs[d]
                global_entity_offset += mesh_num_entities[d] * rep.num_entity_dofs[d]
            
            # (A) Accumulate subelement offsets
            sfc_assert(rep.local_dimension == len(tokens), "Collected too few dof tokens!")
            local_subelement_offset += rep.local_dimension
            global_subelement_size = global_entity_offset
            
            code.begin_block()
            code += "// Subelement with signature: %s" % rep.signature
            code += gen_token_assignments(tokens)
            code += "%s += %s;" % (global_subelement_offset.printc(), global_subelement_size.printc())
            code.end_block()
            
        sfc_assert(local_subelement_offset == self.rep.local_dimension,
                   "Dof computation didn't accumulate correctly!")
        return str(code)
    
    def gen_tabulate_dofs__dof_ptv(self):
        """void tabulate_dofs(unsigned int* dofs,
                               const mesh& m,
                               const cell& c) const"""
        if isinstance(self.rep.ufl_element, ufl.MixedElement):
            assert isinstance(self.rep.ufl_element, (ufl.VectorElement, ufl.TensorElement))
        local_component_stride = (self.rep.local_dimension // len(self.rep.sub_elements)) 

        code = CodeFormatter()
        code += "const unsigned int global_element_offset = %d * c.entity_indices[%d][0];" % (local_component_stride, self.rep.cell.nsd)  
        code += "const unsigned int *scalar_dofs = loc2glob.get() + global_element_offset;"

        code += "for(unsigned int iloc=0; iloc<%d; iloc++)" % local_component_stride
        code.begin_block()

        code += "const unsigned int global_scalar_dof = scalar_dofs[iloc];"
        for i in range(len(self.rep.sub_elements)):
            code += "dofs[iloc + %d * %d] = global_scalar_dof + global_component_stride * %d;" % (local_component_stride, i, i)

        code.end_block()

        return str(code)

    def gen_tabulate_facet_dofs(self):
        """void tabulate_facet_dofs(unsigned int* dofs,
                                     unsigned int facet) const
        This implementation should be general for elements with point evaluation dofs on simplices.
        """
        if self.rep.ufl_element.family() == "Real":
            return 'throw std::runtime_error("tabulate_facet_dofs not implemented for Real elements.");'
        # generate code for each facet: for each facet i, tabulate local dofs[j]
        code = CodeFormatter()
        code.begin_switch("facet")
        for i, fd in enumerate(self.rep.facet_dofs):
            code.begin_case(i)
            for j, d in enumerate(fd):
                code += "dofs[%d] = %d;" % (j, d)
            code.end_case()
        code += "default:"
        code.indent()
        code += 'throw std::runtime_error("Invalid facet number.");'
        code.dedent()
        code.end_switch()

        return str(code)
    
    def gen_tabulate_entity_dofs(self):
        """void tabulate_entity_dofs(unsigned int* dofs,
                                     unsigned int d, unsigned int i) const
        """
        if self.rep.ufl_element.family() == "Real":
            return 'throw std::runtime_error("tabulate_entity_dofs not implemented for Real elements.");'
        code = CodeFormatter()
        # define one case for each cell entity (d, i)
        code.begin_switch("d")
        for d in range(self.rep.cell.nsd+1):
            if any(self.rep.entity_dofs[d]):
                code.begin_case(d)
                code.begin_switch("i")
                n = self.rep.cell.num_entities[d]
                for i in range(n):
                    # get list of local dofs associated with cell entity (d, i)
                    dofs_on_entity = self.rep.entity_dofs[d][i]
                    sfc_assert(len(dofs_on_entity) == self.rep.num_entity_dofs[d], "Inconsistency in entity dofs.")
                    code.begin_case(i)
                    for k, ed in enumerate(dofs_on_entity):
                        code += "dofs[%d] = %d;" % (k, ed)
                    code.end_case()
                code.end_switch()
                code.end_case()
        code.end_switch()
        return str(code)

    def gen_tabulate_coordinates(self):
        """void tabulate_coordinates(double** coordinates,
                                      const cell& c) const"""
        if self.rep.ufl_element.family() == "Real":
            return 'throw std::runtime_error("tabulate_coordinates not implemented for Real elements.");'
        code = CodeFormatter()
        code += gen_geometry_code(self.rep.cell.nsd, detG=False)
        for i in range(self.rep.local_dimension):
            for k in range(self.rep.cell.nsd):
                # generate code to compute component k of the coordinate for dof i
                code += "coordinates[%d][%d] = %s;" % (i, k, self.rep.dof_x[i][k].printc())
        return str(code)

    def gen_num_sub_dofmaps(self):
        """unsigned int num_sub_dofmaps() const"""
        return "return %d;" % len(self.rep.sub_elements)

    def gen_create_sub_dofmap(self):
        """dofmap* create_sub_dofmap(unsigned int i) const"""
        if self.options.enable_dof_ptv:
            if len(self.rep.sub_elements) > 1:
                code = CodeFormatter()
                code.begin_switch("i")
                for i, fe in enumerate(self.rep.sub_elements):
                    code += "case %d: return new %s(loc2glob, %s);" % (i, fe.dof_map_classname, self.constructor_arg_string2)
                code.end_switch()
                code += 'throw std::runtime_error("Invalid index in create_sub_dofmap.");'
            else:
                code = "return new %s(loc2glob, %s);" % (self.classname, self.constructor_arg_string2)
        else:
            if len(self.rep.sub_elements) > 1:
                code = CodeFormatter()
                code.begin_switch("i")
                for i, fe in enumerate(self.rep.sub_elements):
                    code += "case %d: return new %s();" % (i, fe.dof_map_classname)
                code.end_switch()
                code += 'throw std::runtime_error("Invalid index in create_sub_dofmap.");'
            else:
                code = "return new %s();" % self.classname # FIXME: Should we throw error here instead now?
        return str(code)

    def gen_members(self):
        cell = self.rep.cell
        nsd = cell.nsd
        code = CodeFormatter()

        # dof data structures
        #code += "protected:"
        code += "public:"
        code.indent()
        if self.rep.ufl_element.family() != "Real":
            code += "unsigned int _global_dimension;"

        if self.options.enable_dof_ptv:
            code += "Dof_Ptv dof;"
            code += "std::tr1::shared_ptr<unsigned int> loc2glob;"
            code += "unsigned int global_component_stride;" # for tabulating vector element dofs from a scalar element numbering
            code += 'unsigned int loc2glob_size;'
        code.dedent()

        if self.options.enable_dof_ptv:
            # add additional constructor to pass initialization info to share initialized dofmap memory
            code += "public:"
            code.indent()
            args = self.constructor_arg_string
            code += "%s(std::tr1::shared_ptr<unsigned int> loc2glob, %s);" % (self.classname, args)
            code.dedent()

        return str(code)

    def generate_support_code(self):
        """Generate local utility functions.""" 
        nsd = self.rep.cell.nsd

        code = CodeFormatter()
        #code += "namespace { // local namespace"
        #code += "  // code private to this compilation unit (.cpp file) goes here"
        #code += "} // end local namespace"
        #code += ""

        if self.options.enable_dof_ptv:
            # Implement additional constructor for shared memory between initializated dofmaps
            code += "%s::%s(std::tr1::shared_ptr<unsigned int> loc2glob_, %s):" % (self.classname, self.classname, self.constructor_arg_string)
            code.indent()
            code += "loc2glob(loc2glob_),"
            for v in self.constructor_vars[:-1]:
                code += "%s(%s_)," % (v, v)
            v = self.constructor_vars[-1]
            code += "%s(%s_)" % (v, v)
            code.dedent()
        
        return str(code)

