/* ====RICCIFLOW.H =================================================================================
 *
 * Author: Patrice Koehl, June 2019
 * Department of Computer Science
 * University of California, Davis
 *
 * This file applies a Euclidean Ricci flow method for computing a conformal map from a genus zero
 * surface onto the sphere
 *
 * It is based on the papers:
 *      [1] M. Jin, J. Kim, F. Luo, D. Gu. "Discrete Surface Ricci Flow", IEEE Trans. Viz. Comput. Graph.
 *              14, p1030-1043 (2008)
 *      [2] W. Zeng, D. Samaras, D. Gu. "Ricci Flow for 3D Shape Analysis", IEEE. Trans. Pattern. Anal.
 *              Mach. Intell, 32, p662-677 (2010)
 *      [3] X. Chen, H. He, G. Zou, X. Zhang, X. Gu, J. Hua. "Ricci flow-based spherical parameterization
 *              and surface registration", Comput. Vis. Image Understanding, 117, p1107-1118 (2013)
 *      [4] RicciFlowExtremalLength (Gu's code) 2010
 *          -> http://www3.cs.stonybrook.edu/~gu/tutorial/RicciFlow.html
 *
 =============================================================================================== */

#pragma once

  /* ===== INCLUDES AND TYPEDEFS =================================================================
   *
   * Third party libraries included: Eigen
   *
   =============================================================================================== */

  #define _USE_MATH_DEFINES // for M_PI

  #include <vector>
  #include <cmath>

  #include "Layout2D.h"

  typedef Eigen::Triplet<double> Triplet;

  Eigen::SimplicialCholesky<Eigen::SparseMatrix<double>> solverR;
  int nR = 0;

  /* =============================================================================================
     Define class
   =============================================================================================== */

  class Ricci{

  public:
	// select vertex from genus 0 surface that will be remove temperarily
	VertexIter pickPole(Mesh& mesh, int type);

	// Removes a vertex from the mesh and set the tags.
	void punctureMesh(Mesh& mesh, VertexIter pole);
	
	// Performs Ricci flow to plane
	VertexIter euclideanRicci(Mesh& mesh, int vtype, bool *SUCCESS);

  private:
 
	// init Hessian
	void initHessian(Mesh& mesh);

	// Reset Hessian
	void resetHessian(Mesh& mesh);

	// Compute Hessian and right hand side
	void computeHessian(Mesh& mesh);

	// Solve Jacobian linear system
	bool solveSystem(Mesh& mesh);

 	// Compute initial lengths of all edges
	void computeLength0(Mesh& mesh);

 	// Init radius for all vertices
	void initRadii(Mesh& mesh);

 	// Init weights for all edges
	void initWeight(Mesh& mesh);

 	// Compute new lengths for all edges
	void computeLength(Mesh& mesh);

	// Compute angles in a triangle from edge lengths
	void trigAngles(double lAB, double lBC, double lCA, double *ang, double *cotan, int *n);

 	// Compute curvatures using edge lengths
	double computeCurv(Mesh& mesh, double *maxK, int *nbad);

	// Projection to sphere
	void stereo2Sphere(Mesh& mesh);

   protected:

	Eigen::SparseMatrix<double> H;
	Eigen::VectorXd B, Sol;

	bool H_is_set;

	double *Radii, *Vcurv, *U;
	double *Eweight;

  };

  /* ===== PICK VERTEX    ========================================================================
   * Select vertex that will be removed from the mesh to enable planar Tutte embedding.
   * At this stage, selection is based on finding the vertex with either the highest valence
   * (type = 1), or the lowest valence (type = 2). If type is set to 0, vertex # 0 is picked
   * arbitrarily
   =============================================================================================== */

  VertexIter Ricci::pickPole(Mesh& mesh, int type)
  {
	if(type==0) return mesh.vertices.begin();

	/* Find vertex with largest, or smallest valence */

	int valence;
	int val_ref = mesh.vertices.begin()->degree();
	VertexIter pole = mesh.vertices.begin();

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		valence = v->degree();

		if(type==2) {
			if(valence < val_ref) {
				val_ref = valence;
				pole = v;
			}
		} else if(type==1) {
			if(valence > val_ref) {
				val_ref = valence;
				pole = v;
			}
		}
	}

	return pole;
  }

  /* ===== PUNCTURE MESH  ========================================================================
   * Remove vertex VR.  All incident
   * vertices and faces  are set to inNorthPoleVicinity
   =============================================================================================== */

  void Ricci::punctureMesh(Mesh& mesh, VertexIter pole)
  {

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		v->NorthPole = false;
		v->inNorthPoleVicinity = false;
	}

	int idx = 0;
	pole->NorthPole = true;
	HalfEdgeIter he = pole->he;
	do {
		he->face->inNorthPoleVicinity = true;
		he->flip->vertex->indexN = idx;
		he->flip->vertex->inNorthPoleVicinity = true;
		idx++;

		he = he->flip->next;
	} while (he != pole->he);

	idx=0;
	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		bool b1 = v->NorthPole;
		bool b2 = v->inNorthPoleVicinity;

		if(!b1 && !b2) {
			v->indexN = idx;
			idx++;
		}
	}
  }

  /* ===== Stereo2Sphere   ============================================================================
   * Project planar mesh onto unit sphere
   ==================================================================================================*/

  void Ricci::stereo2Sphere(Mesh& mesh)
  {

	double U, V, val, den;
	int idx;

        for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		bool b1 = v->NorthPole;
		idx = v->index;

		if(b1) {
			v->position2.x = 0.;
			v->position2.y = 0.;
			v->position2.z = -1.;
		} else {


			U = v->position2.x;
			V = v->position2.y;

			val = U*U + V*V;
			den = 1.0/(1.0+val);

			v->position2.x = 2.0*U*den;
			v->position2.y = 2.0*V*den;
			v->position2.z = (1.0-val)*den;

		}
	}
  }

  /* =============================================================================================
   Init Hessian matrix
   =============================================================================================== */

  void Ricci::initHessian(Mesh& mesh)
  {

  /* ==================================================================================================
	Initialize all values for regular constraints to 0
   ==================================================================================================*/

	int idx, jdx;
	std::vector<Triplet> Mat_coefficients;
	double zero = 0;
	bool b1, b2, b3, b4;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		b1 = v->NorthPole;
		b2 = v->inNorthPoleVicinity;
		if(!b1 && !b2) {
			idx = v->indexN;
			Mat_coefficients.push_back(Triplet(idx, idx, zero));
			B[idx] = 0.0;
		}
	}

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		b1 = v_i->NorthPole;
		b2 = v_i->inNorthPoleVicinity;
		b3 = v_j->NorthPole;
		b4 = v_j->inNorthPoleVicinity;

		if(!b1 && !b2 && !b3 && !b4) {
			idx = v_i->indexN;
			jdx = v_j->indexN;
			Mat_coefficients.push_back(Triplet(idx, jdx, zero));
			Mat_coefficients.push_back(Triplet(jdx, idx, zero));
		}
	}

	H.setFromTriplets(Mat_coefficients.begin(), Mat_coefficients.end());
	H_is_set = true;

  }

  /* =============================================================================================
   Reset matrix to zero
   =============================================================================================== */

  void Ricci::resetHessian(Mesh& mesh)
  {

	int idx, jdx;
	bool b1, b2, b3, b4;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		b1 = v->NorthPole;
		b2 = v->inNorthPoleVicinity;
		if(!b1 && !b2) {
			idx = v->indexN;
			H.coeffRef(idx, idx)     = 0;
			B[idx] = 0.0;
		}
	}

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		b1 = v_i->NorthPole;
		b2 = v_i->inNorthPoleVicinity;
		b3 = v_j->NorthPole;
		b4 = v_j->inNorthPoleVicinity;

		if(!b1 && !b2 && !b3 && !b4) {
			idx = v_i->indexN;
			jdx = v_j->indexN;
			H.coeffRef(idx, jdx)     = 0;
			H.coeffRef(jdx, idx)     = 0;
		}
	}
  }

  /* ===============================================================================================
   computeLength: computes initial length of all edges
   =============================================================================================== */

  void Ricci::computeLength0(Mesh& mesh)
  {

	for(EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		HalfEdgeIter h = e->he;
		Vector p0 = h->vertex->position;
		Vector p1 = h->flip->vertex->position;
		double length = (p0-p1).norm();
		e->length = length;
	}
  }

  /* ===============================================================================================
   initRadii: initialize radii for Ricci metric
   =============================================================================================== */

  void Ricci::initRadii(Mesh& mesh)
  {

	int n_vertices = mesh.vertices.size();
	int *ncount = new int[n_vertices];

  /* ==================================================================================================
        Set iterators and pointers to Mesh data
   ==================================================================================================*/

	HalfEdgeIter he;
	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter e1, eAB, eBC, eCA;
	VertexIter v_i, v_j;
	VertexIter v1, v2, v3;

  /* ==================================================================================================
        Initialize
   ==================================================================================================*/

	int idx;
        for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
        {
		idx = v_iter->index;
		Radii[idx]=0.0;
		ncount[idx]=0;
        }

  /* ==================================================================================================
        Compute mean radius
   ==================================================================================================*/

	double lAB, lBC, lCA;
	int idxA, idxB, idxC;

	for(FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		eAB = hAB->edge;
		eBC = hBC->edge;
		eCA = hCA->edge;

		idxA = v1->index;
		idxB = v2->index;
		idxC = v3->index;

		lAB = eAB->length;
		lBC = eBC->length;
		lCA = eCA->length;

		Radii[idxA] += (lAB+lCA-lBC)/2; ncount[idxA]++;
		Radii[idxB] += (lAB+lBC-lCA)/2; ncount[idxB]++;
		Radii[idxC] += (lCA+lBC-lAB)/2; ncount[idxC]++;
	}


        for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
        {
		idx = v_iter->index;
		Radii[idx] /= ncount[idx];
        }

	delete [] ncount;

   }

  /* ===============================================================================================
   InitWeight: initialize Inversive distance for each edge
   =============================================================================================== */

  void Ricci::initWeight(Mesh& mesh)
  {

	int idx;
	double length;
	double rA, rB, IAB;
	VertexIter v0, v1;

	int nbad = 0;
	for(EdgeIter e_iter = mesh.edges.begin(); e_iter != mesh.edges.end(); e_iter++)
	{
		idx = e_iter->index;
		v0 = e_iter->he->vertex;
		v1 = e_iter->he->flip->vertex;

		length = e_iter->length;
		rA = Radii[v0->index];
		rB = Radii[v1->index];

		IAB = (length*length - rA*rA - rB*rB)/(2*rA*rB);

		if(IAB<0) {
			IAB = 0.;
			nbad++;
		}


		Eweight[idx] = IAB;
	}
//	std::cout << "nbad = " << nbad << std::endl;

	int nbad2 = 0;
	double lAB, lBC, lCA;
	int idxA, idxB, idxC;
	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter eAB, eBC, eCA;
	VertexIter v2, v3;
	double rC, wAB, wBC, wCA;

	for(FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		hAB =f_iter->he; hBC =hAB->next; hCA =hBC->next;

		v1 = hAB->vertex; v2 = hBC->vertex; v3 = hCA->vertex; 

		eAB = hAB->edge; eBC = hBC->edge; eCA = hCA->edge;

		idxA = v1->index; idxB = v2->index; idxC = v3->index;

		rA = Radii[idxA]; rB = Radii[idxB]; rC = Radii[idxC];
		wAB = Eweight[eAB->index]; wBC = Eweight[eBC->index]; wCA = Eweight[eCA->index];

		lAB = std::sqrt( rA*rA + rB*rB + 2*rA*rB*wAB);
		lBC = std::sqrt( rC*rC + rB*rB + 2*rC*rB*wBC);
		lCA = std::sqrt( rA*rA + rC*rC + 2*rA*rC*wCA);

		if(lAB > lBC+lCA || lBC > lAB+lCA || lCA > lAB + lBC) nbad2++;
	}
	std::cout << "nbad2 = " << nbad2 << std::endl;


  }

  /* ===============================================================================================
   RicciNewLength: compute edge length based on new radii
   =============================================================================================== */

  void Ricci::computeLength(Mesh& mesh) 
  {

	int idx;
	double length;
	double rA, rB, IAB;
	VertexIter v0, v1;

	for(EdgeIter e_iter = mesh.edges.begin(); e_iter != mesh.edges.end(); e_iter++)
	{
		idx = e_iter->index;
		v0 = e_iter->he->vertex;
		v1 = e_iter->he->flip->vertex;

		rA = Radii[v0->index];
		rB = Radii[v1->index];
		IAB = Eweight[idx];

		length = rA*rA + rB*rB + 2*rA*rB*IAB;
		length = std::sqrt(length);

		e_iter->length = length;
	}
   }

  /* ===============================================================================================
   trigAngles: Computes the three angles of a triangle and their cotans
   =============================================================================================== */

  void Ricci::trigAngles(double lAB, double lBC, double lCA, double *ang, double *cotan, int *n)
  {
	double val, alpha_A, alpha_B, alpha_C;
	double cot_A, cot_B, cot_C;

	int n1 = 0;
	if( lAB > lBC + lCA) {
		alpha_A = 0; alpha_B = 0; alpha_C = M_PI;
		cot_A = 0; cot_B = 0; cot_C = 0;
		n1 = 1;
	} else if (lBC > lAB + lCA) {
		alpha_A = M_PI; alpha_B = 0; alpha_C = 0;
		cot_A = 0; cot_B = 0; cot_C = 0;
		n1 = 1;
	} else if (lCA > lAB + lBC) {
		alpha_A = 0; alpha_B = M_PI; alpha_C = 0;
		cot_A = 0; cot_B = 0; cot_C = 0;
		n1 = 1;
	} else {
		val = (lAB*lAB + lCA*lCA - lBC*lBC)/(2.0*lAB*lCA);
		alpha_A = std::acos(val);
		cot_A   = 1.0/std::tan(alpha_A);
		val = (lAB*lAB + lBC*lBC - lCA*lCA)/(2.0*lAB*lBC);
		alpha_B = std::acos(val);
		cot_B   = 1.0/std::tan(alpha_B);
		val = (lCA*lCA + lBC*lBC - lAB*lAB)/(2.0*lCA*lBC);
		alpha_C = std::acos(val);
		cot_C   = 1.0/std::tan(alpha_C);
	}

	ang[0] = alpha_A; ang[1] = alpha_B; ang[2] = alpha_C;
	cotan[0] = cot_A; cotan[1] = cot_B; cotan[2] = cot_C;
	*n = n1;
  }


  /* ===============================================================================================
   computeCurv: computes the curvatures of all vertices

   The curvature is defined as the excess angle at each vertex i, i.e. to
   2pi - sum_t Theta(i,ijk)
   where sum_t indicates a sum over all triangles (ijk) that are incident to i
         Theta(i,ijk) is the angle at vertex i in the triangle (i,jk)
   =============================================================================================== */

  double Ricci::computeCurv(Mesh& mesh, double *maxK, int *nbad)
  {
	double TwoPI = 2*M_PI;

/* 	==========================================================================================
	Initialize Curvature to 2*PI
        ========================================================================================== */

	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter eAB, eBC, eCA;
	VertexCIter va, vb, vc;
	double lAB, lCA, lBC;

	int idxa, idxb, idxc;
	double angs[3], cotans[3];

	for (VertexCIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		idxa = v->index;
		Vcurv[idxa] = TwoPI;
	}

/* 	==========================================================================================
	Iterate over all triangles and remove angle for each vertex
        ========================================================================================== */

	int nb = 0;
	int n;
	for (FaceCIter f = mesh.faces.begin(); f != mesh.faces.end(); f++)
	{
		hAB = f->he;
		hBC = hAB->next;
		hCA = hBC->next;
		
		va = hAB->vertex;
		vb = hBC->vertex;
		vc = hCA->vertex;
		idxa = va->index;
		idxb = vb->index;
		idxc = vc->index;

		lAB = hAB->edge->length;
		lBC = hBC->edge->length;
		lCA = hCA->edge->length;

		trigAngles(lAB, lBC, lCA, angs, cotans, &n);

		Vcurv[idxa] -= angs[0];
		Vcurv[idxb] -= angs[1];
		Vcurv[idxc] -= angs[2];
		nb += n;
	}

	double err=0, errM=0;
	double val;
	for (VertexCIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		idxa = v->index;
		if(!v->inNorthPoleVicinity && !v->NorthPole) {
			val = Vcurv[idxa];
			err += val*val;
			errM = std::max(errM, std::abs(val));
		}
	}
	err = std::sqrt(err);
	*maxK = errM;
	*nbad = nb;
//	std::cout << "nbad = " << nb << " out of " << mesh.faces.size() << std::endl;
	return err;

  }

  /* ===== Jacobian system          =================================================================
	Hessian 
   ==================================================================================================*/

  void Ricci::computeHessian(Mesh& mesh)
  {

	if(H_is_set) {
		resetHessian(mesh);
	} else {
		initHessian(mesh);
		H_is_set = true;
	}

	double lAB, lBC, lCA;
	double wAB, wBC, wCA;
	double valAB, valBC, valCA, val2;
	double rA, rB, rC;
	int idxA, idxB, idxC;
	int n;

	double TwoPI = 2*M_PI;

	double val, alpha_A, alpha_B, alpha_C;
	double sA, sB, sC, cA, cB, cC;
	bool b1, b2, b3;

	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter e1, eAB, eBC, eCA;
	VertexIter v1, v2, v3;

	double angs[3], cotans[3];

	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		if(!v_it->inNorthPoleVicinity && !v_it->NorthPole)
		{
			idxA = v_it->indexN;
			B[idxA] = -TwoPI;
		}
	}

	for(FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		b1 = v1->inNorthPoleVicinity || v1->NorthPole;
		b2 = v2->inNorthPoleVicinity || v2->NorthPole;
		b3 = v3->inNorthPoleVicinity || v3->NorthPole;

		eAB = hAB->edge;
		eBC = hBC->edge;
		eCA = hCA->edge;

		idxA = v1->indexN;
		idxB = v2->indexN;
		idxC = v3->indexN;

		lAB = eAB->length;
		lBC = eBC->length;
		lCA = eCA->length;
		wAB = Eweight[eAB->index];
		wBC = Eweight[eBC->index];
		wCA = Eweight[eCA->index];

		trigAngles(lAB, lBC, lCA, angs, cotans, &n);
		alpha_A = angs[0]; alpha_B = angs[1]; alpha_C = angs[2];

		if(!b1) {
			B[idxA] += alpha_A;
		}
		if(!b2) {
			B[idxB] += alpha_B;
		}
		if(!b3) {
			B[idxC] += alpha_C;
		}

		rA = Radii[v1->index];
		rB = Radii[v2->index];
		rC = Radii[v3->index];

		sA = std::sin(alpha_A); cA = std::cos(alpha_A);
		sB = std::sin(alpha_B); cB = std::cos(alpha_B);
		sC = std::sin(alpha_C); cC = std::cos(alpha_C);

		val = rB/(lAB*lCA*sA);
		val2 = (rA*wAB - rC*wBC - lCA*cA*(rB+rA*wAB)/lAB);
		valAB = val*val2;
		if( !b1 && !b2 ) {
			H.coeffRef(idxA, idxB) += valAB;
			H.coeffRef(idxB, idxA) += valAB;
		}

		val = rC/(lCA*lAB*sA);
		val2 = (rA*wCA - rB*wBC - lAB*cA*(rC+rA*wCA)/lCA);
		valCA = val*val2;
		if( !b1 && !b3 ) {
			H.coeffRef(idxA, idxC) += valCA;
			H.coeffRef(idxC, idxA) += valCA;
		}

		val = rC/(lBC*lAB*sB);
		val2 = (rB*wBC - rA*wCA - lAB*cB*(rC+rB*wBC)/lBC);
		valBC = val*val2;
		if( !b2 && !b3 ) {
			H.coeffRef(idxB, idxC) += valBC;
			H.coeffRef(idxC, idxB) += valBC;
		}

		if(!b1) {
			H.coeffRef(idxA, idxA) -= valAB;
			H.coeffRef(idxA, idxA) -= valCA;
		}
		if(!b2) {
			H.coeffRef(idxB, idxB) -= valAB;
			H.coeffRef(idxB, idxB) -= valBC;
		}
		if(!b3) {
			H.coeffRef(idxC, idxC) -= valCA;
			H.coeffRef(idxC, idxC) -= valBC;
		}

	}
  }

  /* =============================================================================================
  Solve Linear System

   =============================================================================================== */

  bool Ricci::solveSystem(Mesh& mesh)
  {

  /* ==================================================================================================
	Compute Hessian and gradient
   ==================================================================================================*/

	computeHessian(mesh);

  /* ==================================================================================================
	Solve system
   ==================================================================================================*/


	if(nR==0) solverR.analyzePattern(H);
	nR++;
	solverR.factorize(H);

	double TOL=1.e-8;

	if(solverR.info()==Eigen::NumericalIssue)
	{
		std::cout << "Problem with Cholesky factorization!!" << std::endl;
		return false;
	}
	if(solverR.info()!=Eigen::Success)
	{
		std::cout << "Problem with Cholesky factorization!!" << std::endl;
		return false;
	}
	Sol = solverR.solve(B); 
	if(solverR.info()!=Eigen::Success)
	{
		std::cout << "Problem with Cholesky solver!!" << std::endl;
		return false;
	}
	double err = (H*Sol - B).norm();
	if(err > TOL || std::isnan(err) ) return false;
//	std::cout << "Error when solving system: " << err << std::endl;

	return true;
  }

  /* ===== MinimizeRicci  ===========================================================================
   *
   ==================================================================================================*/

  VertexIter Ricci::euclideanRicci(Mesh& mesh, int vtype, bool *SUCCESS)
  {

	int niter_max = 1000;
	double TOL = 1.e-8;
	*SUCCESS = true;

	int n_vertices = mesh.vertices.size();
	int n_edges    = mesh.edges.size();

  /* ==================================================================================================
        compute edge lengths
   ==================================================================================================*/

        computeLength0(mesh);

  /* ==================================================================================================
        Puncture Mesh
   ==================================================================================================*/

        VertexIter pole = pickPole(mesh, vtype);
        punctureMesh(mesh, pole);

  /* ==================================================================================================
	Count number of active vertices
   ==================================================================================================*/

	int ntot = 0;
	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) ntot++;
	}

  /* ==================================================================================================
	Create Eigen matrices and arrays needed to minimize energy
   ==================================================================================================*/

	H.resize(ntot, ntot);
	B.resize(ntot);
	Sol.resize(ntot);
	H_is_set = false;

	Radii   = new double[n_vertices];
	U       = new double[n_vertices];
	Vcurv   = new double[n_vertices];
	Eweight = new double[n_edges];

	H.setZero();
	for(int i = 0; i < n_vertices; i++) Radii[i]=1.0;

  /* ==================================================================================================
	Initialize Ricci discrete metric
   ==================================================================================================*/

	initRadii(mesh);
	initWeight(mesh);
	for(int i = 0; i < ntot; i++) U[i] = 0;

  /* ==================================================================================================
	Initial energy that will be refined
   ==================================================================================================*/

	double ene0, ene, ene1;
	int nbad;
	computeLength(mesh);
	ene = computeCurv(mesh, &ene0, &nbad);

  /* ==================================================================================================
	Now perform refinement
   ==================================================================================================*/

	double step = 0;
	double u;
	int zero = 0;

	std::cout << "Minimize Ricci functional:" << std::endl;
	std::cout << "==========================" << std::endl;
	std::cout << " " << std::endl;

	std::cout << "        " << "===========================================================================" << std::endl;
	std::cout << "        " << "       Iter       Step size          Max curv.     Bad triangles           " << std::endl;
        std::cout << "        " << "===========================================================================" << std::endl;
        std::cout << "        " << "   " << std::setw(8)<< zero << "    " << std::setw(12) << step;
	std::cout << "    " << std::setw(12) << ene0 ;
	std::cout << "      " << std::setw(8) << zero << std::endl;

	bool info;
	int idx;

	for(int niter = 0; niter < niter_max; niter++)
	{

		for(int i = 0; i < n_vertices; i++) U[i] = std::log(Radii[i]);

  		/* ====================================================================================
		Solve Jacobian system
   		======================================================================================*/

		info = solveSystem(mesh);
		if(!info) break;

  		/* ====================================================================================
		Find "optimal" step size:
			- try first step = 1; if corresponding energy ene1 < ene0, apply step
			- if ene1 > ene0, set step = 0.5*step and try again
   		======================================================================================*/

		step = 1.0;
		for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
		{
			idx = v_it->index;
			if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) {
				u = U[idx] + step*Sol[v_it->indexN]; 
				Radii[idx] = std::exp(u);
			}
		}
		computeLength(mesh);
		ene = computeCurv(mesh, &ene1, &nbad);

		while (((ene1 > ene0) || (nbad > 0)) && (step > 0.00001)) {
			step = 0.5*step;
			for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
			{
				idx = v_it->index;
				if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) {
					u = U[idx] + step*Sol[v_it->indexN]; 
					Radii[idx] = std::exp(u);
				}
			}
			computeLength(mesh);
			ene = computeCurv(mesh, &ene1, &nbad);
		}

        	std::cout << "        " << "   " << std::setw(8)<< niter+1 << "    " << std::setw(12) << step;
		std::cout << "    " << std::setw(12) << ene1 ;
		std::cout << "      " << std::setw(8) << nbad << std::endl;

		if(std::abs(ene1 - ene0) < TOL*ene1 || ene1 < TOL) break;
		ene0 = ene1;
	}
        std::cout << "        " << "===========================================================================" << std::endl;
	std::cout << " " << std::endl;

  /* ==================================================================================================
	Now layout on plane, and project onto sphere
   ==================================================================================================*/

	if(std::abs(ene1) > TOL || nbad > 0 || !info) {
		std::cout << " " << std::endl;
		std::cout << "Ricci flow has failed!" << std::endl;
		std::cout << " " << std::endl;
		*SUCCESS = false;
		return pole;
	}

	info = Layout::layout2D(mesh);
	if(!info) {
		*SUCCESS = false;
		return pole;
	}

	// Project onto sphere
	stereo2Sphere(mesh);

	// restore vertex star
	pole->NorthPole = false;
	HalfEdgeIter he = pole->he;
	do {
		he->face->inNorthPoleVicinity = false;

		he = he->flip->next;
	} while (he != pole->he);


	return pole;

  }
