/* ====cMCF.H ===================================================================================
 *
 * Author: Patrice Koehl, June 2018
 * Department of Computer Science
 * University of California, Davis
 *
 * This file applies a conformalized Mean Curvature Flow, or a regular Mean curvature flow
 * to parameterize a surface onto the sphere
 *
 * It is based on the paper:
 *      [1] M Kazdhan, J. Solomon, and M. Ben-Chen. "Can Mean-Curvature Flow be Modified to be Non-singular?"
 *		Eurographics Symp. Geom. Proc., 31 (2012).
 =============================================================================================== */

#pragma once

  /*================================================================================================
   BLAS prototypes
  ================================================================================================== */

  extern "C" {
	void daxpy_(int * n ,double *alpha , double * X, int *incx, double * Y,int *incy);
	void dcopy_(int * n, double *Y, int *incy, double *X, int *incx);
	void dscal_(int * n, double *scale, double *Y, int *incy);
  }

  /* ===== INCLUDES AND TYPEDEFS =================================================================
   *
   =============================================================================================== */

  #define _USE_MATH_DEFINES // for M_PI

  #include <vector>
  #include <cmath>
  #include <Eigen/Dense>
  #include <pthread.h>
  #include "ConformalError.h"

  typedef Eigen::Triplet<double> Triplet;

  #define NUM_THREADS 32

  typedef struct Mat_data {
        int N1;
	int N2;
	int n;
	Mesh *mesh;
	double *Diag;
  } Mat_data;

  Mat_data Mat[NUM_THREADS];
  int threadids[NUM_THREADS];
  pthread_t threads[NUM_THREADS];

  Eigen::SparseMatrix<double> Mass;
  Eigen::SparseMatrix<double> Stiff;
  Eigen::CholmodDecomposition<Eigen::SparseMatrix<double>> solver;

  int ntype = 0;
	
  /* =============================================================================================
     Define class
   =============================================================================================== */

  class cMCF{

  public:
	// initialize cMCF flow to sphere
	void initFlow(Mesh& mesh, int ftype, double dt);

	// Perform one step of cMCF flow
	Vector solveOneStep(Mesh& mesh, int niter, double dt, int type);

	// terminate cMCF flow to sphere
	void stopFlow(Mesh& mesh);

  private:
 
	// init Hessian
	void initMatrix(Mesh& mesh, Eigen::SparseMatrix<double> &Mat);

	// Reset Hessian
	void resetMatrix(Mesh& mesh, Eigen::SparseMatrix<double> &Mat);

	// Compute Mass matrix on one thread
	static void* stiffness_thread(void* data);

	// Compute Stiffness matrix
	void computeStiffness(Mesh& mesh);

	// Compute Mass matrix on one thread
	static void* mass_thread(void* data);

	// Compute Mass matrix
	void computeMass(Mesh& mesh);

	// Compute Mesh area
	double computeArea(Mesh& mesh);

	// Compute Mesh info: area, vol, sphericity
	void computeAVS(Mesh& mesh, double *Area, double *Vol, double *S);

	// Transfer one coordinate to a vector
	void transferCoord(Mesh& mesh, Eigen::VectorXd &X, int icoord);

	// Transfer one coordinate to a vector
	void transferBack(Mesh& mesh, Eigen::VectorXd &X, int icoord);

	// Fit sphere
	void fitSphere(Mesh& mesh);

	// scale Mesh
	void scaleMesh(Mesh& mesh, double Scale);

   protected:

	Eigen::SparseMatrix<double> H;
	Eigen::VectorXd B;
	Eigen::VectorXd Sol;

	double Area0, S;

  };

  /* =============================================================================================
   Init matrix
   =============================================================================================== */

  void cMCF::initMatrix(Mesh& mesh, Eigen::SparseMatrix<double> &Mat)
  {

  /* ==================================================================================================
	Initialize all values for regular constraints to 0
   ==================================================================================================*/

	int idx, jdx;
	std::vector<Triplet> Mat_coefficients;
	double zero = 0;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		idx = v->index;
		Mat_coefficients.push_back(Triplet(idx, idx, zero));
	}

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		idx = v_i->index;
		jdx = v_j->index;
		Mat_coefficients.push_back(Triplet(idx, jdx, zero));
		Mat_coefficients.push_back(Triplet(jdx, idx, zero));
	}

	Mat.setFromTriplets(Mat_coefficients.begin(), Mat_coefficients.end());

  }

  /* =============================================================================================
   Reset matrix to zero
   =============================================================================================== */

  void cMCF::resetMatrix(Mesh& mesh, Eigen::SparseMatrix<double> &Mat)
  {

	int idx, jdx;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		idx = v->index;
		Mat.coeffRef(idx, idx)     = 0;
	}

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		idx = v_i->index;
		jdx = v_j->index;
		Mat.coeffRef(idx, jdx)     = 0;
		Mat.coeffRef(jdx, idx)     = 0;
	}
  }


