/* ====SpinTransform.h ==========================================================================
 *
 * Author: Patrice Koehl, July 2018
 * Department of Computer Science
 * University of California, Davis
 *
 * This file applies a spin transform to compute new edge lengths and edge coordinates
 *
 * It is based on the code SpinXform_OpenGL from Keenan Crane (available at:
 *	https://www.cs.cmu.edu/~kmcrane/index.html#code
 =============================================================================================== */

#pragma once

  /* ===== INCLUDES AND TYPEDEFS =================================================================
   *
   =============================================================================================== */

#include <vector>
#include <cmath>
#include "Quaternion.h"

/*================================================================================================
  Prototypes for BLAS and LAPACK
================================================================================================== */

extern "C" {

        void dscal_(int * n, double * alpha, double * X, int *incx);
        double ddot_(int * n, double * u, int * incu, double * v, int *incv);

        void dgemv_(char * trans, int * m, int * n, double * alpha, double *A,
                int *lda, double * X, int * incx, double *beta, double * Y, int * incy);

}

 Eigen::CholmodDecomposition<Eigen::SparseMatrix<double>> solverS;
 Eigen::LeastSquaresConjugateGradient<Eigen::SparseMatrix<double> > solverLS;

 int nS = 0;
 int nLS = 0;

  /* ===== The SpinTransform class    =================================================================
   *
   =============================================================================================== */

class SpinTransform
 {
  public:

	// solve for new coordinates
	void solveForPositions_poisson(Mesh& mesh, Eigen::VectorXd& lambda_v);

	// solve for new coordinates
	void solveForPositions_lsq(Mesh& mesh, std::vector<double>& ecurv, Eigen::VectorXd& lambda_f);


   protected:

	// init matrix
	void initMatrixL(Mesh& mesh, Eigen::SparseMatrix<double> &Mat);

	// reset matrix
	void resetMatrixL(Mesh& mesh, Eigen::SparseMatrix<double> &Mat);

	// divergence of target edge vectors
	Eigen::VectorXd omega;

	// temp vectors
	Eigen::VectorXd temp1, temp2;

	Eigen::SparseMatrix<double> EV; // Matrix for Edge-vertex adjacency
	Eigen::SparseMatrix<double> L; //Laplace matrix

	// update real sparse matrix with a quaternion value
	void updateMatrixL(Eigen::SparseMatrix<double>& Mat, int i0, int j0, double Q[4][4]);

	// update real vector with a quaternion value
	void updateVector(Eigen::VectorXd& Vect, int i0, Quaternion q);

	// Build current Laplacian over the mesh
	void buildLaplacian(Mesh& mesh);

	// Build tangents
	void buildOmega(Mesh& mesh, Eigen::VectorXd& lambda_v);

	// Build tangents
	void solveForPositions(Mesh& mesh);

	// init Dirac matrix
	void initEV(Mesh& mesh, Eigen::SparseMatrix<double> &Mat);

	// Apply spin transform
	void applySpinX(Mesh& mesh, std::vector<double>& ecurv, Eigen::VectorXd& lambda_f);

};

  /* =============================================================================================
   Init matrix L
   =============================================================================================== */

  void SpinTransform::initMatrixL(Mesh& mesh, Eigen::SparseMatrix<double> &Mat)
  {

	// Initialize all values to 0

	int idx, jdx;
	std::vector<Triplet> Mat_coefficients;
	double zero = 0;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end() - 1; v++)
	{
		idx = v->index;
		for(int u = 0; u < 4; u++) {
			Mat_coefficients.push_back(Triplet(4*idx+u, 4*idx+u, zero));
		}
	}

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	int n_vertices = mesh.vertices.size();
	
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		idx = v_i->index;
		jdx = v_j->index;

		if(idx != n_vertices-1 && jdx != n_vertices -1) {
			for(int u = 0; u < 4; u++) {
				Mat_coefficients.push_back(Triplet(4*idx+u, 4*jdx+u, zero));
			}
			for(int u = 0; u < 4; u++) {
				Mat_coefficients.push_back(Triplet(4*jdx+u, 4*idx+u, zero));
			}
		}
	}

	Mat.setFromTriplets(Mat_coefficients.begin(), Mat_coefficients.end());

  }

  /* =============================================================================================
   Reset matrix L
   =============================================================================================== */

  void SpinTransform::resetMatrixL(Mesh& mesh, Eigen::SparseMatrix<double> &Mat)
  {

	int idx, jdx;
	double zero = 0;

	for (VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end()-1; v++)
	{
		idx = v->index;
		for(int u = 0; u < 4; u++) {
			Mat.coeffRef(4*idx+u, 4*idx+u) = zero;
		}
	}

	int n_vertices = mesh.vertices.size();

	HalfEdgeIter he, he2;
	VertexIter v_i, v_j;
	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		he = e->he;
		he2 = he->flip;

		v_i = he->vertex;
		v_j = he2->vertex;

		idx = v_i->index;
		jdx = v_j->index;

		if(idx != n_vertices-1 && jdx != n_vertices-1) {
			for(int u = 0; u < 4; u++) {
				Mat.coeffRef(4*idx+u, 4*jdx+u) = zero;
			}
			for(int u = 0; u < 4; u++) {
				Mat.coeffRef(4*jdx+u, 4*idx+u) = zero;
			}
		}
	}

  }


  /* =============================================================================================
   * update real vector with a quaternion value
   =============================================================================================== */

  void SpinTransform::updateVector(Eigen::VectorXd& Vect, int i0, Quaternion q)
  {
	Vect[4*i0+0] += q.re();
	Vect[4*i0+1] += q.im().x;
	Vect[4*i0+2] += q.im().y;
	Vect[4*i0+3] += q.im().z;
  }

  /* =============================================================================================
   * update real sparse matrix with a quaternion value
   =============================================================================================== */

  void SpinTransform::updateMatrixL(Eigen::SparseMatrix<double>& Mat, int i0, int j0, double Q[4][4])
  {
	int idx, jdx;
	for(int u = 0; u < 4; u++) {
		idx = 4*i0 + u;
		jdx = 4*j0 + u;
		Mat.coeffRef(idx, jdx) += Q[u][u];
	}
   }

  /* =============================================================================================
   *	Solve using Poisson equation
   =============================================================================================== */

