/*================================================================================================
  Align.h
  Version 1: 3/24/2020

Copyright (c) Patrice Koehl.

>>> SOURCE LICENSE >>>

  This library is free software; you can redistribute it and/or
  modify it under the terms of the GNU Lesser General Public
  License as published by the Free Software Foundation; either
  version 2.1 of the License, or (at your option) any later version.

  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public
  License along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

>>> END OF LICENSE >>>

================================================================================================== */

#ifndef _ALIGN_H_
#define _ALIGN_H_

/*================================================================================================
 Includes
================================================================================================== */

#include <math.h>
#include <cstdlib>

/*================================================================================================
  Prototypes for BLAS and LAPACK
================================================================================================== */

extern "C" {

	void daxpy_(int * n ,double *alpha , double * X, int *incx, double * Y,int *incy);
	double dnrm2_(int * n, double * X, int *incx);
	void dscal_(int * n, double * alpha, double * X, int *incx);
	void dcopy_(int * n, double * X, int *incx, double * Y, int *incy);
	double ddot_(int * n, double * u, int * incu, double * v, int *incv);

	void dsyev(char * JOBZ, char * UPLO, int *N, double *A, int *LDA, double *W,
	double *WORK, int *LWORK, int *INFO);

	void dgesvd_(char *JOBU, char *JOBV, int *M, int *N, double *A, int *LDA,
	double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK, int *info);
}

 class Align {

	public:
		// Mass-weighted best fit of two sets of points
		double bestfitm(std::vector<Vector>& points1, std::vector<double> mass1,
		std::vector<Vector>& points2, std::vector<double> mass2, 
		std::vector<std::pair<int, int> > corresp, int *info);

		// Procustes best fit of two sets of points
		double procustes(std::vector<Vector>& points1, std::vector<Vector>& points2,
		std::vector<std::pair<int, int> > corresp, int *info);

		// Position mesh 2 based on transport plan between mesh1 and mesh2
		void alignMesh2(Mesh& mesh1, Mesh& mesh2, std::vector<int>& keypoints1,
		std::vector<int>& keypoints2, double *plan, double *m1);


	private:
		// compute the determinant of a 3x3 matrix
		double deter(double *mat);

  };

/*================================================================================================
  Position mesh 2 based on transport plan between mesh1 and mesh2
================================================================================================== */

  void Align::alignMesh2(Mesh& mesh1, Mesh& mesh2, std::vector<int>& keypoints1,
		std::vector<int>& keypoints2, double *plan, double *m1)
  {

	int nvertex1 = mesh1.vertices.size();
	int nvertex2 = mesh2.vertices.size();

	int n1 = keypoints1.size();
	int n2 = keypoints2.size();

	int *assignment = new int[nvertex1];
	for(int i = 0; i < nvertex1; i++) assignment[i] = -1;

	double tol = 1.e-4;
	double Gmax;
	for(int i = 0; i < n1; i++) {
		if(m1[i] < tol) {
			Gmax = 0;
			assignment[keypoints1[i]] = -1;
		} else {
			int jmax = 0;
			Gmax = plan[i];
			for(int j = 1; j < n2; j++) {
				if(plan[i+j*n1] > Gmax) {
					Gmax = plan[i+j*n1];
					jmax = j;
				}
			}
			if(Gmax > 0.9*m1[i]) {
				assignment[keypoints1[i]] = keypoints2[jmax];
			} else {
				assignment[i] = -1;
			}
		}
	}

	std::vector<std::pair<int, int> > corresp;
	for(int i = 0; i < nvertex1; i++) {
		if(assignment[i] != -1) {
			corresp.push_back(std::make_pair(i, assignment[i]));
		}
	}

	std::vector<Vector> vertices1;
	std::vector<Vector> vertices2;
	for(int i = 0; i < nvertex1; i++) {
		vertices1.push_back(mesh1.vertices[i].position);
	}
	for(int i = 0; i < nvertex2; i++) {
		vertices2.push_back(mesh2.vertices[i].position);
	}
	std::cout << "ncorresp: " << corresp.size() << std::endl;
	std::cout << "nvertex1: " << vertices1.size() << std::endl;
	std::cout << "nvertex2: " << vertices2.size() << std::endl;

	int info;
	double rms = procustes(vertices1, vertices2, corresp, &info);
	std::cout << "rms = " << rms << std::endl;

	for(int i = 0; i < nvertex2; i++) {
		mesh2.vertices[i].position = vertices2[i];
		mesh2.vertices[i].position2 = vertices2[i];
	}

  }

