/* ====LAYOUT2D.H ===============================================================================
 *
 * Author: Patrice Koehl, April 2019
 * Department of Computer Science
 * University of California, Davis
 *
 * This file generates a 2D layout of a mesh based on its edge lengths
 *
 * It is based on the paper:
 *	B. Springborn, P. Schroeder, and U. Pinkall. "Conformal equivalence of triangle meshes",
 *	ACM Trans. Graph. 27, 3, Article 77 (2008).
 *
 =============================================================================================== */

#pragma once

  /* ===== INCLUDES AND TYPEDEFS =================================================================
   *
   =============================================================================================== */

  #define _USE_MATH_DEFINES // for M_PI

  #include <vector>
  #include <deque>
  #include <cmath>
  #include <algorithm>

  extern "C" {

	void lbfgs_(int *N, int *M, double *X, double *F, double *G, int *DIAGCO, double *DIAG,
	int *IPRINT, double *EPS, double *XTOL, double *W, double *GNORM, int *IFLAG);

  }

  /* =============================================================================================
     Define class
   =============================================================================================== */

  class Layout{

  public:
	// Layout mesh whose geometry is set by edge lengths onto the plane
	static bool layout2D(Mesh& mesh);

  private:
 
	// "Energy" of the layout
	static double layoutEnergy(Mesh& mesh, int *vertex_status);

	// "Energy" of the layout, with derivatives
	static double layoutEneDer(Mesh& mesh, int *vertex_status, double *Deriv);

	// Optimizes layout using lbfgs
	static bool layoutMinim(Mesh& mesh, int *vertex_status, int iprint);

	// Center and scale current 2D mesh
	static void centerScale2D(Mesh& mesh);

  };

  /* ===== LayoutEne       ========================================================================
	Computes the current "energy" of the layout
   =============================================================================================== */

  double Layout::layoutEnergy(Mesh& mesh, int *vertex_status) 
  {
  /* ==================================================================================================
        Set iterators and pointers to Mesh data
   ==================================================================================================*/

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;

  /* ==================================================================================================
	Initialize energy
   ==================================================================================================*/

	double ene = 0;

  /* ==================================================================================================
	Now sum over all edges that have already been drawn
   ==================================================================================================*/

	int idx1, idx2;
	double l1, l2;
	Vector point1, point2;

	for (EdgeIter e_it = mesh.edges.begin(); e_it != mesh.edges.end(); e_it++)
	{
                he = e_it->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		idx1 = v_i->index;
		idx2 = v_j->index;

		if (vertex_status[idx1]==1 && vertex_status[idx2]==1) {
			point1 = v_i->position2;
			point2 = v_j->position2;
			l1 = (point1-point2).norm();
			l2 = e_it->length;
			ene += (l1-l2)*(l1-l2);
		}
	}

	return ene;
  }

  /* ===== LayoutEneDer    ========================================================================
	Computes the current "energy" of the layout and its derivatives wrt to coordinates
   =============================================================================================== */

  double Layout::layoutEneDer(Mesh& mesh, int *vertex_status, double *Deriv)
  {
  /* ==================================================================================================
        Set iterators and pointers to Mesh data
   ==================================================================================================*/

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;

	int n_vertices = mesh.vertices.size();

  /* ==================================================================================================
	Initialize energy
   ==================================================================================================*/

	double ene = 0;
	memset(Deriv, 0, 2*n_vertices*sizeof(double));

  /* ==================================================================================================
	Now sum over all edges that have already been drawn
   ==================================================================================================*/

	int idx1, idx2;
	double l1, l2;
	double dx, dy;
	double val, valx, valy;
	Vector point1, point2;

	for (EdgeIter e_it = mesh.edges.begin(); e_it != mesh.edges.end(); e_it++)
	{
                he = e_it->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		idx1 = v_i->index;
		idx2 = v_j->index;

		if (vertex_status[idx1]==1 && vertex_status[idx2]==1) {
			point1 = v_i->position2;
			point2 = v_j->position2;
			l1 = (point1-point2).norm();
			l2 = e_it->length;
			ene += (l1-l2)*(l1-l2);

			dx = point1[0]-point2[0];
			dy = point1[1]-point2[1];
			val = (l1-l2)/l1;
			valx = 2*val*dx;
			valy = 2*val*dx;

			Deriv[2*idx1] += valx;
			Deriv[2*idx1+1] += valy;
			Deriv[2*idx2] -= valx;
			Deriv[2*idx2+1] -= valy;
		}
	}

	return ene;
  }

  /* ===== LayoutMinim     ========================================================================
	Optimizes layout using lbfgs
   =============================================================================================== */

  bool Layout::layoutMinim(Mesh& mesh, int *vertex_status, int iprint)
  {
  /* ==================================================================================================
        Set iterators and pointers to Mesh data
   ==================================================================================================*/

	int n_vertices = mesh.vertices.size();

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;

  /* ==================================================================================================
	Variables for lbfgs
   ==================================================================================================*/

	int N=0;
	for(int i = 0; i < n_vertices; i++) {
		if(vertex_status[i] == 1) N++;
	}

	int M = 5; 
	int IFLAG = 0;
	int IPRINT[2]; IPRINT[0] = -1; IPRINT[1] = 0;

	int DIAGCO = 0;

	double EPS = 1.e-6;
	double XTOL = 1.e-14;

	double GNORM, grms;

	double *X = new double[2*N];
	double *Grad = new double[2*N];
	double *Work  = new double[(2*M+1)*2*N + 2*M];
	double *Deriv = new double[2*n_vertices];
	double *Diag  = new double[2*N];

	double err, errmax;
	double diff;
	int n, idx1, idx2;
	double lAB, l1;

  /* ==================================================================================================
	Initialize energy
   ==================================================================================================*/

	double ene = layoutEnergy(mesh, vertex_status);

  /* ==================================================================================================
	Minimization loop
   ==================================================================================================*/

	int iter = 0;
	int iter_max = 1000;
	int idx;

	double tol = 1.e-8; 
	Vector point;

	while(iter < iter_max)
	{

		ene = layoutEneDer(mesh, vertex_status, Deriv);

		double norm = 0;
		N = 0;
		for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
		{
			idx = v_it->index;
			if(vertex_status[idx] == 1) {
				point = v_it->position2;
				X[2*N] = point[0];
				X[2*N+1] = point[1];
				Grad[2*N] = Deriv[2*idx];
				Grad[2*N+1] = Deriv[2*idx+1];
				norm += Grad[2*N]*Grad[2*N] + Grad[2*N+1]*Grad[2*N+1];
				N++;
			}
		}
		norm = std::sqrt(norm/N);
		if(norm < tol) break;

		N = 2*N;
		lbfgs_(&N, &M, X, &ene, Grad, &DIAGCO, Diag, IPRINT, &EPS, &XTOL,
		Work, &GNORM, &IFLAG);

		grms = GNORM/(std::sqrt(N/3.));

		N = 0;
		for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
		{
			idx = v_it->index;
			if(vertex_status[idx] == 1) {
				v_it->position2[0] = X[2*N];
				v_it->position2[1] = X[2*N+1];
				v_it->position2[2] = 0.0;
				N++;
			}
		}

		if(IFLAG <=0) break;

		iter++;
	}

  /* ==========================================================================================
       Check that edge lengths are correct
  ==========================================================================================*/

	Vector pointA, pointB;
	bool SUCCESS = true;
	if(iprint==1) {

		err = 0; errmax = 0;
		n=0;
		for(EdgeIter e_it = mesh.edges.begin(); e_it != mesh.edges.end(); e_it++)
		{
                	he = e_it->he;
			he2 = he->flip;

			v_i = he->vertex;
			v_j = he2->vertex;

			idx1 = v_i->index;
			idx2 = v_j->index;

			if(vertex_status[idx1] == 1 && vertex_status[idx2]==1) {
				pointA = v_i->position2;
				pointB = v_j->position2;

				lAB = (pointA-pointB).norm();
				l1  = e_it->length;
				diff = lAB - l1;
				err += diff*diff;
				errmax = std::max(errmax, std::abs(diff));
				n++;
			}
		}

		err=std::sqrt(err/n);

		std::cout << " " << std::endl;
		std::cout << "After refinement  " << std::endl;
		std::cout << "RMS error on edge lengths            : " << err << std::endl;
		std::cout << "Max error (l1-norm) on edge lengths  : " << errmax << std::endl;
		std::cout << " " << std::endl;

		if(isnan(err) || err > 1.e-2) {
			std::cout<< "Layout was not successful " << std::endl;
			std::cout << " " << std::endl;
			SUCCESS = false;
		}

	}

  /* ==========================================================================================
       Free memory
  ==========================================================================================*/

	delete [] Diag; delete [] X; delete [] Grad; delete [] Deriv; delete [] Work;

	return SUCCESS;

  }

  /* ===== CenterScale2D   ========================================================================
	Center and scale current 2D mesh
   =============================================================================================== */

  void Layout::centerScale2D(Mesh& mesh)
  {

  /* ==================================================================================================
	Finding the circle that passes through the boundary points
   ==================================================================================================*/

	double xmean = 0;
	double ymean = 0;

	int npoint = 0;
	Vector point;
	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		if(v_it->inNorthPoleVicinity) {
			point = v_it->position2;
			xmean += point[0];
			ymean += point[1];
			npoint++;
		}
	}

	xmean = xmean/npoint;
	ymean = ymean/npoint;

	double Suu = 0; double Suv = 0; double Svv = 0;
	double Suuu = 0; double Svvv = 0; double Suvv = 0; double Svuu = 0;
	double valu, valv, valu2, valv2;

	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		if(v_it->inNorthPoleVicinity) {
			point = v_it->position2;
			valu = point[0] - xmean;
			valv = point[1] - ymean;
			valu2 = valu*valu;
			valv2 = valv*valv;
			Suu += valu2; Suv += valu*valv; Svv += valv2;
			Suuu += valu2*valu; Svvv += valv2*valv;
			Suvv += valu*valv2; Svuu += valv*valu2;
		}
	}

	double val1 = 0.5*(Suuu + Suvv);
	double val2 = 0.5*(Svvv + Svuu);

	double det = Suu*Svv - Suv*Suv;

	double Uc = (val1*Svv - val2*Suv)/det;
	double Vc = (val2*Suu - val1*Suv)/det;

	double Radius = Uc*Uc + Vc*Vc + (Suu+Svv)/npoint;
	Radius = std::sqrt(Radius);

	Uc = Uc + xmean;
	Vc = Vc + ymean;
	
  /* ==================================================================================================
	Update coordinates of vertices: they should be in the plane now
	Scale so that the initial circle of radius Radius has radius
        sin(PI)/(1-cos(PI))
   ==================================================================================================*/

	double ang1 = M_PI/180.0;

	double factor = std::sin(ang1)/(1.0-std::cos(ang1));
	factor = factor/Radius;

	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		if(!v_it->NorthPole) {
			point = v_it->position2;
			point[0] = factor*(point[0] - Uc);
			point[1] = factor*(point[1] - Vc);
			v_it->position2[0] = point[0];
			v_it->position2[1] = point[1];
			v_it->position2[2] = 0.0;
		}
	}
  }

  /* ===== Layout2D        ========================================================================
	Layout mesh whose geometry is set by edge lengths onto the plane
   =============================================================================================== */

  bool Layout::layout2D(Mesh& mesh)
  {

  /* ==================================================================================================
        Set some local arrays
   ==================================================================================================*/

	int n_vertices = mesh.vertices.size();
	int n_faces    = mesh.faces.size();

	int *v_visited = new int[n_vertices];
	int *f_visited = new int[n_faces];

  /* ==================================================================================================
        Set iterators and pointers to Mesh data
   ==================================================================================================*/

	HalfEdgeIter he, he2;
	HalfEdgeIter hAB, hBC, hCA;
	EdgeIter e1, eAB, eBC, eCA;
	VertexIter v_i, v_j;
	VertexIter v1, v2, v3;
	FaceIter f1, f2, f3;

  /* ==================================================================================================
        Initialize vertices and faces as "not visited" (i.e. 0)
   ==================================================================================================*/

	int idx;
	int idxA, idxB, idxC;

	for (VertexIter v_it = mesh.vertices.begin(); v_it != mesh.vertices.end(); v_it++)
	{
		idxA = v_it->index;
		if(!v_it->NorthPole) {
			v_visited[idxA] = 0;
		} else {
			v_visited[idxA] = -1;
		}
		v_it->position2[0] = 0; v_it->position2[1] = 0; v_it->position2[2] = 0;
	}

	for (FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		idx=f_iter->index;

		if(!f_iter->inNorthPoleVicinity) {
			f_visited[idx] = 0;
		} else {
			f_visited[idx] = 1;
		}
	}
	
/* ==========================================================================================
       Start with first "active" triangle: set first edge to (0,0) and (l,0)
   ========================================================================================*/

	double l1, l2;
	std::deque<FaceIter> heap;

	for (FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++)
	{
		idx=f_iter->index;

		if(f_visited[idx] == 0) {
			he=f_iter->he;
			e1=he->edge;
			v_i = he->vertex;
			v_j = he->flip->vertex;
			l1 = e1->length;

			idxA = v_i->index;
			idxB = v_j->index;
			v_visited[idxA] = 1;
			v_visited[idxB] = 1;

			v_i->position2[0] = 0; v_i->position2[1] = 0; v_i->position2[2] = 0;
			v_j->position2[0] = 0; v_j->position2[1] = l1; v_j->position2[2] = 0;
			f_visited[idx] = 1;
			heap.push_back(f_iter);
			break;
		}
	}

/* ==========================================================================================
       Now loop over all triangles using a breadth first approach:
       for a given triangle in the list:
               - the position of two of its vertices should already be known
               - the length of its three edges should be known
               - we compute the position of the third vertex using the fact
                 that it is at the intersection of two circles whose centers
                 are the two other vertices, and radii the corresponding
                 edge length
 ========================================================================================*/

	int nbad  = 0;
	int j;
	double lAB, lBC, lCA;
	double e, h, val;
	int nvisited = 2;
	int iprint;
	Vector pointA, pointB, pointC;
			
	while( !heap.empty() )
	{
		FaceIter fh = heap[0];
		heap.pop_front();

		if(fh->inNorthPoleVicinity) continue;

		idx = fh->index;
			
		hAB =fh->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		eAB = hAB->edge;
		eBC = hBC->edge;
		eCA = hCA->edge;

		pointA = v1->position2;
		pointB = v2->position2;
		pointC = v3->position2;

		idxA = v1->index;
		idxB = v2->index;
		idxC = v3->index;

		lAB = eAB->length;
		lBC = eBC->length;
		lCA = eCA->length;

		j = std::abs(v_visited[idxA])+std::abs(v_visited[idxB])
			+std::abs(v_visited[idxC]);

		if(j<2) {
			std::cout << "Problem with face: " << idx << std::endl;
			std::cout << "Vertices         : " << idxA << " " << idxB << " " << idxC << std::endl;
			std::cout << " " << std::endl;
			exit(1);
		}

		if(j < 3) {
			if(v_visited[idxC] == 0) {

				lAB = (pointA - pointB).norm();
				e = 0.5*(lAB*lAB + lCA*lCA - lBC*lBC)/lAB;

				val = lCA*lCA-e*e;
				if(val > 0) {
					h = -std::sqrt(val);
				} else {
					h = -std::sqrt(-val);
					nbad++;
				}

				v3->position2[0] = pointA[0] + e*(pointB[0]-pointA[0])/lAB 
					+h*(pointB[1]-pointA[1])/lAB;
				v3->position2[1] = pointA[1] + e*(pointB[1]-pointA[1])/lAB 
					-h*(pointB[0]-pointA[0])/lAB;
				v3->position2[2] = 0;

				v_visited[idxC] = 1;

			}

			if(v_visited[idxB] == 0) {

				lCA = (pointA - pointC).norm();
				e = 0.5*(lCA*lCA + lAB*lAB - lBC*lBC)/lCA;

				val = lAB*lAB-e*e;
				if(val > 0) {
					h = std::sqrt(val);
				} else {
					h = std::sqrt(-val);
					nbad++;
				}

				v2->position2[0] = pointA[0] + e*(pointC[0]-pointA[0])/lCA 
					+h*(pointC[1]-pointA[1])/lCA;
				v2->position2[1] = pointA[1] + e*(pointC[1]-pointA[1])/lCA 
					-h*(pointC[0]-pointA[0])/lCA;
				v2->position2[2] = 0;

				v_visited[idxB] = 1;

			}

			if(v_visited[idxA] == 0) {

				lBC = (pointB - pointC).norm();
				e = 0.5*(lBC*lBC + lAB*lAB - lCA*lCA)/lBC;

				val = lAB*lAB-e*e;
				if(val > 0) {
					h = -std::sqrt(val);
				} else {
					h = -std::sqrt(-val);
					nbad++;
				}

				v1->position2[0] = pointB[0] + e*(pointC[0]-pointB[0])/lBC 
					+h*(pointC[1]-pointB[1])/lBC;
				v1->position2[1] = pointB[1] + e*(pointC[1]-pointB[1])/lBC 
					-h*(pointC[0]-pointB[0])/lBC;
				v1->position2[2] = 0;

				v_visited[idxA] = 1;

			}

			nvisited++;

		}

		f1 = hAB->flip->face;
		f2 = hBC->flip->face;
		f3 = hCA->flip->face;

		idx = f1->index;
		if(f_visited[idx] == 0) {
			heap.push_back(f1);
			f_visited[idx]=1;
		}
		idx = f2->index;
		if(f_visited[idx] == 0) {
			heap.push_back(f2);
			f_visited[idx]=1;
		}
		idx = f3->index;
		if(f_visited[idx] == 0) {
			heap.push_back(f3);
			f_visited[idx]=1;
		}


		if((nvisited % 20000) == 0) {
			iprint = 0;
			bool info=layoutMinim(mesh, v_visited, iprint);
			if(!info) return info;
		}

	}

  /* ==========================================================================================
       Now some verifications....  all vertices should have been visited
  ==========================================================================================*/

	nvisited = 0;
	for(int i = 0; i < n_vertices; i++) {
		nvisited += std::abs(v_visited[i]);
	}

	if(nvisited != n_vertices) {
		std::cout << "Problem: not all vertices were built!" << std::endl;
		std::cout << "nvertices = " << n_vertices << std::endl;
		std::cout << "nvisited  = " << nvisited  << std::endl;
		std::cout << " " << std::endl;
		exit(1);
	}

  /* ==========================================================================================
       Check that edge lengths are correct
  ==========================================================================================*/

	double err = 0;
	double errmax = 0;
	int idx1, idx2;
	int n=0;
	Vector point1, point2;
	for (EdgeIter e_it = mesh.edges.begin(); e_it != mesh.edges.end(); e_it++)
	{
                he = e_it->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		idx1 = v_i->index;
		idx2 = v_j->index;

		if (v_visited[idx1]==1 && v_visited[idx2]==1) {
			point1 = v_i->position2;
			point2 = v_j->position2;
			l1 = (point1-point2).norm();
			l2 = e_it->length;
			err += (l1-l2)*(l1-l2);
			errmax = std::max(errmax, std::abs(l1-l2));
			n++;
		}
	}

	err=std::sqrt(err/n);

	std::cout << " " << std::endl;
	std::cout << "Number of bad triangles (num. error) : " << nbad << std::endl;
	std::cout << "RMS error on edge lengths            : " << err << std::endl;
	std::cout << "Max error (l1-norm) on edge lengths  : " << errmax << std::endl;
	std::cout << " " << std::endl;

  /* ==================================================================================================
	Improve layout with minimization
   ==================================================================================================*/

	iprint = 1;
	bool SUCCESS = layoutMinim(mesh, v_visited, iprint);

	if(!SUCCESS) return SUCCESS;

  /* ==================================================================================================
	Scale 2D mesh
   ==================================================================================================*/

	centerScale2D(mesh);

  /* ==========================================================================================
       Free memory assigned to local arrays
  ==========================================================================================*/

	delete [] v_visited; delete [] f_visited;

	return SUCCESS;

  }
