
/* ====RICCIFLOW.H =================================================================================
 *
 * Author: Patrice Koehl, June 2019
 * Department of Computer Science
 * University of California, Davis
 *
 * This file applies a Euclidean Ricci flow method for computing a conformal map from a genus zero
 * surface onto the sphere
 *
 * It is based on the papers:
 *      [1] M. Jin, J. Kim, F. Luo, D. Gu. "Discrete Surface Ricci Flow", IEEE Trans. Viz. Comput. Graph.
 *              14, p1030-1043 (2008)
 *      [2] W. Zeng, D. Samaras, D. Gu. "Ricci Flow for 3D Shape Analysis", IEEE. Trans. Pattern. Anal.
 *              Mach. Intell, 32, p662-677 (2010)
 *      [3] X. Chen, H. He, G. Zou, X. Zhang, X. Gu, J. Hua. "Ricci flow-based spherical parameterization
 *              and surface registration", Comput. Vis. Image Understanding, 117, p1107-1118 (2013)
 *      [4] RicciFlowExtremalLength (Gu's code) 2010
 *          -> http://www3.cs.stonybrook.edu/~gu/tutorial/RicciFlow.html
 *
 =============================================================================================== */

#pragma once

  /* ===== INCLUDES AND TYPEDEFS =================================================================
   *
   * Third party libraries included: Eigen
   *
   =============================================================================================== */

  #define _USE_MATH_DEFINES // for M_PI

  #include <vector>
  #include <cmath>

  #include "iDT.h"
  #include "Layout2D.h"

  typedef Eigen::Triplet<double> Triplet;

  Eigen::SparseMatrix<double> H_RicciIDT;
  Eigen::VectorXd B_RicciIDT, Sol_RicciIDT;

  double *RadiiIDT, *VcurvIDT, *UIDT;
  double *EweightIDT;

  /* =============================================================================================
     Define class
   =============================================================================================== */

  class RicciFlowIDT{

  public:
	// select vertex from genus 0 surface that will be remove temperarily
	static VertexIter pickPole(Mesh& mesh, int type);

	// Removes a vertex from the mesh and set the tags.
	static void punctureMesh(Mesh& mesh, VertexIter pole);
	
	// Performs Ricci flow to plane
	static VertexIter euclideanRicciIDT(Mesh& mesh, int vtype, bool *SUCCESS);

  private:
 
	// init Hessian
	static void initHessian(Mesh& mesh);

	// Reset Hessian
	static void resetHessian(Mesh& mesh);

	// Compute Hessian and right hand side
	static void computeHessian(Mesh& mesh);

	// Solve Jacobian linear system
	static void solveSystem(Mesh& mesh);

 	// Compute initial lengths of all edges
	static void computeLength0(Mesh& mesh);

 	// Init radius for all vertices
	static void initRadii(Mesh& mesh);

 	// Init weights for all edges
	static void initWeight(Mesh& mesh);

 	// Compute new lengths for all edges
	static void computeLength(Mesh& mesh);

	// Compute angles in a triangle from edge lengths
	static void trigAngles(double lAB, double lBC, double lCA, double *ang, double *cotan, int *n);

 	// Compute curvatures using edge lengths
	static double computeCurv(Mesh& mesh, double *maxK, int *nbad, std::vector<EdgeIter>& list_bad);

	// Projection to sphere
	static void stereo2Sphere(Mesh& mesh);

	// flip "bad" edges
	static void flipMesh(Mesh& mesh, std::vector<EdgeIter>& list_bad);
  };

  /* ===== PICK VERTEX    ========================================================================
   * Select vertex that will be removed from the mesh to enable planar Tutte embedding.
   * At this stage, selection is based on finding the vertex with either the highest valence
   * (type = 1), or the lowest valence (type = 2). If type is set to 0, vertex # 0 is picked
   * arbitrarily
   =============================================================================================== */

  VertexIter RicciFlowIDT::pickPole(Mesh& mesh, int type)
  {
	if(type==0) return mesh.vertices.begin();

	/* Find vertex with largest, or smallest valence */

	int valence;
	int val_ref = mesh.vertices.begin()->degree();
	VertexIter pole = mesh.vertices.begin();

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		valence = v->degree();

		if(type==2) {
			if(valence < val_ref) {
				val_ref = valence;
				pole = v;
			}
		} else if(type==1) {
			if(valence > val_ref) {
				val_ref = valence;
				pole = v;
			}
		}
	}

	return pole;
  }

  /* ===== PUNCTURE MESH  ========================================================================
   * Remove vertex VR.  All incident
   * vertices and faces  are set to inNorthPoleVicinity
   =============================================================================================== */

  void RicciFlowIDT::punctureMesh(Mesh& mesh, VertexIter pole)
  {

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		v->NorthPole = false;
		v->inNorthPoleVicinity = false;
	}

	int idx = 0;
	pole->NorthPole = true;
	HalfEdgeIter he = pole->he;
	do {
		he->face->inNorthPoleVicinity = true;
		he->flip->vertex->indexN = idx;
		he->flip->vertex->inNorthPoleVicinity = true;
		idx++;

		he = he->flip->next;
	} while (he != pole->he);

	idx=0;
	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		bool b1 = v->NorthPole;
		bool b2 = v->inNorthPoleVicinity;

		if(!b1 && !b2) {
			v->indexN = idx;
			idx++;
		}
	}
  }

  /* ===== Stereo2Sphere   ============================================================================
   * Project planar mesh onto unit sphere
   ==================================================================================================*/

  void RicciFlowIDT::stereo2Sphere(Mesh& mesh)
  {

	double U, V, val, den;
	int idx;

        for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		bool b1 = v->NorthPole;
		idx = v->index;

		if(b1) {
			v->position2.x = 0.;
			v->position2.y = 0.;
			v->position2.z = -1.;
		} else {


			U = v->position2.x;
			V = v->position2.y;

			val = U*U + V*V;
			den = 1.0/(1.0+val);

			v->position2.x = 2.0*U*den;
			v->position2.y = 2.0*V*den;
			v->position2.z = (1.0-val)*den;

		}
	}
  }

  /* =============================================================================================
   Init Hessian matrix
   =============================================================================================== */

  void RicciFlowIDT::initHessian(Mesh& mesh)
  {

  /* ==================================================================================================
	Initialize all values for regular constraints to 0
   ==================================================================================================*/

	int idx, jdx;
	std::vector<Triplet> Mat_coefficients;
	double zero = 0;
	bool b1, b2, b3, b4;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		b1 = v->NorthPole;
		b2 = v->inNorthPoleVicinity;
		if(!b1 && !b2) {
			idx = v->indexN;
			Mat_coefficients.push_back(Triplet(idx, idx, zero));
			B_RicciIDT[idx] = 0.0;
		}
	}

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		b1 = v_i->NorthPole;
		b2 = v_i->inNorthPoleVicinity;
		b3 = v_j->NorthPole;
		b4 = v_j->inNorthPoleVicinity;

		if(!b1 && !b2 && !b3 && !b4) {
			idx = v_i->indexN;
			jdx = v_j->indexN;
			Mat_coefficients.push_back(Triplet(idx, jdx, zero));
			Mat_coefficients.push_back(Triplet(jdx, idx, zero));
		}
	}

	H_RicciIDT.setFromTriplets(Mat_coefficients.begin(), Mat_coefficients.end());

  }

  /* ===============================================================================================
   computeLength: computes initial length of all edges
   =============================================================================================== */

  void RicciFlowIDT::computeLength0(Mesh& mesh)
  {

	for(EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		HalfEdgeIter h = e->he;
		Vector p0 = h->vertex->position;
		Vector p1 = h->flip->vertex->position;
		double length = (p0-p1).norm();
		e->length = length;
	}
  }

  /* ===============================================================================================
   initRadii: initialize Ricci metric

   =============================================================================================== */

  void RicciFlowIDT::initRadii(Mesh& mesh)
  {

	int n_vertices = mesh.vertices.size();
	int *ncount = new int[n_vertices];

  /* ==================================================================================================
        Set iterators and pointers to Mesh data
   ==================================================================================================*/

	HalfEdgeIter he;
	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter e1, eAB, eBC, eCA;
	VertexIter v_i, v_j;
	VertexIter v1, v2, v3;

  /* ==================================================================================================
        Initialize
   ==================================================================================================*/

	int idx;
        for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
        {
		idx = v_iter->index;
		RadiiIDT[idx]=0.0;
		ncount[idx]=0;
        }

  /* ==================================================================================================
        Compute mean radius
   ==================================================================================================*/

	double lAB, lBC, lCA;
	int idxA, idxB, idxC;

	for(FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		eAB = hAB->edge;
		eBC = hBC->edge;
		eCA = hCA->edge;

		idxA = v1->index;
		idxB = v2->index;
		idxC = v3->index;

		lAB = eAB->length;
		lBC = eBC->length;
		lCA = eCA->length;

		RadiiIDT[idxA] += (lAB+lCA-lBC)/2; ncount[idxA]++;
		RadiiIDT[idxB] += (lAB+lBC-lCA)/2; ncount[idxB]++;
		RadiiIDT[idxC] += (lCA+lBC-lAB)/2; ncount[idxC]++;
	}

        for(VertexIter v_iter = mesh.vertices.begin(); v_iter != mesh.vertices.end(); v_iter++)
        {
		idx = v_iter->index;
		RadiiIDT[idx] /= ncount[idx];
        }

	delete [] ncount;

   }

  /* ===============================================================================================
   InitWeight: initialize Inversive distance for each edge
   =============================================================================================== */

  void RicciFlowIDT::initWeight(Mesh& mesh)
  {

	int idx;
	double length;
	double rA, rB, IAB;
	VertexIter v0, v1;

	int nbad = 0;
	for(EdgeIter e_iter = mesh.edges.begin(); e_iter != mesh.edges.end(); e_iter++)
	{
		idx = e_iter->index;
		v0 = e_iter->he->vertex;
		v1 = e_iter->he->flip->vertex;

		length = e_iter->length;
		rA = RadiiIDT[v0->index];
		rB = RadiiIDT[v1->index];

		IAB = (length*length - rA*rA - rB*rB)/(2*rA*rB);
		if(IAB<0) {
			IAB = 0.;
			nbad++;
		}

		EweightIDT[idx] = IAB;
	}
	std::cout << "nbad = " << nbad << std::endl;

  }
  /* ===============================================================================================
   RicciNewLength: compute edge length based on new radii
   =============================================================================================== */

  void RicciFlowIDT::computeLength(Mesh& mesh) 
  {

	int idx;
	double length;
	double rA, rB, IAB;
	VertexIter v0, v1;

	for(EdgeIter e_iter = mesh.edges.begin(); e_iter != mesh.edges.end(); e_iter++)
	{
		idx = e_iter->index;
		v0 = e_iter->he->vertex;
		v1 = e_iter->he->flip->vertex;

		rA = RadiiIDT[v0->index];
		rB = RadiiIDT[v1->index];
		IAB = EweightIDT[idx];

		length = rA*rA + rB*rB + 2*rA*rB*IAB;
		length = std::sqrt(length);

		e_iter->length = length;
	}
   }

  /* ===== Jacobian system          =================================================================
	Hessian 
   ==================================================================================================*/

  void RicciFlowIDT::computeHessian(Mesh& mesh)
  {

	initHessian(mesh);

	double lAB, lBC, lCA;
	double wAB, wBC, wCA;
	double valAB, valBC, valCA, val2;
	double rA, rB, rC;
	int idxA, idxB, idxC;

	double TwoPI = 2*M_PI;

	double val, alpha_A, alpha_B, alpha_C;
	double sA, sB, sC, cA, cB, cC;
	bool b1, b2, b3;

	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter e1, eAB, eBC, eCA;
	VertexIter v1, v2, v3;

	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		if(!v_it->inNorthPoleVicinity && !v_it->NorthPole)
		{
			idxA = v_it->indexN;
			B_RicciIDT[idxA] = -TwoPI;
		}
	}

	for(FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		b1 = v1->inNorthPoleVicinity || v1->NorthPole;
		b2 = v2->inNorthPoleVicinity || v2->NorthPole;
		b3 = v3->inNorthPoleVicinity || v3->NorthPole;

		eAB = hAB->edge;
		eBC = hBC->edge;
		eCA = hCA->edge;

		idxA = v1->indexN;
		idxB = v2->indexN;
		idxC = v3->indexN;

		lAB = eAB->length;
		lBC = eBC->length;
		lCA = eCA->length;
		wAB = EweightIDT[eAB->index];
		wBC = EweightIDT[eBC->index];
		wCA = EweightIDT[eCA->index];

		if( lAB > lBC + lCA) {
			alpha_A = 0; alpha_B = 0; alpha_C = M_PI;
		} else if (lBC > lAB + lCA) {
			alpha_A = M_PI; alpha_B = 0; alpha_C = 0;
		} else if (lCA > lAB + lBC) {
			alpha_A = 0; alpha_B = M_PI; alpha_C = 0;
		} else {
			val = ((lBC + lAB-lCA)*(lBC-lAB+lCA))/((lAB+lBC+lCA)*(-lBC+lAB+lCA));
			val = std::sqrt(val);
			alpha_A = 2.0*atan(val); 
			val = ((lCA + lAB-lBC)*(lCA-lAB+lBC))/((lAB+lBC+lCA)*(-lCA+lAB+lBC));
			val = std::sqrt(val);
			alpha_B = 2.0*atan(val); 
			val = ((lAB + lBC-lCA)*(lAB-lBC+lCA))/((lAB+lBC+lCA)*(-lAB+lBC+lCA));
			val = std::sqrt(val);
			alpha_C = 2.0*atan(val); 
		}

		if(!b1) {
			B_RicciIDT[idxA] += alpha_A;
		}
		if(!b2) {
			B_RicciIDT[idxB] += alpha_B;
		}
		if(!b3) {
			B_RicciIDT[idxC] += alpha_C;
		}

		rA = RadiiIDT[v1->index];
		rB = RadiiIDT[v2->index];
		rC = RadiiIDT[v3->index];

		sA = std::sin(alpha_A); cA = std::cos(alpha_A);
		sB = std::sin(alpha_B); cB = std::cos(alpha_B);
		sC = std::sin(alpha_C); cC = std::cos(alpha_C);

		val = rB/(lAB*lCA*sA);
		val2 = (rA*wAB - rC*wBC - lCA*cA*(rB+rA*wAB)/lAB);
		valAB = val*val2;
		if( !b1 && !b2 ) {
			H_RicciIDT.coeffRef(idxA, idxB) += valAB;
			H_RicciIDT.coeffRef(idxB, idxA) += valAB;
		}

		val = rC/(lCA*lAB*sA);
		val2 = (rA*wCA - rB*wBC - lAB*cA*(rC+rA*wCA)/lCA);
		valCA = val*val2;
		if( !b1 && !b3 ) {
			H_RicciIDT.coeffRef(idxA, idxC) += valCA;
			H_RicciIDT.coeffRef(idxC, idxA) += valCA;
		}

		val = rC/(lBC*lAB*sB);
		val2 = (rB*wBC - rA*wCA - lAB*cB*(rC+rB*wBC)/lBC);
		valBC = val*val2;
		if( !b2 && !b3 ) {
			H_RicciIDT.coeffRef(idxB, idxC) += valBC;
			H_RicciIDT.coeffRef(idxC, idxB) += valBC;
		}

		if(!b1) {
			H_RicciIDT.coeffRef(idxA, idxA) -= valAB;
			H_RicciIDT.coeffRef(idxA, idxA) -= valCA;
		}
		if(!b2) {
			H_RicciIDT.coeffRef(idxB, idxB) -= valAB;
			H_RicciIDT.coeffRef(idxB, idxB) -= valBC;
		}
		if(!b3) {
			H_RicciIDT.coeffRef(idxC, idxC) -= valCA;
			H_RicciIDT.coeffRef(idxC, idxC) -= valBC;
		}

	}
  }

  /* ===============================================================================================
   flipMesh: flip "bad" triangles
   =============================================================================================== */

  void RicciFlowIDT::flipMesh(Mesh& mesh, std::vector<EdgeIter>& list_bad)
  {

	for(int i = 0; i < list_bad.size(); i++)
	{
		EdgeIter e = list_bad[i];
//		if(e->he->vertex->inNorthPoleVicinity && e->he->flip->vertex->inNorthPoleVicinity) continue;
		IDT::flipEdge(e);
	}
  }

  /* ===============================================================================================
   computeCurv: computes the curvatures of all vertices

   The curvature is defined as the excess angle at each vertex i, i.e. to
   2pi - sum_t Theta(i,ijk)
   where sum_t indicates a sum over all triangles (ijk) that are incident to i
         Theta(i,ijk) is the angle at vertex i in the triangle (i,jk)
   =============================================================================================== */

  double RicciFlowIDT::computeCurv(Mesh& mesh, double *maxK, int *nbad, std::vector<EdgeIter>& list_bad)
  {
	double TwoPI = 2*M_PI;

/* 	==========================================================================================
	Initialize Curvature to 2*PI
        ========================================================================================== */

	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter eAB, eBC, eCA;
	VertexCIter va, vb, vc;
	double lAB, lCA, lBC;

	int idxa, idxb, idxc;

	for (VertexCIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		idxa = v->index;
		VcurvIDT[idxa] = TwoPI;
	}

/* 	==========================================================================================
	Iterate over all triangles and remove angle for each vertex
        ========================================================================================== */

	int nb = 0;
	list_bad.clear();
	int n;
	double val, alpha_A, alpha_B, alpha_C;

	for (FaceCIter f = mesh.faces.begin(); f != mesh.faces.end(); f++)
	{
		hAB = f->he;
		hBC = hAB->next;
		hCA = hBC->next;
		
		va = hAB->vertex;
		vb = hBC->vertex;
		vc = hCA->vertex;
		idxa = va->index;
		idxb = vb->index;
		idxc = vc->index;

		lAB = hAB->edge->length;
		lBC = hBC->edge->length;
		lCA = hCA->edge->length;

		n = 0;
		if( lAB > lBC + lCA) {
			alpha_A = 0; alpha_B = 0; alpha_C = M_PI;
			n++;
		} else if (lBC > lAB + lCA) {
			alpha_A = M_PI; alpha_B = 0; alpha_C = 0;
			n++;
		} else if (lCA > lAB + lBC) {
			alpha_A = 0; alpha_B = M_PI; alpha_C = 0;
			n++;
		} else {
			val = ((lBC + lAB-lCA)*(lBC-lAB+lCA))/((lAB+lBC+lCA)*(-lBC+lAB+lCA));
			val = std::sqrt(val);
			alpha_A = 2.0*atan(val);
			val = ((lCA + lAB-lBC)*(lCA-lAB+lBC))/((lAB+lBC+lCA)*(-lCA+lAB+lBC));
			val = std::sqrt(val);
			alpha_B = 2.0*atan(val);
			val = ((lAB + lBC-lCA)*(lAB-lBC+lCA))/((lAB+lBC+lCA)*(-lAB+lBC+lCA));
			val = std::sqrt(val);
			alpha_C = 2.0*atan(val);
		}

		VcurvIDT[idxa] -= alpha_A;
		VcurvIDT[idxb] -= alpha_B;
		VcurvIDT[idxc] -= alpha_C;
		nb += n;

		if(n > 0) {
			if(lAB > std::max(lBC, lCA)) {
				list_bad.push_back(hAB->edge);
			} else if (lBC > std::max(lAB, lCA)) {
				list_bad.push_back(hBC->edge);
			} else {
				list_bad.push_back(hCA->edge);
			}
		}
	}

	double err=0, errM=0;
	for (VertexCIter v = mesh.vertices.begin(); v != mesh.vertices.end(); v++)
	{
		idxa = v->index;
		if(!v->inNorthPoleVicinity && !v->NorthPole) {
			val = VcurvIDT[idxa];
			err += val*val;
			errM = std::max(errM, std::abs(val));
		}
	}
	err = std::sqrt(err);
	*maxK = errM;
	*nbad = nb;
	return err;

  }


  /* =============================================================================================
  Solve Linear System

   =============================================================================================== */

  void RicciFlowIDT::solveSystem(Mesh& mesh)
  {

  /* ==================================================================================================
	Compute Hessian and gradient
   ==================================================================================================*/

	computeHessian(mesh);

  /* ==================================================================================================
	Solve system
   ==================================================================================================*/

//	Eigen::SimplicialCholesky<Eigen::SparseMatrix<double>> chol(H_RicciIDT);  // performs a Cholesky factorization of H
	Eigen::CholmodDecomposition<Eigen::SparseMatrix<double>> chol(H_RicciIDT);


	if(chol.info()==Eigen::NumericalIssue)
	{
		std::cout << "Problem with Cholesky factorization!!" << std::endl;
		exit(1);
	}
	if(chol.info()!=Eigen::Success)
	{
		std::cout << "Problem with Cholesky factorization!!" << std::endl;
		exit(1);
	}
	Sol_RicciIDT = chol.solve(B_RicciIDT); 
	if(chol.info()!=Eigen::Success)
	{
		std::cout << "Problem with Cholesky solver!!" << std::endl;
		exit(1);
	}

  }

  /* ===== MinimizeYamabe  ===========================================================================
   *
   ==================================================================================================*/

  VertexIter RicciFlowIDT::euclideanRicciIDT(Mesh& mesh, int vtype, bool *SUCCESS)
  {

	Mesh *model = new Mesh(mesh);
	Mesh mesh2 = *model;

	int niter_max = 300;
	double TOL = 1.e-10;

	*SUCCESS = true;

	int n_vertices = mesh.vertices.size();
	int n_edges    = mesh.edges.size();

  /* ==================================================================================================
        Compute initial edge lengths
   ==================================================================================================*/

	computeLength0(mesh2);

  /* ==================================================================================================
        Puncture Mesh
   ==================================================================================================*/

        VertexIter pole = pickPole(mesh2, vtype);
        punctureMesh(mesh2, pole);

  /* ==================================================================================================
	Count number of active vertices
   ==================================================================================================*/

	int idx;
	int ntot = 0;
	for (VertexIter v_it = mesh2.vertices.begin(); v_it != mesh2.vertices.end(); v_it++)
	{
		if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) ntot++;
	}

  /* ==================================================================================================
	Create Eigen matrices and arrays needed to minimize energy
   ==================================================================================================*/

	H_RicciIDT.resize(ntot, ntot);
	B_RicciIDT.resize(ntot);
	Sol_RicciIDT.resize(ntot);

	RadiiIDT   = new double[n_vertices];
	UIDT       = new double[n_vertices];
	VcurvIDT   = new double[n_vertices];
	EweightIDT = new double[n_edges];

	H_RicciIDT.setZero();
	for(int i = 0; i < n_vertices; i++) RadiiIDT[i]=1.0;

  /* ==================================================================================================
        apply iDT on initial mesh
   ==================================================================================================*/

	IDT::iDT(mesh2);

  /* ==================================================================================================
	Initialize Yamabe discrete metric
   ==================================================================================================*/

	initRadii(mesh2);
	initWeight(mesh2);
	for(int i = 0; i < ntot; i++) UIDT[i] = 0;

  /* ==================================================================================================
	Initial energy that will be refined
   ==================================================================================================*/

	double ene0, ene, ene1;
	int nbad;
	std::vector<EdgeIter> list_bad;
	computeLength(mesh2);
	ene = computeCurv(mesh2, &ene0, &nbad, list_bad);

  /* ==================================================================================================
	Now perform refinement
   ==================================================================================================*/

	double step = 0;
	double u;
	int zero = 0;

	std::cout << "Minimize Ricci functional:" << std::endl;
	std::cout << "==========================" << std::endl;
	std::cout << " " << std::endl;

	std::cout << "        " << "===========================================================================" << std::endl;
	std::cout << "        " << "       Iter       Step size          Max curv.     Bad triangles           " << std::endl;
        std::cout << "        " << "===========================================================================" << std::endl;
        std::cout << "        " << "   " << std::setw(8)<< zero << "    " << std::setw(12) << step;
	std::cout << "    " << std::setw(12) << ene0 ;
	std::cout << "      " << std::setw(8) << zero << std::endl;

	std::vector<double> save_length;

	for(int niter = 0; niter < niter_max; niter++)
	{

		save_length.clear();
		for(EdgeIter e = mesh2.edges.begin(); e != mesh2.edges.end(); e++) save_length.push_back(e->length);

		for(int i = 0; i < n_vertices; i++) UIDT[i] = std::log(RadiiIDT[i]);

  		/* ====================================================================================
		Solve Jacobian system
   		======================================================================================*/

		solveSystem(mesh2);

  		/* ====================================================================================
		Find "optimal" step size:
			- try first step = 1; if corresponding energy ene1 < ene0, apply step
			- if ene1 > ene0, set step = 0.5*step and try again
   		======================================================================================*/

		step = 1.0;
		for (VertexIter v_it = mesh2.vertices.begin(); v_it != mesh2.vertices.end(); v_it++)
		{
			idx = v_it->index;
			if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) {
				u = UIDT[idx] + step*Sol_RicciIDT[v_it->indexN]; 
				RadiiIDT[idx] = std::exp(u);
			}
		}
		computeLength(mesh2);
		ene = computeCurv(mesh2, &ene1, &nbad, list_bad);

		double step_min = 0.5;

		while (((ene1 > ene0) || (nbad > 0)) && (step > step_min)) {
			step = 0.5*step;
			for (VertexIter v_it = mesh2.vertices.begin(); v_it != mesh2.vertices.end(); v_it++)
			{
				idx = v_it->index;
				if(!v_it->inNorthPoleVicinity && !v_it->NorthPole) {
					u = UIDT[idx] + step*Sol_RicciIDT[v_it->indexN]; 
					RadiiIDT[idx] = std::exp(u);
				}
			}
			computeLength(mesh2);
			ene = computeCurv(mesh2, &ene1, &nbad, list_bad);
		}

        	std::cout << "        " << "   " << std::setw(8)<< niter+1 << "    " << std::setw(12) << step;
		std::cout << "    " << std::setw(12) << ene1 ;
		std::cout << "      " << std::setw(8) << nbad << std::endl;

		if(std::abs(ene1 - ene0) < TOL*ene1 && ene1 < TOL) break;
		if(ene1 < TOL && nbad == 0) break;
		ene0 = ene1;

		if(nbad > 0) {
			int ie = 0;
			for(EdgeIter e = mesh2.edges.begin(); e != mesh2.edges.end(); e++) {
				e->length = save_length[ie];
				ie++;
			}
			for(int i = 0; i < n_vertices; i++) RadiiIDT[i] = std::exp(UIDT[i]);
			flipMesh(mesh2, list_bad);
		}
	}
        std::cout << "        " << "===========================================================================" << std::endl;
	std::cout << " " << std::endl;

	if(std::abs(ene1) > TOL || nbad > 0) {
		std::cout << " " << std::endl;
		std::cout << "Ricci flow has failed!" << std::endl;
		std::cout << " " << std::endl;
		*SUCCESS = false;
		return pole;
	}

  /* ==================================================================================================
	Now layout on plane, and project onto sphere
   ==================================================================================================*/

	bool info = Layout::layout2D(mesh2);

	if(!info) {
		*SUCCESS = false;
		return pole;
	}
	std::vector<Vector> positions;
	for(VertexIter v2 = mesh2.vertices.begin() ; v2 != mesh2.vertices.end(); v2++) {
		positions.push_back(v2->position2);
	}			
	int i=0;
	for(VertexIter v = mesh.vertices.begin() ; v != mesh.vertices.end(); v++) {
		v->position2 = positions[i];
		i++;
	}
	positions.clear();

	// Project onto sphere
	stereo2Sphere(mesh);

	// restore vertex star
	pole->NorthPole = false;
	HalfEdgeIter he = pole->he;
	do {
		he->face->inNorthPoleVicinity = false;

		he = he->flip->next;
	} while (he != pole->he);


	delete [] RadiiIDT; delete [] UIDT; delete [] VcurvIDT; delete [] EweightIDT;

	return pole;

  }