/*================================================================================================
 Performs mass-weighted bestfit between two sets of points

 Input:
	points1: the first set of points, stored as a vector of Vector (3D class)
	mass1:   mass assigned to each point in set 1
	points2: the second set of points, stored as a vector of Vector (3D class)
	mass2:   mass assigned to each point in set 2
	corresp: correspondence between the two sets of points
================================================================================================== */

  double Align::bestfitm(std::vector<Vector>& points1, std::vector<double> mass1,
		std::vector<Vector>& points2, std::vector<double> mass2, 
		std::vector<std::pair<int, int> > corresp, int *info)
  {

	int npoints2 = points2.size();
	int npoints = corresp.size();
	*info = 0;

/*================================================================================================
  Find center of masses of the two sets of points in correspondence
================================================================================================== */

	Vector cm1(0.,0.,0.);
	Vector cm2(0.,0.,0.);

	double xmass1 = 0, xmass2 = 0;
	int ip1, ip2;
	for(int i = 0; i < npoints; i++) {
		ip1 = corresp[i].first; 
		ip2 = corresp[i].second; 
		cm1 += points1[ip1]*mass1[ip1];
		cm2 += points2[ip2]*mass2[ip2];
		xmass1 += mass1[ip1];
		xmass2 += mass2[ip2];
	}
	cm1 /= xmass1;
	cm2 /= xmass2;

/*================================================================================================
  Calculate covariance matrix on centered points
================================================================================================== */

	double covar[9];
	Vector p1, p2;
	for(int i = 0; i < 3; i++) {
		for(int j = 0; j < 3; j++) {
			int idx = i + 3*j;
			covar[idx] = 0;
			for(int k = 0; k < npoints; k++) {
				ip1 = corresp[k].first; 
				ip2 = corresp[k].second; 
				p1 = points1[ip1] - cm1;
				p2 = points2[ip2] - cm2;
				covar[idx] += p1[i]*p2[j]*mass1[ip1];
			}
		}
	}

/*================================================================================================
  Compute determinant of covariance matrix; if 0, problem
================================================================================================== */

	double det = deter(covar);

	if(det == 0) {
		*info = 1;
		return det;
	}

	double sign = 1.0;
	if(det < 0) sign = -1.0;

/*================================================================================================
  Perform SVD on covariance matrix
================================================================================================== */

	char JOBU ='A';
	char JOBVT='A';
	int N     = 3;
	double U[9], VT[9], V[9], S[3], WORK[50];
	int lwork = 50;
	int inf;
	dgesvd_(&JOBU, &JOBVT, &N, &N, covar, &N, S, U, &N, VT, &N, WORK, &lwork, &inf);

	V[0] = VT[0]; V[1] = VT[3]; V[2] = VT[6]; 
	V[3] = VT[1]; V[4] = VT[4]; V[5] = VT[7]; 
	V[6] = VT[2]; V[7] = VT[5]; V[8] = VT[8]; 
	if(inf!=0) {
		*info = 1;
		return det;
	}

/*================================================================================================
  Calculate bestfit rotation matrix
================================================================================================== */

	double R[9];
	double tiny = 1.e-14;

	if(S[1] > tiny) {
		if(S[2] < tiny) {
			sign = 1.0;
			U[6] = U[1]*U[5] - U[2]*U[4];
			U[7] = U[2]*U[3] - U[0]*U[5];
			U[8] = U[0]*U[4] - U[1]*U[3];
			V[6] = V[1]*V[5] - V[2]*V[4];
			V[7] = V[2]*V[3] - V[0]*V[5];
			V[8] = V[0]*V[4] - V[1]*V[3];
		}
		for(int i = 0; i < 3; i++) {
			for(int j = 0; j < 3; j++) {
				R[i+3*j] = U[i]*V[j] + U[i+3]*V[j+3]
					+sign*U[i+6]*V[j+6];
			}
		}
	} else {
		WORK[0] = U[1]*V[2] - U[2]*V[1];
		WORK[1] = U[2]*V[0] - U[0]*V[2];
		WORK[2] = U[0]*V[1] - U[1]*V[0];
		double norm = 0;
		for(int i = 0; i < 3; i++) norm+= WORK[i]*WORK[i];
		if(norm !=0) {
			for(int i = 0; i < 3; i++) WORK[i] = U[i] + V[i];
		} else {
			WORK[0] = - U[1]; WORK[2] = U[0]; WORK[3] = 0;
		}
		norm = 0;
		for(int i = 0; i < 3; i++) norm+= WORK[i]*WORK[i];
		norm = std::sqrt(norm);
		for(int i = 0; i < 3; i++) WORK[i]/= norm;
		for(int i = 0; i < 3; i++) {
			for(int j = 0; j < 3; j++) {
				R[i+3*j] = 2*WORK[i]*WORK[j];
			}
			R[i+3*i] -= 1;
		}
	}

	det = deter(R);
	if(det < 0) {
		std::cout << "Warning: rotation matrix with negative determinant" << std::endl;
		*info = 1;
		return det;
	}

/*================================================================================================
  Apply rotation matrix
================================================================================================== */

	for(int i = 0; i < npoints2; i++) {
		for(int j = 0; j < 3; j++) {
			WORK[j] = 0;
			for(int k = 0; k < 3; k++) {
				WORK[j] += R[j+3*k]*(points2[i][k]-cm2[k]);
			}
		}
		for(int j = 0; j < 3; j++) {
			points2[i][j] = WORK[j] + cm1[j];
		}
	}

/*================================================================================================
  Calculate RMSD
================================================================================================== */

	double rmsd = 0;
	for(int i = 0; i < npoints; i++) {
		ip1 = corresp[i].first; 
		ip2 = corresp[i].second; 
		p1 = points1[ip1];
		p2 = points2[ip2];
		rmsd += (p1-p2).norm2();
	}

	rmsd = std::sqrt(std::abs(rmsd)/npoints);

	return rmsd;

 }