/*================================================================================================
 stiffness_thread
================================================================================================== */


void* cMCF::stiffness_thread(void* data)
{
	int threadid = *((int *) data);

	int N1    = Mat[threadid].N1;
	int N2    = Mat[threadid].N2;
	int n     = Mat[threadid].n;

	int idxA, idxB;
	double val = 1.0/2.;
	double val2, cotC, cotD;

	HalfEdgeIter hAB, hBA;
	Vector A, B, C, D;
	Vector u, v;

	std::memset(Mat[threadid].Diag, 0, n*sizeof(double));

	for (int pair = N1; pair < N2; pair++)
	{
		hAB = Mat[threadid].mesh->edges[pair].he;
		hBA = hAB->flip;

		idxA = hAB->vertex->index;
		idxB = hBA->vertex->index;

		A = hAB->vertex->position;
		B = hAB->next->vertex->position;
		C = hAB->prev->vertex->position;
		D = hBA->prev->vertex->position;

		u = A-C;
		v = B-C;
		cotC = dot(u,v)/cross(u,v).norm();

		u = A-D;
		v = B-D;
		cotD = dot(u,v)/cross(u,v).norm();

		val2 = val*(cotC+cotD);

		Stiff.coeffRef(idxA, idxB) = val2;
		Stiff.coeffRef(idxB, idxA) = val2;

		Mat[threadid].Diag[idxA] -= val2;
		Mat[threadid].Diag[idxB] -= val2;

	}
	return 0;
}

/*================================================================================================
 Stiffness
================================================================================================== */

void cMCF::computeStiffness(Mesh& mesh)
{
	int nthreads = sysconf( _SC_NPROCESSORS_ONLN );
	if(nthreads==0) nthreads = 1;

	int Npair = mesh.edges.size();
	int nval = Npair / nthreads;
	int N1, N2;
	int n  = mesh.vertices.size();

	for(int i = 0; i < nthreads; i++) 
	{
		N1 = i*nval;
		N2 = N1 + nval;
		if(i == nthreads-1) N2 = Npair;
		threadids[i]=i;
		Mat[i].N1 = N1;
		Mat[i].N2 = N2;
		Mat[i].n  = n;
		Mat[i].mesh  = &mesh;
		double *Diag = new double[n];
		Mat[i].Diag = Diag;

		pthread_create(&threads[i], NULL, stiffness_thread, (void*) &threadids[i]);
	}
	
/*      ==========================================================================================
	Join all the threads (to make sure they are all finished)
        ========================================================================================== */

	double *Diag = new double[n];
	std::memset(Diag, 0, n*sizeof(double));
	double alpha = 1.0;
	int inc = 1;
	for (int i=0; i < nthreads; i++)
	{
		pthread_join(threads[i], NULL);
		daxpy_(&n, &alpha, Mat[i].Diag, &inc, Diag, &inc);
	}

	for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
	{
		int i = v_iter->index;
		Stiff.coeffRef(i, i) = Diag[i];
	}

	delete [] Diag;

}

/*================================================================================================
 mass_thread
================================================================================================== */


void* cMCF::mass_thread(void* data)
{
	int threadid = *((int *) data);

	int N1    = Mat[threadid].N1;
	int N2    = Mat[threadid].N2;
	int n     = Mat[threadid].n;

	int idxA, idxB;
	double  area1, area2, area;
	double val = 1.0/24.;

	HalfEdgeIter hAB, hBA;
	VertexIter A, B;
	Vector pA, pB, pC, pD;

	std::memset(Mat[threadid].Diag, 0, n*sizeof(double));

	for (int pair = N1; pair < N2; pair++)
	{
		hAB = Mat[threadid].mesh->edges[pair].he;
		hBA = hAB->flip;

		A = hAB->vertex;
		B = hBA->vertex;

		pA = A->position2;
		pB = B->position2;
		pC = hAB->prev->vertex->position2;
		pD = hBA->prev->vertex->position2;

		idxA = A->index;
		idxB = B->index;

		area1 = cross(pA-pB,pA-pC).norm();
		area2 = cross(pA-pB,pA-pD).norm();

		area = val*(area1+area2);

		Mass.coeffRef(idxA, idxB) = area;
		Mass.coeffRef(idxB, idxA) = area;

		Mat[threadid].Diag[idxA] += area;
		Mat[threadid].Diag[idxB] += area;

	}
	return 0;
}
/*================================================================================================
 Mass
================================================================================================== */