void SpinTransform::solveForPositions_poisson(Mesh& mesh, Eigen::VectorXd& lambda_v)
{
	if(nS==0) {
		int n4 = 4*mesh.vertices.size();
		int n4m1 = n4 -4;
		L.resize(n4m1, n4m1);
		omega.resize(n4m1);
		temp1.resize(n4);
		temp2.resize(n4);

		initMatrixL(mesh, L);
	}

	buildLaplacian(mesh);
	buildOmega(mesh, lambda_v);

	solveForPositions(mesh);
}

  /* =============================================================================================
   *	Build Laplacian matrix (for Poisson problem)
   =============================================================================================== */

void SpinTransform::buildLaplacian(Mesh& mesh)
{
	// Reset L matrix ot zero
	resetMatrixL(mesh, L);

	// Pointers to mesh structure
	HalfEdgeIter hAB, hBC, hCA;
	VertexIter v1, v2, v3;
	Vector p1, p2, p3, p;
	int idxA, idxB, idxC;
	int n_vertices = mesh.vertices.size();

	double cotan_A, cotan_B, cotan_C;
	Quaternion qA, qB, qC, q;
	double QA[4][4], QB[4][4], QC[4][4], Q[4][4];

	// visit each face
	for( FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++ )
	{

		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		p1 = v1->position2;
		p2 = v2->position2;
		p3 = v3->position2;

		idxA = v1->index;
		idxB = v2->index;
		idxC = v3->index;

		cotan_A = dot(p2-p1, p3-p1) / ( cross(p2-p1, p3-p1).norm());
		cotan_B = dot(p1-p2, p3-p2) / ( cross(p1-p2, p3-p2).norm());
		cotan_C = dot(p1-p3, p2-p3) / ( cross(p1-p3, p2-p3).norm());

		qA = -0.5*cotan_A;
		qB = -0.5*cotan_B;
		qC = -0.5*cotan_C;
		qA.toMatrix(QA);
		qB.toMatrix(QB);
		qC.toMatrix(QC);

		q = -qB - qC; q.toMatrix(Q);
		if(idxA != n_vertices-1) updateMatrixL(L, idxA, idxA, Q);
		q = -qA - qC; q.toMatrix(Q);
		if(idxB != n_vertices-1) updateMatrixL(L, idxB, idxB, Q);
		q = -qA - qB; q.toMatrix(Q);
		if(idxC != n_vertices-1) updateMatrixL(L, idxC, idxC, Q);

		if(idxA != n_vertices-1 && idxB != n_vertices-1) {
			updateMatrixL(L, idxA, idxB, QC);
			updateMatrixL(L, idxB, idxA, QC);
		}
		if(idxA != n_vertices-1 && idxC != n_vertices-1) {
			updateMatrixL(L, idxA, idxC, QB);
			updateMatrixL(L, idxC, idxA, QB);
		}
		if(idxB != n_vertices-1 && idxC != n_vertices-1) {
			updateMatrixL(L, idxB, idxC, QA);
			updateMatrixL(L, idxC, idxB, QA);
		}
	}

}

  /* =============================================================================================
   *	Build Omega vector (new tangent vector), for Poisson problem
   =============================================================================================== */

 void SpinTransform::buildOmega(Mesh& mesh, Eigen::VectorXd& lambda)
 {

	// clear current omega vector
	for( size_t i = 0; i < omega.size(); i++ ) omega[i] = 0.;

	// Pointers to mesh structure
	HalfEdgeIter hAB, hBC, hCA;
	VertexIter v1, v2, v3;
	Vector p1, p2, p3, p;
	int idxA, idxB, idxC;
	int idA, idB, idC;
	int n_vertices = mesh.vertices.size();

	double cotan_A, cotan_B, cotan_C;
	Quaternion qA, qB, qC, q, qm;
	Quaternion eTildeAB, eTildeBC, eTildeCA;

	// visit each face
	for( FaceIter f_iter = mesh.faces.begin(); f_iter != mesh.faces.end(); f_iter++ )
	{

		hAB =f_iter->he;
		hBC =hAB->next;
		hCA =hBC->next;

		v1 = hAB->vertex;
		v2 = hBC->vertex;
		v3 = hCA->vertex;

		idA = v1->index;
		idB = v2->index;
		idC = v3->index;

		idxA = std::min(idA, std::min(idB, idC));
		idxC = std::max(idA, std::max(idB, idC));
		idxB = idA + idB + idC - idxA - idxC;

		p1 = mesh.vertices[idxA].position2;
		p2 = mesh.vertices[idxB].position2;
		p3 = mesh.vertices[idxC].position2;

		cotan_A = dot(p2-p1, p3-p1) / ( cross(p2-p1, p3-p1).norm());
		cotan_B = dot(p1-p2, p3-p2) / ( cross(p1-p2, p3-p2).norm());
		cotan_C = dot(p1-p3, p2-p3) / ( cross(p1-p3, p2-p3).norm());

		qA = Quaternion(lambda[4*idxA], lambda[4*idxA+1], lambda[4*idxA+2], lambda[4*idxA+3]);
		qB = Quaternion(lambda[4*idxB], lambda[4*idxB+1], lambda[4*idxB+2], lambda[4*idxB+3]);
		qC = Quaternion(lambda[4*idxC], lambda[4*idxC+1], lambda[4*idxC+2], lambda[4*idxC+3]);

		Quaternion eAB = Quaternion(0., p2[0] -p1[0], p2[1]-p1[1], p2[2]-p1[2]);
		Quaternion eBC = Quaternion(0., p3[0] -p2[0], p3[1]-p2[1], p3[2]-p2[2]);
		Quaternion eCA = Quaternion(0., p1[0] -p3[0], p1[1]-p3[1], p1[2]-p3[2]);

		eTildeAB = (1./3.) * (~qA) * eAB * qA +
		           (1./6.) * (~qA) * eAB * qB +
		           (1./6.) * (~qB) * eAB * qA +
		           (1./3.) * (~qB) * eAB * qB ;
		eTildeBC = (1./3.) * (~qB) * eBC * qB +
		           (1./6.) * (~qB) * eBC * qC +
		           (1./6.) * (~qC) * eBC * qB +
		           (1./3.) * (~qC) * eBC * qC ;
		eTildeCA = (1./3.) * (~qC) * eCA * qC +
		           (1./6.) * (~qC) * eCA * qA +
		           (1./6.) * (~qA) * eCA * qC +
		           (1./3.) * (~qA) * eCA * qA ;

		q = 0.5*cotan_C * eTildeAB; qm = -q;
		if(idxA != n_vertices-1) updateVector(omega, idxA, qm);
		if(idxB != n_vertices-1) updateVector(omega, idxB, q);
		q = 0.5*cotan_A * eTildeBC; qm = -q;
		if(idxB != n_vertices-1) updateVector(omega, idxB, qm);
		if(idxC != n_vertices-1) updateVector(omega, idxC, q);
		q = 0.5*cotan_B * eTildeCA; qm = -q;
		if(idxC != n_vertices-1) updateVector(omega, idxC, qm);
		if(idxA != n_vertices-1) updateVector(omega, idxA, q);

	}

}