/*================================================================================================
 Performs Procustes bestfit between two sets of points

 Input:
	points1: the first set of points, stored as a vector of Vector (3D class)
	points2: the second set of points, stored as a vector of Vector (3D class)
	corresp: correspondence between the two sets of points
================================================================================================== */

  double Align::procustes(std::vector<Vector>& points1, std::vector<Vector>& points2,
		std::vector<std::pair<int, int> > corresp, int *info)
  {

	int npoints2 = points2.size();
	int npoints = corresp.size();
	*info = 0;

/*================================================================================================
  Find center of masses of the two sets of points in correspondence
================================================================================================== */

	Vector cm1(0.,0.,0.);
	Vector cm2(0.,0.,0.);

	int ip1, ip2;
	for(int i = 0; i < npoints; i++) {
		ip1 = corresp[i].first; 
		ip2 = corresp[i].second; 
		cm1 += points1[ip1];
		cm2 += points2[ip2];
	}
	cm1 /= npoints;
	cm2 /= npoints;

/*================================================================================================
  Find scales of the two sets of points in correspondence
================================================================================================== */

	double s1 = 0, s2 = 0;
	Vector p1, p2;
	for(int i = 0; i < npoints; i++) {
		ip1 = corresp[i].first; 
		ip2 = corresp[i].second; 
		p1 = points1[ip1] - cm1;
		p2 = points2[ip2] - cm2;
		s1 += p1.norm2();
		s2 += p2.norm2();
	}
	s1 = std::sqrt(s1/npoints);
	s2 = std::sqrt(s2/npoints);

/*================================================================================================
  Calculate covariance matrix on scaled, centered points
================================================================================================== */

	double covar[9];
	for(int i = 0; i < 3; i++) {
		for(int j = 0; j < 3; j++) {
			int idx = i + 3*j;
			covar[idx] = 0;
			for(int k = 0; k < npoints; k++) {
				ip1 = corresp[k].first; 
				ip2 = corresp[k].second; 
				p1 = (points1[ip1] - cm1)/s1;
				p2 = (points2[ip2] - cm2)/s2;
				covar[idx] += p1[i]*p2[j];
			}
		}
	}

/*================================================================================================
  Compute determinant of covariance matrix; if 0, problem
================================================================================================== */

	double det = deter(covar);

	if(det == 0) {
		*info = 1;
		return det;
	}

	double sign = 1.0;
	if(det < 0) sign = -1.0;

/*================================================================================================
  Perform SVD on covariance matrix
================================================================================================== */

	char JOBU ='A';
	char JOBVT='A';
	int N     = 3;
	double U[9], VT[9], V[9], S[3], WORK[20];
	int lwork = 20;
	int inf;
	dgesvd_(&JOBU, &JOBVT, &N, &N, covar, &N, S, U, &N, VT, &N, WORK, &lwork, &inf);

	V[0] = VT[0]; V[1] = VT[3]; V[2] = VT[6]; 
	V[3] = VT[1]; V[4] = VT[4]; V[5] = VT[7]; 
	V[6] = VT[2]; V[7] = VT[5]; V[8] = VT[8]; 

	if(inf!=0) {
		*info = 1;
		return det;
	}

/*================================================================================================
  Calculate bestfit rotation matrix
================================================================================================== */

	double R[9];
	double tiny = 1.e-14;

	if(S[1] > tiny) {
		if(S[2] < tiny) {
			sign = 1.0;
			U[6] = U[1]*U[5] - U[2]*U[4];
			U[7] = U[2]*U[3] - U[0]*U[5];
			U[8] = U[0]*U[4] - U[1]*U[3];
			V[6] = V[1]*V[5] - V[2]*V[4];
			V[7] = V[2]*V[3] - V[0]*V[5];
			V[8] = V[0]*V[4] - V[1]*V[3];
		}
		for(int i = 0; i < 3; i++) {
			for(int j = 0; j < 3; j++) {
				R[i+3*j] = U[i]*V[j] + U[i+3]*V[j+3]
					+sign*U[i+6]*V[j+6];
			}
		}
	} else {
		WORK[0] = U[1]*V[2] - U[2]*V[1];
		WORK[1] = U[2]*V[0] - U[0]*V[2];
		WORK[2] = U[0]*V[1] - U[1]*V[0];
		double norm = 0;
		for(int i = 0; i < 3; i++) norm+= WORK[i]*WORK[i];
		if(norm !=0) {
			for(int i = 0; i < 3; i++) WORK[i] = U[i] + V[i];
		} else {
			WORK[0] = - U[1]; WORK[2] = U[0]; WORK[3] = 0;
		}
		norm = 0;
		for(int i = 0; i < 3; i++) norm+= WORK[i]*WORK[i];
		norm = std::sqrt(norm);
		for(int i = 0; i < 3; i++) WORK[i]/= norm;
		for(int i = 0; i < 3; i++) {
			for(int j = 0; j < 3; j++) {
				R[i+3*j] = 2*WORK[i]*WORK[j];
			}
			R[i+3*i] -= 1;
		}
	}

	det = deter(R);
	if(det < 0) {
		std::cout << "Warning: rotation matrix with negative determinant" << std::endl;
		*info = 1;
		return det;
	}

/*================================================================================================
  Apply rotation matrix
================================================================================================== */

	for(int i = 0; i < npoints2; i++) {
		for(int j = 0; j < 3; j++) {
			WORK[j] = 0;
			for(int k = 0; k < 3; k++) {
				WORK[j] += R[j+3*k]*(points2[i][k]-cm2[k])/s2;
			}
		}
		for(int j = 0; j < 3; j++) {
			points2[i][j] = s1*WORK[j] + cm1[j];
		}
	}

/*================================================================================================
  Calculate RMSD
================================================================================================== */

	double rmsd = 0;
	for(int i = 0; i < npoints; i++) {
		ip1 = corresp[i].first; 
		ip2 = corresp[i].second; 
		p1 = points1[ip1];
		p2 = points2[ip2];
		rmsd += (p1-p2).norm2();
	}

	rmsd = std::sqrt(std::abs(rmsd)/npoints);

	return rmsd;

 }

/*================================================================================================
  Computes the determinant of a 3x3 matrix
================================================================================================== */

  double Align::deter(double *mat)
  {
	double a1, a2, a3;

	a1 = mat[4]*mat[8] - mat[5]*mat[7];
	a2 = mat[3]*mat[8] - mat[5]*mat[6];
	a3 = mat[3]*mat[7] - mat[4]*mat[6];

	double det = mat[0]*a1 - mat[1]*a2 + mat[2]*a3;

	return det;
  }

#endif