void cMCF::computeMass(Mesh& mesh)
{
	int nthreads = sysconf( _SC_NPROCESSORS_ONLN );
	if(nthreads==0) nthreads = 1;

	int Npair = mesh.edges.size();
	int nval = Npair / nthreads;
	int N1, N2;
	int n  = mesh.vertices.size();

	for(int i = 0; i < nthreads; i++) 
	{
		N1 = i*nval;
		N2 = N1 + nval;
		if(i == nthreads-1) N2 = Npair;
		threadids[i]=i;
		Mat[i].N1 = N1;
		Mat[i].N2 = N2;
		Mat[i].n  = n;
		Mat[i].mesh  = &mesh;
		double *Diag = new double[n];
		Mat[i].Diag = Diag;

		pthread_create(&threads[i], NULL, mass_thread, (void*) &threadids[i]);
	}
	
/*      ==========================================================================================
	Join all the threads (to make sure they are all finished)
        ========================================================================================== */

	double *Diag = new double[n];
	std::memset(Diag, 0, n*sizeof(double));
	double alpha = 1.0;
	int inc = 1;
	for (int i=0; i < nthreads; i++)
	{
		pthread_join(threads[i], NULL);
		daxpy_(&n, &alpha, Mat[i].Diag, &inc, Diag, &inc);
	}

	for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
	{
		int i = v_iter->index;
		Mass.coeffRef(i, i) = Diag[i];
	}

	delete [] Diag;

}

  /* ===== Mesh area           =======================================================================
   * @Mesh - Mesh considered
   *
   ==================================================================================================*/

  double cMCF::computeArea(Mesh& mesh)
  {

	HalfEdgeIter hAB, hBC, hCA;
	VertexIter v1, v2, v3;
	Vector p1, p2, p3;

	double area = 0.0;

	for(FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		p1 = v1->position2;
		p2 = v2->position2;
		p3 = v3->position2;

		area += cross(p1-p2,p1-p3).norm();

	}

	return 0.5*area;
  }


  /* ===== Mesh area, Vol, sphericity ================================================================
   ==================================================================================================*/

  void cMCF::computeAVS(Mesh& mesh, double *Area, double *Vol, double *S)
  {
	double A, V;

	HalfEdgeIter hAB, hBC, hCA;
	VertexIter v1, v2, v3;
	Vector p1, p2, p3;

	A = 0; V = 0;
	for(FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		p1 = v1->position2;
		p2 = v2->position2;
		p3 = v3->position2;

		A += cross(p1-p2,p1-p3).norm();

		V += dot(p1 , cross(p2,p3) );
	}

	A = 0.5*A;
	V = std::abs(V)/6.0;
	*Area = A;
	*Vol  = V;
	*S    = std::pow(M_PI,1.0/3.0) * std::pow(6.0*V,2.0/3.0) / A;

  }

  /* =============================================================================================
  Transfer coordinate to a Eigen vector
   =============================================================================================== */

  void cMCF::transferCoord(Mesh& mesh, Eigen::VectorXd &X, int icoord)
  {
        for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
        {
		X[v_iter->index] = v_iter->position2[icoord];
	}
   }

  /* =============================================================================================
  Transfer coordinate from a Eigen vector
   =============================================================================================== */

  void cMCF::transferBack(Mesh& mesh, Eigen::VectorXd &X, int icoord)
  {
        for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
        {
		v_iter->position2[icoord]= X[v_iter->index];
	}
   }

  /* ===============================================================================================
   scaleMesh: Scale coordinates of all vertices in a mesh by a constant, Scale.
   =============================================================================================== */

  void cMCF::scaleMesh(Mesh& mesh, double Scale)
  {
	Vector center;
	center[0] = 0; center[1] = 0; center[2]=0;
	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		v->position2 *= Scale;
		center[0] += v->position2[0];
		center[1] += v->position2[1];
		center[2] += v->position2[2];
	}
	int n_vertices = mesh.vertices.size();
	center /= n_vertices;
	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		v->position2 -= center;
	}

  }

  /* ===============================================================================================
   FitSphere: Fit a sphere into the vertices of a mesh, and then translate the mesh such that
		the center of the sphere is at 0,0,0

   Input:
	  mesh:	 the mesh data structure (pointer to the structure)
   =============================================================================================== */

  void cMCF::fitSphere(Mesh& mesh)
  {
	int n_vertices = mesh.vertices.size();

	Eigen::MatrixXd A(n_vertices, 4);
	Eigen::VectorXd B(n_vertices);
	Eigen::VectorXd C(4);

	Vector pointA, centerA;

	double xA, yA, zA;
	int idx = 0;
	for(VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		pointA = v_it->position2;
		xA = pointA[0]; yA = pointA[1]; zA = pointA[2];
		A(idx, 0) = 2*xA; A(idx, 1) = 2*yA; A(idx, 2) = 2*zA;
		A(idx, 3) = 1.0;
		B[idx]    = xA*xA + yA*yA + zA*zA;
		idx++;
	}

	C = A.bdcSvd(Eigen::ComputeThinU | Eigen::ComputeThinV).solve(B);

	centerA[0] = C[0]; centerA[1] = C[1]; centerA[2] = C[2];

	for(VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		pointA = v_it->position2;
		pointA = pointA-centerA;
		pointA = pointA / pointA.norm();
		v_it->position2[0] = pointA[0];
		v_it->position2[1] = pointA[1];
		v_it->position2[2] = pointA[2];
	}
  }
  /* =============================================================================================
  Solve Linear System

   =============================================================================================== */

  Vector cMCF::solveOneStep(Mesh& mesh, int niter, double dt, int type)
  {

  /* ==================================================================================================
	Solve system
   ==================================================================================================*/

	if(type==0) {
		computeMass(mesh);
	} else if (type==1) {
		computeStiffness(mesh);
	} else {
		computeMass(mesh);
		computeStiffness(mesh);
	}

	H = Mass - dt*Stiff;

	if(ntype==0) solver.analyzePattern(H);
	ntype++;
	solver.factorize(H);


	if(solver.info()!=Eigen::Success)
	{
		std::cout << "Singularity: Problem with Cholesky factorization!!" << std::endl;
		exit(1);
	}

	for(int icoord = 0; icoord < 3; icoord++)
	{
		transferCoord(mesh, B, icoord);
		B = Mass*B;
		Sol = solver.solve(B); 
		transferBack(mesh, Sol, icoord);
	}

	double Area, Vol;
	Area = computeArea(mesh);
	double Scale = std::sqrt(Area0/Area);
	scaleMesh(mesh, Scale);
	computeAVS(mesh, &Area, &Vol, &S);
//	double crossdefect = ConformalError::crossRatioError(mesh, &maxc);
	Vector r = ConformalError::quasiConformalError(mesh);
	double quasi = r[2];
	

	std::cout << "        " << "   " << std::setw(8)<< niter+1 << "    " << std::setw(12) << dt;
	std::cout << "      " << std::setw(8) << Area << "      " << std::fixed << std::setprecision(6) << std::setw(8) << Vol ;
	std::cout << "      " << std::fixed << std::setprecision(6) << std::setw(8) << quasi;
	std::cout << "      " << std::fixed << std::setprecision(6) << std::setw(8) << S << std::endl;

	return Vector(S, quasi, 0);

  }

  /* ===== MeanCurvatureFlow ========================================================================
   *
   ==================================================================================================*/

  void cMCF::initFlow(Mesh& mesh, int init, double dt)
  {
	int n_vertices = mesh.vertices.size();

  /* ==================================================================================================
	Create Eigen matrices and arrays needed to minimize curvatures
   ==================================================================================================*/

	Stiff.resize(n_vertices, n_vertices);
	Mass.resize(n_vertices, n_vertices);
	H.resize(n_vertices, n_vertices);
	B.resize(n_vertices);
	Sol.resize(n_vertices);

	initMatrix(mesh, Stiff);
	initMatrix(mesh, Mass);
	initMatrix(mesh, H);

  /* ==================================================================================================
	Initialize
   ==================================================================================================*/

	if(init==0) {
		for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
		{
			v_iter->position2[0] = v_iter->position[0];
			v_iter->position2[1] = v_iter->position[1];
			v_iter->position2[2] = v_iter->position[2];
		}
	}

	int zero = 0;
	double Vol0, S0;
	computeAVS(mesh, &Area0, &Vol0, &S0);
	Vector r = ConformalError::quasiConformalError(mesh);
	double quasi = r[2];

	computeStiffness(mesh);

	std::cout << "        " << "=================================================================================================" << std::endl;
	std::cout << "        " << "       Iter       Step size        Area      Volume       QuasiConf. ratio   Sphericity          " << std::endl;
        std::cout << "        " << "=================================================================================================" << std::endl;
        std::cout << "        " << "   " << std::setw(8)<< zero << "    " << std::setw(12) << dt;
	std::cout << "      " << std::setw(8) << Area0 << "      " << std::fixed << std::setprecision(6) << std::setw(8) << Vol0 ;
	std::cout << "      " << std::fixed << std::setprecision(6) << std::setw(8) << quasi;
	std::cout << "      " << std::fixed << std::setprecision(6) << std::setw(8) << S0 << std::endl;

  }

  /* ===== MeanCurvatureFlow ========================================================================
   *
   ==================================================================================================*/

  void cMCF::stopFlow(Mesh& mesh)
  {
        std::cout << "        " << "=================================================================================================" << std::endl;
	std::cout << " " << std::endl;

	fitSphere(mesh);

  }
