// Copyright (C) 2009 Martin Sandve Alnes
//
// This file is part of SyFi.
//
// SyFi is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 2 of the License, or
// (at your option) any later version.
//
// SyFi is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with SyFi. If not, see <http://www.gnu.org/licenses/>.
//
// First added:  2009-01-01
// Last changed: 2009-04-01
//
// This demo program solves Poissons equation in 3D.

#include <cmath>
#include <dolfin.h>
#include "generated_code/Poisson3D.h"

using std::sqrt;
using std::sin;
using std::cos;
using std::acos;
using namespace dolfin;
using namespace Poisson3D;

const double a = 1.0;
const double b = 1.0;
const double k = 1.0;
const double Pi = acos(-1);

class Source: public Expression 
{
public:
  void eval(Array<double>& values, const Array<double>& x) const
  {
    double dx = x[0] - 0.5;
    double dy = x[1] - 0.5;
    double dz = x[2] - 0.5;
    values[0] =  b * 3.0 * (Pi*Pi)*(k*k) * cos(Pi*k*x[0])*cos(x[1]*Pi*k)*cos(Pi*k*x[2])
              +  a * 3.0 * (Pi*Pi)*(k*k) * sin(x[1]*Pi*k)*sin(Pi*k*x[2])*sin(Pi*k*x[0]);
  }
};

class BoundarySource: public Expression 
{
public:
  void eval(Array<double>& values, const Array<double>& x) const
  {
    double dx = x[0] - 0.5;
    double dy = x[1] - 0.5;
    double dz = x[2] - 0.5;

    values[0] = 0.0;

    bool gx0 = x[0] < DOLFIN_EPS;
    bool gx1 = x[0] > 1.0-DOLFIN_EPS;
    bool gy0 = x[1] < DOLFIN_EPS;
    bool gy1 = x[1] > 1.0-DOLFIN_EPS;
    bool gz0 = x[2] < DOLFIN_EPS;
    bool gz1 = x[2] > 1.0-DOLFIN_EPS;

    if     (gx0) values[0] = -sin(x[1]*Pi*k)*sin(Pi*k*x[2])*Pi*k*a;
    else if(gx1) values[0] =  sin(x[1]*Pi*k)*sin(Pi*k*x[2])*Pi*k*cos(Pi*k)*a-sin(Pi*k)*b*Pi*k*cos(x[1]*Pi*k)*cos(Pi*k*x[2]);
    else if(gy0) values[0] = -sin(Pi*k*x[2])*Pi*k*sin(Pi*k*x[0])*a;
    else if(gy1) values[0] =  sin(Pi*k*x[2])*Pi*k*cos(Pi*k)*sin(Pi*k*x[0])*a-cos(Pi*k*x[0])*sin(Pi*k)*b*Pi*k*cos(Pi*k*x[2]);
    else if(gz0) values[0] = -sin(x[1]*Pi*k)*Pi*k*sin(Pi*k*x[0])*a;
    else if(gz1) values[0] =  sin(x[1]*Pi*k)*Pi*k*cos(Pi*k)*sin(Pi*k*x[0])*a-cos(Pi*k*x[0])*sin(Pi*k)*b*Pi*k*cos(x[1]*Pi*k);
  }
};

class Solution: public Expression 
{
public:
  void eval(Array<double>& values, const Array<double>& x) const
  {
    values[0] =  b * cos(Pi*k*x[0])*cos(x[1]*Pi*k)*cos(Pi*k*x[2]) + a * sin(x[1]*Pi*k)*sin(Pi*k*x[2])*sin(Pi*k*x[0]);
  }
};

class DirichletBoundary: public SubDomain
{
  bool inside(const Array<double>& x, bool on_boundary) const
  {
    return on_boundary and (x[0] < DOLFIN_EPS);
  }
};

int main()
{
    // Geometry
    info("Mesh");
    unsigned n = 30;
    UnitCube mesh(n, n, n);


    // Function spaces
    info("Function spaces");
    BilinearForm::TrialSpace V(mesh);

    // Forms
    info("Forms");
    BilinearForm a(V, V);
    LinearForm L(V);

    // Coefficient functions
    info("Functions");
    Source f;
    BoundarySource g;
    Solution usol;

    // Solution function
    Function u(V);

    // Attach functions to forms
    L.f = f;
    L.g = g;

    // Setup boundary conditions
    info("Boundary conditions");
    DirichletBoundary boundary;
    DirichletBC bc(V, usol, boundary);

    std::vector<const BoundaryCondition*> bcs;
    bcs.push_back(&bc);

    u = Function(V);
    info("Solving");
    solve(a == L, u, bcs);

    // Interpolate exact solution
    Function usol_h(V);
    usol_h = usol; 

    Function e(V);
    e.vector() = u.vector();
    e.vector() -= usol_h.vector();
    
    int N = u.vector().size();
    cout << endl;
    cout << "==========================" << endl;
    cout << "Norm of e = " << e.vector().norm("l2") / sqrt(N) << endl;
    cout << "Min  of e = " << e.vector().min() << endl;
    cout << "Max  of e = " << e.vector().max() << endl;
    cout << "==========================" << endl;

    // Write results to file
    info("Writing to file");
    File ufile("u.pvd");
    ufile << u;
    File usolfile("usol.pvd");
    usolfile << usol_h;
    File efile("e.pvd");
    efile << e;

    //plot(u);
    
    return 0;
}