/*================================================================================================
 Solve for new coordinates
================================================================================================== */

void SpinTransform::solveForPositions(Mesh& mesh)
{

	// Cholesky factorization of L
	if(nS==0) solverS.analyzePattern(L);
        nS++;
	solverS.factorize(L);

	temp2 = solverS.solve(omega);

	for(VertexIter v = mesh.vertices.begin(); v != mesh.vertices.end()-1 ; v++) {
		v->position2[0] = temp2[4*v->index + 1];
		v->position2[1] = temp2[4*v->index + 2];
		v->position2[2] = temp2[4*v->index + 3];
	}

	int nv = mesh.vertices.size();
	mesh.vertices[nv-1].position2[0] = 0;
	mesh.vertices[nv-1].position2[1] = 0;
	mesh.vertices[nv-1].position2[2] = 0;

}

  /* =============================================================================================
   *	Build Poisson problem
   =============================================================================================== */

void SpinTransform::solveForPositions_lsq(Mesh& mesh, std::vector<double>& ecurv, Eigen::VectorXd& lambda_f)
{
	if(nLS==0) {
		int n_edges = mesh.edges.size();
		int n_vertices = mesh.vertices.size();

		EV.resize(n_edges, n_vertices-1);
		initEV(mesh, EV);

		temp1.resize(4*n_vertices);
		temp2.resize(4*n_vertices);
	}

	applySpinX(mesh, ecurv, lambda_f);

};

  /* =============================================================================================
   Build Edge-Vertex Adjacency matrix
   =============================================================================================== */

  void SpinTransform::initEV(Mesh& mesh, Eigen::SparseMatrix<double> &Mat)
  {

	int nvertex = mesh.vertices.size();
	int idx;
	int idv1, idv2;
	std::vector<Triplet> Mat_coefficients;
	double mone = -1;
	double one = 1;

	for (EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++)
	{
		idx = e->index;
		idv1 = e->he->vertex->index;
		idv2 = e->he->flip->vertex->index;
		if(idv1 != nvertex-1) {
			Mat_coefficients.push_back(Triplet(idx, idv1, mone));
		}
		if(idv2 != nvertex-1) {
			Mat_coefficients.push_back(Triplet(idx, idv2, one));
		}
	}

	Mat.setFromTriplets(Mat_coefficients.begin(), Mat_coefficients.end());

  }

  /* =============================================================================================
   *	Apply spin transformation
   =============================================================================================== */
 
  void SpinTransform::applySpinX(Mesh& mesh, std::vector<double>& ecurv, Eigen::VectorXd& lambda)
  {

	int nvertices = mesh.vertices.size();
	int nedges = mesh.edges.size();

	double val = std::sqrt(mesh.faces.size()) / lambda.norm();
	lambda = val*lambda;
	
	Eigen::VectorXd Bx(nedges);
	Eigen::VectorXd By(nedges);
	Eigen::VectorXd Bz(nedges);
	Eigen::VectorXd Sol(nvertices-1);

	if(nLS==0) {
		solverLS.analyzePattern(EV);
		solverLS.factorize(EV);
		if(solverLS.info()!=Eigen::Success) {
			std::cout << "LU decompositon failed..." << std::endl;
			exit(1);
		}
	}
	nLS++;

	int idxA, idxB, f1, f2;
	Vector vA, vB, eAB;
	Quaternion q1, q2, qAB, q;
	for(EdgeIter e = mesh.edges.begin(); e != mesh.edges.end(); e++) {

		idxA = e->he->vertex->index;
		idxB = e->he->flip->vertex->index;
		vA = e->he->vertex->position2;
		vB = e->he->flip->vertex->position2;
		f1 = e->he->face->index;
		f2 = e->he->flip->face->index;

		eAB = vB - vA;
		q1 = Quaternion(lambda[4*f1], lambda[4*f1+1], lambda[4*f1+2], lambda[4*f1+3]);
		q2 = Quaternion(lambda[4*f2], lambda[4*f2+1], lambda[4*f2+2], lambda[4*f2+3]);
		qAB = Quaternion(ecurv[e->index], eAB);

		q = (~q1) * qAB * q2;

		Bx[e->index] = q.im().x;
		By[e->index] = q.im().y;
		Bz[e->index] = q.im().z;

	}

	Sol = solverLS.solve(Bx);
	for(int i = 0; i < nvertices-1; i++) {
		mesh.vertices[i].position2.x = Sol[i];
	}
	mesh.vertices[nvertices-1].position2.x = 0.;

	Sol = solverLS.solve(By);
	for(int i = 0; i < nvertices-1; i++) {
		mesh.vertices[i].position2.y = Sol[i];
	}
	mesh.vertices[nvertices-1].position2.y = 0.;
	
	Sol = solverLS.solve(Bz);
	for(int i = 0; i < nvertices-1; i++) {
		mesh.vertices[i].position2.z = Sol[i];
	}
	mesh.vertices[nvertices-1].position2.z = 0.;

  }
