/* ====YAMABEFLOW.H===============================================================================
 *
 * Author: Patrice Koehl, June 2019
 * Department of Computer Science
 * University of California, Davis
 *
 * This file applies a variational method for computing a conformal map from a genus zero
 * surface onto the sphere
 *
 * It is based on the paper:
 *      B. Springborn, P. Schroeder, and U. Pinkall. "Conformal equivalence of triangle meshes",
 *      ACM Trans. Graph. 27, 3, Article 77 (2008).
 =============================================================================================== */

#pragma once

  /* ===== INCLUDES AND TYPEDEFS =================================================================
   *
   * Third party libraries included: Eigen
   *
   =============================================================================================== */

  #define _USE_MATH_DEFINES // for M_PI

  #include <vector>
  #include <cmath>

  #include "Layout2D.h"

  typedef Eigen::Triplet<double> Triplet;
//  Eigen::CholmodDecomposition<Eigen::SparseMatrix<double>> solverY;
  Eigen::SimplicialCholesky<Eigen::SparseMatrix<double>> solverY;
  int nY = 0;

  /* =============================================================================================
     Define class
   =============================================================================================== */

  class Yamabe{

  public:
	// select vertex from genus 0 surface that will be remove temperarily
	VertexIter pickPole(Mesh& mesh, int type);

	// Removes a vertex from the mesh and set the tags.
	void punctureMesh(Mesh& mesh, VertexIter pole);
	
	// Performs Yamabe flow to plane
	VertexIter yamabeFlow(Mesh& mesh, int vtype, bool *SUCCESS);

  private:
 
	// init Hessian
	void initHessian(Mesh& mesh);

	// Reset Hessian
	void resetHessian(Mesh& mesh);

	// Compute Hessian and right hand side
	void computeHessian(Mesh& mesh);

	// Solve Jacobian linear system
	bool solveSystem(Mesh& mesh);

	// compute Lengths
	void computeLength0(Mesh& mesh);

 	// Init radius for all vertices
	void initU(Mesh& mesh);

 	// Compute initial lengths for all edges
	void setLength0(Mesh& mesh);

 	// Compute new lengths for all edges
	void computeLength(Mesh& mesh);

	// Compute angles in a triangle from edge lengths
	void trigAngles(double lAB, double lBC, double lCA, double *ang, double *cotan, int *n);

 	// Compute curvatures using edge lengths
	double computeCurv(Mesh& mesh, double *maxK, int *nbad);

	// Projection to sphere
	void stereo2Sphere(Mesh& mesh);

   protected:

	Eigen::SparseMatrix<double> H;
	Eigen::VectorXd B, Sol;

	bool H_is_set;

	double *V_curv, *U_fact, *U_keep;
	double *E_length0;

  };

  /* ===== PICK VERTEX    ========================================================================
   * Select vertex that will be removed from the mesh to enable planar Tutte embedding.
   * At this stage, selection is based on finding the vertex with either the highest valence
   * (type = 1), or the lowest valence (type = 2). If type is set to 0, vertex # 0 is picked
   * arbitrarily
   =============================================================================================== */

  VertexIter Yamabe::pickPole(Mesh& mesh, int type)
  {
	if(type==0) return mesh.vertices.begin();

	/* Find vertex with largest, or smallest valence */

	int valence;
	int val_ref = mesh.vertices.begin()->degree();
	VertexIter pole = mesh.vertices.begin();

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		valence = v->degree();

		if(type==2) {
			if(valence < val_ref) {
				val_ref = valence;
				pole = v;
			}
		} else if(type==1) {
			if(valence > val_ref) {
				val_ref = valence;
				pole = v;
			}
		}
	}

	return pole;
  }

  /* ===== PUNCTURE MESH  ========================================================================
   * Remove vertex VR.  All incident
   * vertices and faces  are set to inNorthPoleVicinity
   =============================================================================================== */

  void Yamabe::punctureMesh(Mesh& mesh, VertexIter pole)
  {

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		v->NorthPole = false;
		v->inNorthPoleVicinity = false;
	}

	int idx = 0;
	pole->NorthPole = true;
	HalfEdgeIter he = pole->he;
	do {
		he->face->inNorthPoleVicinity = true;
		he->flip->vertex->indexN = idx;
		he->flip->vertex->inNorthPoleVicinity = true;
		idx++;

		he = he->flip->next;
	} while (he != pole->he);

	idx=0;
	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		bool b1 = v->NorthPole;
		bool b2 = v->inNorthPoleVicinity;

		if(!b1 && !b2) {
			v->indexN = idx;
			idx++;
		}
	}
  }

  /* ===== Stereo2Sphere   ============================================================================
   * Project planar mesh onto unit sphere
   ==================================================================================================*/

  void Yamabe::stereo2Sphere(Mesh& mesh)
  {

	double U, V, val, den;
	int idx;

        for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		bool b1 = v->NorthPole;
		idx = v->index;

		if(b1) {
			v->position2.x = 0.;
			v->position2.y = 0.;
			v->position2.z = -1.;
		} else {


			U = v->position2.x;
			V = v->position2.y;

			val = U*U + V*V;
			den = 1.0/(1.0+val);

			v->position2.x = 2.0*U*den;
			v->position2.y = 2.0*V*den;
			v->position2.z = (1.0-val)*den;

		}
	}
  }


  /* ===============================================================================================
   computeLength: computes initial length of all edges
   =============================================================================================== */

  void Yamabe::computeLength0(Mesh& mesh)
  {

	for(EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		HalfEdgeIter h = e->he;
		Vector p0 = h->vertex->position;
		Vector p1 = h->flip->vertex->position;
		double length = (p0-p1).norm();
		e->length = length;
	}
  }
  /* =============================================================================================
   Init Conformal factors
   =============================================================================================== */

  void Yamabe::initU(Mesh& mesh)
  {
	int n_vertices = mesh.vertices.size();
	for(int i = 0; i < n_vertices; i++) U_fact[i] = 0.0;
  }

  /* =============================================================================================
   Init Hessian matrix
   =============================================================================================== */

  void Yamabe::initHessian(Mesh& mesh)
  {

  /* ==================================================================================================
	Initialize all values for regular constraints to 0
   ==================================================================================================*/

	int idx, jdx;
	std::vector<Triplet> Mat_coefficients;
	double zero = 0;
	bool b1, b2, b3, b4;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		b1 = v->NorthPole;
		b2 = v->inNorthPoleVicinity;
		if(!b1 && !b2) {
			idx = v->indexN;
			Mat_coefficients.push_back(Triplet(idx, idx, zero));
			B[idx] = 0.0;
		}
	}

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		b1 = v_i->NorthPole;
		b2 = v_i->inNorthPoleVicinity;
		b3 = v_j->NorthPole;
		b4 = v_j->inNorthPoleVicinity;

		if(!b1 && !b2 && !b3 && !b4) {
			idx = v_i->indexN;
			jdx = v_j->indexN;
			Mat_coefficients.push_back(Triplet(idx, jdx, zero));
			Mat_coefficients.push_back(Triplet(jdx, idx, zero));
		}
	}

	H.setFromTriplets(Mat_coefficients.begin(), Mat_coefficients.end());
	H_is_set = true;

  }

  /* =============================================================================================
   Reset matrix to zero
   =============================================================================================== */

  void Yamabe::resetHessian(Mesh& mesh)
  {

	int idx, jdx;
	bool b1, b2, b3, b4;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		b1 = v->NorthPole;
		b2 = v->inNorthPoleVicinity;
		if(!b1 && !b2) {
			idx = v->indexN;
			H.coeffRef(idx, idx)     = 0;
			B[idx] = 0.0;
		}
	}

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		b1 = v_i->NorthPole;
		b2 = v_i->inNorthPoleVicinity;
		b3 = v_j->NorthPole;
		b4 = v_j->inNorthPoleVicinity;

		if(!b1 && !b2 && !b3 && !b4) {
			idx = v_i->indexN;
			jdx = v_j->indexN;
			H.coeffRef(idx, jdx)     = 0;
			H.coeffRef(jdx, idx)     = 0;
		}
	}
  }

  /* ===============================================================================================
   YamabeInitLength: compute initial edge length
   =============================================================================================== */

  void Yamabe::setLength0(Mesh& mesh) 
  {

	VertexIter v0, v1;
	Vector a,b;

	for(EdgeIter e_iter = mesh.edges.begin(); e_iter != mesh.edges.end(); e_iter++)
	{
		E_length0[e_iter->index] = e_iter->length;
	}
   }

  /* ===============================================================================================
   YamabeNewLength: compute edge length based on new radii
   =============================================================================================== */

  void Yamabe::computeLength(Mesh& mesh) 
  {

	int idx;
	double length0, length;
	double u0, u1;
	VertexIter v0, v1;

	for(EdgeIter e_iter = mesh.edges.begin(); e_iter != mesh.edges.end(); e_iter++)
	{
		idx = e_iter->index;
		v0 = e_iter->he->vertex;
		v1 = e_iter->he->flip->vertex;

		u0 = U_fact[v0->index];
		u1 = U_fact[v1->index];
		u0 = std::max(-30.0, std::min(30.0, u0));
		u1 = std::max(-30.0, std::min(30.0, u1));

		length0 = E_length0[idx];
		length = length0*std::exp((u0 + u1) / 2.0);

		e_iter->length = length;
	}
   }

  /* ===============================================================================================
   trigAngles: Computes the three angles of a triangle and their cotans
   =============================================================================================== */

  void Yamabe::trigAngles(double lAB, double lBC, double lCA, double *ang, double *cotan, int *n)
  {
	double val, alpha_A, alpha_B, alpha_C;
	double cot_A, cot_B, cot_C;

	int n1 = 0;
	if( lAB > lBC + lCA) {
		alpha_A = 0; alpha_B = 0; alpha_C = M_PI;
		cot_A = 0; cot_B = 0; cot_C = 0;
		n1 = 1;
	} else if (lBC > lAB + lCA) {
		alpha_A = M_PI; alpha_B = 0; alpha_C = 0;
		cot_A = 0; cot_B = 0; cot_C = 0;
		n1 = 1;
	} else if (lCA > lAB + lBC) {
		alpha_A = 0; alpha_B = M_PI; alpha_C = 0;
		cot_A = 0; cot_B = 0; cot_C = 0;
		n1 = 1;
	} else {
		val = (lAB*lAB + lCA*lCA - lBC*lBC)/(2.0*lAB*lCA);
		alpha_A = std::acos(val);
		cot_A   = 1.0/std::tan(alpha_A);
		val = (lAB*lAB + lBC*lBC - lCA*lCA)/(2.0*lAB*lBC);
		alpha_B = std::acos(val);
		cot_B   = 1.0/std::tan(alpha_B);
		val = (lCA*lCA + lBC*lBC - lAB*lAB)/(2.0*lCA*lBC);
		alpha_C = std::acos(val);
		cot_C   = 1.0/std::tan(alpha_C);
	}

	ang[0] = alpha_A; ang[1] = alpha_B; ang[2] = alpha_C;
	cotan[0] = cot_A; cotan[1] = cot_B; cotan[2] = cot_C;
	*n = n1;
  }


  /* ===============================================================================================
   computeCurv: computes the curvatures of all vertices

   The curvature is defined as the excess angle at each vertex i, i.e. to
   2pi - sum_t Theta(i,ijk)
   where sum_t indicates a sum over all triangles (ijk) that are incident to i
         Theta(i,ijk) is the angle at vertex i in the triangle (i,jk)
   =============================================================================================== */

  double Yamabe::computeCurv(Mesh& mesh, double *maxK, int *nbad)
  {
	double TwoPI = 2*M_PI;

/* 	==========================================================================================
	Initialize Curvature to 2*PI
        ========================================================================================== */

	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter eAB, eBC, eCA;
	VertexCIter va, vb, vc;
	double lAB, lCA, lBC;

	int idxa, idxb, idxc;
	double angs[3], cotans[3];

	for (VertexCIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		idxa = v->index;
		V_curv[idxa] = TwoPI;
	}

/* 	==========================================================================================
	Iterate over all triangles and remove angle for each vertex
        ========================================================================================== */

	int nb = 0;
	int n;
	for (FaceCIter f = mesh.faces.begin(); f != mesh.faces.end(); f++)
	{
		hAB = f->he;
		hBC = hAB->next;
		hCA = hBC->next;
		
		va = hAB->vertex;
		vb = hBC->vertex;
		vc = hCA->vertex;
		idxa = va->index;
		idxb = vb->index;
		idxc = vc->index;

		lAB = hAB->edge->length;
		lBC = hBC->edge->length;
		lCA = hCA->edge->length;

		trigAngles(lAB, lBC, lCA, angs, cotans, &n);

		V_curv[idxa] -= angs[0];
		V_curv[idxb] -= angs[1];
		V_curv[idxc] -= angs[2];
		nb += n;
	}

	double err=0, errM=0;
	double val;
	for (VertexCIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		idxa = v->index;
		if(!v->inNorthPoleVicinity && !v->NorthPole) {
			val = V_curv[idxa];
			err += val*val;
			errM = std::max(errM, std::abs(val));
		}
	}
	err = std::sqrt(err);
	*maxK = errM;
	*nbad = nb;
	return err;

  }

  /* ===== Jacobian system          =================================================================
	Hessian 
   ==================================================================================================*/

  void Yamabe::computeHessian(Mesh& mesh)
  {

	if(H_is_set) {
		resetHessian(mesh);
	} else {
		initHessian(mesh);
		H_is_set = true;
	}

	double lAB, lBC, lCA;
	int idxA, idxB, idxC;

	double TwoPI = 2*M_PI;

	double alpha_A, alpha_B, alpha_C;
	double cotan_A, cotan_B, cotan_C;
	bool b1, b2, b3;
	double angs[3], cotans[3];
	int n;

	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter e1, eAB, eBC, eCA;
	VertexIter v1, v2, v3;

	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		if(!v_it->inNorthPoleVicinity && !v_it->NorthPole)
		{
			idxA = v_it->indexN;
			B[idxA] = -TwoPI;
		}
	}

	for(FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		b1 = v1->inNorthPoleVicinity || v1->NorthPole;
		b2 = v2->inNorthPoleVicinity || v2->NorthPole;
		b3 = v3->inNorthPoleVicinity || v3->NorthPole;

		eAB = hAB->edge;
		eBC = hBC->edge;
		eCA = hCA->edge;

		idxA = v1->indexN;
		idxB = v2->indexN;
		idxC = v3->indexN;

		lAB = eAB->length;
		lBC = eBC->length;
		lCA = eCA->length;

		trigAngles(lAB, lBC, lCA, angs, cotans, &n);
		alpha_A = angs[0]; alpha_B = angs[1]; alpha_C = angs[2];
		cotan_A = cotans[0]; cotan_B = cotans[1]; cotan_C = cotans[2];

		if(!b1) {
			B[idxA] += alpha_A;
			H.coeffRef(idxA, idxA) -= cotan_B+cotan_C;
		}
		if(!b2) {
			B[idxB] += alpha_B;
			H.coeffRef(idxB, idxB) -= cotan_A+cotan_C;
		}
		if(!b3) {
			B[idxC] += alpha_C;
			H.coeffRef(idxC, idxC) -= cotan_A+cotan_B;
		}

		if( !b1 && !b2 ) {
			H.coeffRef(idxA, idxB) += cotan_C;
			H.coeffRef(idxB, idxA) += cotan_C;
		}

		if( !b1 && !b3 ) {
			H.coeffRef(idxA, idxC) += cotan_B;
			H.coeffRef(idxC, idxA) += cotan_B;
		}

		if( !b2 && !b3 ) {
			H.coeffRef(idxB, idxC) += cotan_A;
			H.coeffRef(idxC, idxB) += cotan_A;
		}
	}
	H = 0.25*H;
	B = -0.5*B;
  }

  /* =============================================================================================
  Solve Linear System

   =============================================================================================== */

  bool Yamabe::solveSystem(Mesh& mesh)
  {

  /* ==================================================================================================
	Compute Hessian and gradient
   ==================================================================================================*/

	computeHessian(mesh);

//	std::cout << Eigen::MatrixXd(H) << std::endl;
//	std::cout << " " << std::endl;
//	std::cout << B << std::endl;
//	std::cout << " " << std::endl;

  /* ==================================================================================================
	Solve system
   ==================================================================================================*/

	if(nY==0) solverY.analyzePattern(H);
        nY++;
        solverY.factorize(H);

	double TOL = 1.e-8;

	if(solverY.info()==Eigen::NumericalIssue)
	{
		std::cout << "Problem with Cholesky factorization!!" << std::endl;
		return false;
	}
	if(solverY.info()!=Eigen::Success)
	{
		std::cout << "Problem with Cholesky factorization!!" << std::endl;
		return false;
	}
	Sol = solverY.solve(B); 
	if(solverY.info()!=Eigen::Success)
	{
		std::cout << "Problem with Cholesky solver!!" << std::endl;
		return false;
	}

	double err = (H*Sol - B).norm();
        if(err > TOL || std::isnan(err) ) return false;

	return true;

  }

  /* ===== MinimizeYamabe  ===========================================================================
   *
   ==================================================================================================*/

  VertexIter Yamabe::yamabeFlow(Mesh& mesh, int vtype, bool *SUCCESS)
  {

	int niter_max = 500;
	double TOL = 1.e-8;

	*SUCCESS = true;

	int n_vertices = mesh.vertices.size();
	int n_edges    = mesh.edges.size();

	computeLength0(mesh);

  /* ==================================================================================================
        Puncture Mesh
   ==================================================================================================*/

        VertexIter pole = pickPole(mesh, vtype);
        punctureMesh(mesh, pole);

  /* ==================================================================================================
	Count number of active vertices
   ==================================================================================================*/

	int idx;
	int ntot = 0;
	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) ntot++;
	}

  /* ==================================================================================================
	Create Eigen matrices and arrays needed to minimize energy
   ==================================================================================================*/

	H.resize(ntot, ntot);
	B.resize(ntot);
	Sol.resize(ntot);
	H_is_set = false;

	U_fact    = new double[n_vertices];
	U_keep    = new double[n_vertices];
	V_curv    = new double[n_vertices];
	E_length0 = new double[n_edges];

	H.setZero();

  /* ==================================================================================================
	Initialize Yamabe discrete metric
   ==================================================================================================*/

	initU(mesh);
	setLength0(mesh);

  /* ==================================================================================================
	Initial energy that will be refined
   ==================================================================================================*/

	double ene0, ene, ene1;
	int nbad;
	computeLength(mesh);
	ene = computeCurv(mesh, &ene0, &nbad);

  /* ==================================================================================================
	Now perform refinement
   ==================================================================================================*/

	double step = 0;
	double u;
	int zero = 0;

	std::cout << " " << std::endl;
	std::cout << "Minimize Yamabe functional:" << std::endl;
	std::cout << "==========================" << std::endl;
	std::cout << " " << std::endl;

	std::cout << "        " << "===========================================================================" << std::endl;
	std::cout << "        " << "       Iter       Step size          Max curv.     Bad triangles           " << std::endl;
        std::cout << "        " << "===========================================================================" << std::endl;
        std::cout << "        " << "   " << std::setw(8)<< zero << "    " << std::setw(12) << step;
	std::cout << "    " << std::setw(12) << ene0 ;
	std::cout << "      " << std::setw(8) << zero << std::endl;

	bool info;
	for(int niter = 0; niter < niter_max; niter++)
	{

		for(int i = 0; i < n_vertices; i++) U_keep[i] = U_fact[i];

  		/* ====================================================================================
		Solve Jacobian system
   		======================================================================================*/

		info = solveSystem(mesh);
		if(!info) break;

  		/* ====================================================================================
		Find "optimal" step size:
			- try first step = 1; if corresponding energy ene1 < ene0, apply step
			- if ene1 > ene0, set step = 0.5*step and try again
   		======================================================================================*/

		step = 1.0;
		for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
		{
			idx = v_it->index;
			if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) {
				u = U_keep[idx] + step*Sol[v_it->indexN]; 
				U_fact[idx] = u;
			}
		}
		computeLength(mesh);
		ene = computeCurv(mesh, &ene1, &nbad);

		while (((ene1 > ene0) || (nbad > 0)) && (step > 0.00001)) {
			step = 0.5*step;
			for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
			{
				idx = v_it->index;
				if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) {
					u = U_keep[idx] + step*Sol[v_it->indexN]; 
					U_fact[idx] = u;
				}
			}
			computeLength(mesh);
			ene = computeCurv(mesh, &ene1, &nbad);
		}

        	std::cout << "        " << "   " << std::setw(8)<< niter+1 << "    " << std::setw(12) << step;
		std::cout << "    " << std::setw(12) << ene1 ;
		std::cout << "      " << std::setw(8) << nbad << std::endl;

		if(std::abs(ene1 - ene0) < TOL*ene1 || ene1 < TOL) break;
		ene0 = ene1;
	}
        std::cout << "        " << "===========================================================================" << std::endl;
	std::cout << " " << std::endl;

	if(std::abs(ene1) > TOL || nbad > 0 || !info) {
		std::cout << " " << std::endl;
		std::cout << "Yamabe flow has failed!" << std::endl;
		std::cout << " " << std::endl;
		*SUCCESS = false;
		return pole;
	}

  /* ==================================================================================================
	Now layout on plane, and project onto sphere
   ==================================================================================================*/

	Layout::layout2D(mesh);

	// Project onto sphere
	stereo2Sphere(mesh);

	// restore vertex star
	pole->NorthPole = false;
	HalfEdgeIter he = pole->he;
	do {
		he->face->inNorthPoleVicinity = false;

		he = he->flip->next;
	} while (he != pole->he);


	return pole;

  }

