/* ===============================================================================================
   Optim     		Version 1 9/23/2019		Patrice Koehl

   Given a set of points Target, optimize position of a set of points Input, so that the OT
   distance between Input and Target is minimum
   Only rigid body transformation are allowed

   =============================================================================================== */

#ifndef _OPTIM_H_
#define _OPTIM_H_

/* ===============================================================================================
   Includes
   =============================================================================================== */

  #include "ExpMap.h"

  extern "C" {

	// BLAS1: copy one vector into another
	void dcopy_(int * n, double * X, int *incx, double * Y, int *incy);

	// BLAS1: dot product of two vectors
	double ddot_(int * n, double * u, int * incu, double * v, int *incv);

	// BLAS1: norm of a vector
	double dnrm2_(int * n, double * X, int *incx);

	void lbfgs_(int *N, int *M, double *X, double *F, double *G, int *DIAGCO, double *DIAG,
	int *IPRINT, double *EPS, double *XTOL, double *W, double *GNORM, int *IFLAG);

	void cgfam_(int *N, double *X, double *F, double *G, double *D, double *GOLD,
	int *IPRINT, double *EPS, double *W, int *IFLAG, int *IREST, int *METHOD,int *FINISH);

	void setulb_(int *N, int *M, double *X, double *L, double *U, int *NBD, double *F, 
		double *G, double *FACTR, double *PGTOL, double *WA, int *IWA, char *TASK,
		int *IPRINT, char *CSAVE, bool *LSAVE, int *ISAVE, double *DSAVE);
  }

  EXPMAP emap;

/* ===============================================================================================
   Main class
   =============================================================================================== */

  class Optim{

  public:

	double refine(Mesh& model1, Mesh& model2, double a1, double a2, int niter, 
	int method, int nthreads);

	// init optim
	void initOptim(int npoint1, double *Target, int npoint2, double *Y, 
		double *X, double *dX, double alp1, double alp2, int nthreads);

	// Apply rigid body transformation
	void applyTransform(double *X, int npoint2, double *Y);

	// One step of LBFGSB
	double oneStep_LBFGSB(int npoint1, double *Target, int npoint2, double *Y, 
	double *X, double *dX, int method, int *IFLAG, int iter);

	// Compute energy and gradient
	int eneAndDer(int npoint1, double *Target, int npoint2, double *Y,
	double *X, double *dX, int method, double *d, double *dM1, double *dM2);

	// Compute energy
	double ene(int npoint1, double *Target, int npoint2, double *Y,
	double *X, int method);

	// compute gradient
	void checkDeriv(int npoint1, double *Target, int npoint2, double *Y,
	double *X, double *dX, int method);

	double *m1, *m2, *G, *Cost, *lambda, *mu;

  private:

	// Apply rotation
	void Rotation(double R[3][3], double *p, double *q);

	// compute Cost (distance) Matrix
	void computeCostMatrix(int npoint1, double *Target, int npoint2, double *Y);

	// compute gradient
	void computeGradient(int npoint1, double *Target, int npoint2, double *Y,
	double *X, double *dX);

	// compute gradient
	void computeNumGradient(int npoint1, double *Target, int npoint2, double *Y,
	double *X, double *dX, int method);

	// internal variables

	int nproc;
	int M, N;
	double *Yref;
	double x, a1, a2;
	double beta_init, beta_final;
	int iprint, init;

	double *Work, *Diag;

	double *L, *U, *WA, *DSAVE;
	int *NBD, *IWA, *ISAVE;
	char *TASK, *CSAVE;
	bool *LSAVE;

  };

/* ===============================================================================================
   Initialize parameters
   =============================================================================================== */

  void Optim::initOptim(int npoint1, double *Target, int npoint2, double *Y, 
	double *X, double *dX, double alp1, double alp2, int nthreads)
  {

	nproc = nthreads;

	a1 = alp1;
	a2 = alp2;

	N = 7;
	M = 10;

        lambda = new double[npoint1];
       	mu = new double[npoint2];
	m1   = new double[npoint1];
	m2   = new double[npoint2];
	Cost  = new double[npoint1*npoint2];

	G = new double[npoint1*npoint2];

	Yref = new double[3*npoint2];

	Work  = new double[(2*M+1)*N + 2*M];
	Diag  = new double[N];

	L = new double[N];
	U = new double[N];
	NBD = new int[N];
	WA = new double[2*M*N+5*N+11*M*M+8*M];
	IWA = new int[3*N];
	ISAVE = new int[44];
	DSAVE = new double[29];
	LSAVE = new bool[4];
	TASK = new char[60];
	CSAVE = new char[60];
	memset(L, 0, N*sizeof(double));
	memset(U, 0, N*sizeof(double));
	memset(NBD, 0, N*sizeof(int));

	for(int i = 0; i < 60; i++) {
		TASK[i] = ' ';
		CSAVE[i] = ' ';
	}

        memset(lambda, 0, npoint1*sizeof(double));
       	memset(mu, 0, npoint2*sizeof(double));
	memset(Cost, 0, npoint1*npoint2*sizeof(double));
	memset(G, 0, npoint1*npoint2*sizeof(double));
	x = 0;

	int nv = 3*npoint2;
	int inc = 1;
	dcopy_(&nv, Y, &inc, Yref, &inc);

	for(int i = 0; i < npoint1; i++) m1[i] = 1./npoint1;
	for(int i = 0; i < npoint2; i++) m2[i] = 1./npoint2;

	beta_init = 1.e4;
	beta_final = 1.e10;
	iprint = 1;
	init = 0;

        std::cout << " " << std::endl;
	std::cout << "        " << "===============================================================" << std::endl;
	std::cout << "        " << "       Iter      E_OT      E_mass1      E_mass2        E_tot   " << std::endl;
	std::cout << "        " << "===============================================================" << std::endl;

  }
	
/* ===============================================================================================
   Apply rotation on one point
   =============================================================================================== */

   void Optim::Rotation(double R[3][3], double *p, double *q)
   {

	q[0] = R[0][0]*p[0] + R[0][1]*p[1] + R[0][2]*p[2];
	q[1] = R[1][0]*p[0] + R[1][1]*p[1] + R[1][2]*p[2];
	q[2] = R[2][0]*p[0] + R[2][1]*p[1] + R[2][2]*p[2];

   }

/* ===============================================================================================
   Apply rigid body transformation
   =============================================================================================== */

   void Optim::applyTransform(double *X, int npoint2, double *Y)
   {
	double R[3][3];
	double p[3];
	emap.EM3_To_R(X, R);

	for(int i = 0; i < npoint2; i++) {

		Rotation(R, &Yref[3*i], p);
		for(int j = 0; j < 3; j++) p[j] *= X[6];
		for(int j = 0; j < 3; j++) p[j] += X[j+3];

		Y[3*i]   = p[0];
		Y[3*i+1] = p[1];
		Y[3*i+2] = p[2];

	}
   }
		

/* ===============================================================================================
   Compute cost matrix
   =============================================================================================== */

   void Optim::computeCostMatrix(int npoint1, double *Target, int npoint2, double *Y)
   {

	double sum, val;
	for(int i = 0; i < npoint2; i++) {
		for(int j = 0; j < npoint1; j++) {
			sum = 0;
			for(int k = 0; k < 3; k++) {
				val = Target[3*j+k] - Y[3*i+k];
				sum += val*val;
			}
			Cost[i*npoint1+j] = std::sqrt(sum);
//			Cost[i*npoint1+j] = sum;
		}
	}
   }

/* ===============================================================================================
   Compute num gradient vector from optimized transfer plan G
   =============================================================================================== */

   void Optim::computeNumGradient(int npoint1, double *Target, int npoint2, double *Y, 
	double *X, double *dX , int method)
  {
	double dx = 0.1;
	double e1, e2;
	for(int i = 0; i < 7; i++) {
		X[i] = X[i] + dx;
		e1 = ene(npoint1, Target, npoint2, Y, X, method);
		X[i] = X[i] - 2.0*dx;
		e2 = ene(npoint1, Target, npoint2, Y, X, method);
		X[i] = X[i] + dx;
		dX[i] = (e1-e2)/(2*dx);
	}
   }

/* ===============================================================================================
   Compute gradient vector from optimized transfer plan G
   =============================================================================================== */

   void Optim::computeGradient(int npoint1, double *Target, int npoint2, double *Y, 
	double *X, double *dX )
   {
	memset(dX, 0., 7*sizeof(double));

	double dRx[3][3]; double dRy[3][3]; double dRz[3][3];
	double v2[3], v1x[3], v1y[3], v1z[3];
	double vx, vy, vz;

	emap.dR_dVi(X, 0, dRx);
	emap.dR_dVi(X, 1, dRy);
	emap.dR_dVi(X, 2, dRz);

	double sum, coef;
	for(int j = 0; j < npoint2; j++) {

		Rotation(dRx, &Yref[3*j], v1x);
		Rotation(dRy, &Yref[3*j], v1y);
		Rotation(dRz, &Yref[3*j], v1z);

		for(int i = 0; i < npoint1; i++) {
			sum = 0;
			for(int k = 0; k < 3; k++) {
				v2[k] = Y[3*j+k]-Target[3*i+k];
				sum += v2[k]*v2[k];
			}
			if(sum!=0) {
				coef = G[j*npoint1+i]/std::sqrt(sum);
			} else {
				coef = 0;
			}
//			coef = 2*G[j*npoint1+i];
			vx = v1x[0]*v2[0] + v1x[1]*v2[1] + v1x[2]*v2[2];
			vy = v1y[0]*v2[0] + v1y[1]*v2[1] + v1y[2]*v2[2];
			vz = v1z[0]*v2[0] + v1z[1]*v2[1] + v1z[2]*v2[2];
			dX[0] += coef*vx;
			dX[1] += coef*vy;
			dX[2] += coef*vz;
			dX[3] += coef*v2[0];
			dX[4] += coef*v2[1];
			dX[5] += coef*v2[2];
			dX[6] += coef*(v2[0]*(Y[3*j]-X[3])+v2[1]*(Y[3*j+1]-X[4])+v2[2]*(Y[3*j+2]-X[5]))/X[6];
		}
	}
   }

/* ===============================================================================================
   Compute "energy" (OT distance between the points and gradient wrt Y coordinates)
   =============================================================================================== */

   int Optim::eneAndDer(int npoint1, double *Target, int npoint2, double *Y, 
	double *X, double *dX, int method, double *d, double *dm1, double *dm2)
   {

	int info = 0;
	double tol = 1.e-4;

	if(init == 0) {
		memset(lambda, 0, npoint1*sizeof(double));
		memset(mu, 0, npoint2*sizeof(double));
		x = 0;
	}

	std::cout << "Apply transform" << std::endl;
	applyTransform(X, npoint2, Y);

	memset(Cost, 0, npoint1*npoint2*sizeof(double));
	std::cout << "Compute cost matrix" << std::endl;
	computeCostMatrix(npoint1, Target, npoint2, Y);

	for(int i = 0; i < npoint1; i++) m1[i] = 1./npoint1;
	for(int i = 0; i < npoint2; i++) m2[i] = 1./npoint2;

	double F_opt;
	double dist;
	double dmass1 = 0.0;
	double dmass2 = 0.0;
	double dtot;
	if(method==0) {
		dist = ot.ot1(npoint1, m1, npoint2, m2, Cost, G, lambda, mu,
		beta_init, beta_final, &F_opt, iprint, init, nproc); 
		dtot = dist;
	} else {
		std::cout << "Apply otm" << std::endl;
		dist = otw.ot1_w(npoint1, m1, npoint2, m2, Cost, G, lambda, mu, &x,
		beta_init, beta_final, &F_opt, iprint, init, a1, a2, nproc); 
		for(int i = 0; i < npoint1; i++) {
			dmass1 += a1*m1[i]*m1[i];
		}
		for(int i = 0; i < npoint2; i++) {
			dmass2 += a2*m2[i]*m2[i];
		}
		dtot = dist + dmass1 + dmass2;
		std::cout << "otm done" << std::endl;
	}
	init++;

	*d = dtot;
	*dm1 = dmass1;
	*dm2 = dmass2;

	if(dist < tol) {
		info = 1;
		return info;
	}

	std::cout << "Compute Gradient" << std::endl;
	computeGradient(npoint1, Target, npoint2, Y, X, dX);
	std::cout << "Gradient done" << std::endl;

	return info;

  }

/* ===============================================================================================
   Compute "energy" (OT distance between the points)
   =============================================================================================== */

   double Optim::ene(int npoint1, double *Target, int npoint2, double *Y, double *X, int method)
   {

        memset(lambda, 0, npoint1*sizeof(double));
       	memset(mu, 0, npoint2*sizeof(double));
	x = 0;

	applyTransform(X, npoint2, Y);
	memset(Cost, 0, npoint1*npoint2*sizeof(double));
	computeCostMatrix(npoint1, Target, npoint2, Y);

	for(int i = 0; i < npoint1; i++) m1[i] = 1./npoint1;
	for(int i = 0; i < npoint2; i++) m2[i] = 1./npoint2;

	double dist, F_opt;
	double dmass1 = 0;
	double dmass2 = 0;
	if(method==0) {
		dist = ot.ot1(npoint1, m1, npoint2, m2, Cost, G, lambda, mu,
		beta_init, beta_final, &F_opt, iprint, init, nproc); 
	} else {
		dist = otw.ot1_w(npoint1, m1, npoint2, m2, Cost, G, lambda, mu, &x,
		beta_init, beta_final, &F_opt, iprint, init, a1, a2, nproc); 
		for(int i = 0; i < npoint1; i++) {
			dmass1 += a1*m1[i]*m1[i];
		}
		for(int i = 0; i < npoint2; i++) {
			dmass2 += a2*m2[i]*m2[i];
		}
		dist += dmass1 + dmass2;
	}
	init++;

	return dist;

  }

  /* ==============================================================================================
	Optimizes moving points with lbfgsb
   =============================================================================================== */

  double Optim::oneStep_LBFGSB(int npoint1, double *Target, int npoint2, double *Y, 
	double *X, double *dX, int method, int *IFLAG, int iter)
  {

  /* ==================================================================================================
	Variables for lbfgs
   ==================================================================================================*/

	int Msize = M;
	int IPRINT = -1;

	double FACTR = 1.e7;
	double PGTOL = 1.e-4;

	double dm1 = 0.0;
	double dm2 = 0.0;

	if(*IFLAG == 0) {
		TASK[0] = 'S'; TASK[1] = 'T'; TASK[2]='A'; TASK[3]='R'; TASK[4]='T';
	}

	int info;
	double dist;
	double dist1;
	do {
		info = eneAndDer(npoint1, Target, npoint2, Y, X, dX, method, &dist, &dm1, &dm2);
		std::cout << "info = " << info << std::endl;
		if(info==1) {
			std::cout << "dtot = " << dist << std::endl;
			std::cout << "dm1  = " << dm1 << std::endl;
			std::cout << "dm2  = " << dm2 << std::endl;
			std::cout << "dist = " << dist-dm1-dm2 << std::endl;
			TASK[0] = 'C';
			break;
		}
		setulb_(&N, &Msize, X, L, U, NBD, &dist, dX, &FACTR, &PGTOL, WA, IWA,
		TASK, &IPRINT, CSAVE, LSAVE, ISAVE, DSAVE);
//		std::cout << "TASK = " ;
//		for(int i = 0; i < 5; i++) std::cout << TASK[i];
//		std::cout << std::endl;
	} while(TASK[0]=='F');

	dist1 = dist - dm1 - dm2;
        std::cout << "     " << "   " << std::setw(8)<< iter+1;
        std::cout << "     " << std::fixed << std::setprecision(6) << std::setw(8) << dist1 ;
        std::cout << "     " << std::fixed << std::setprecision(6) << std::setw(8) << dm1;
        std::cout << "     " << std::fixed << std::setprecision(6) << std::setw(8) << dm2;
        std::cout << "     " << std::fixed << std::setprecision(6) << std::setw(8) << dist ;
	std::cout << std::endl;

	if(TASK[0]=='C' || dist1 < 1.e-5) {
		*IFLAG=0;
	} else if( TASK[0]=='A' || TASK[0]=='E') {
		*IFLAG=1;
	} else {
		*IFLAG=2;
	}

	return dist;
  }

  /* ==============================================================================================
	Check derivatives
   =============================================================================================== */

  void Optim::checkDeriv(int npoint1, double *Target, int npoint2, double *Y,
		double *X, double *dX, int method)
  {

  /* ==================================================================================================
	Variables for lbfgs
   ==================================================================================================*/

	X[0] =M_PI/3; X[1] = M_PI/4; X[2] = M_PI/6;
	X[3] = 3.; X[4] = 4.; X[5] = 5.; X[6] = 1.0;
	double anal[7], num[7];

	double dx = 0.01;

	double R[3][3];
	double R1[3][3], R2[3][3];
	double Rnum[3][3];
	double dRx[3][3]; double dRy[3][3]; double dRz[3][3];

	emap.EM3_To_R(X, R);
	emap.dR_dVi(X, 0, dRx);
	emap.dR_dVi(X, 1, dRy);
	emap.dR_dVi(X, 2, dRz);

	std::cout << std::endl;
	std::cout << "First coord: " << std::endl;
	X[0] = X[0] + dx;
	emap.EM3_To_R(X, R1);
	X[0] = X[0] - 2.0*dx;
	emap.EM3_To_R(X, R2);
	X[0] = X[0] + dx;
	for(int j = 0; j < 3; j++) {
		for(int k = 0; k < 3; k++) {
			Rnum[j][k] = (R1[j][k]-R2[j][k])/(2*dx);
			std::cout << "j = " << j << " k = " << k << " anal = " << dRx[j][k] << " num = " << Rnum[j][k] << std::endl;
		}
	}
	std::cout << std::endl;
	std::cout << "Second coord: " << std::endl;
	X[1] = X[1] + dx;
	emap.EM3_To_R(X, R1);
	X[1] = X[1] - 2.0*dx;
	emap.EM3_To_R(X, R2);
	X[1] = X[1] + dx;
	for(int j = 0; j < 3; j++) {
		for(int k = 0; k < 3; k++) {
			Rnum[j][k] = (R1[j][k]-R2[j][k])/(2*dx);
			std::cout << "j = " << j << " k = " << k << " anal = " << dRy[j][k] << " num = " << Rnum[j][k] << std::endl;
		}
	}
	std::cout << std::endl;
	std::cout << std::endl;
	std::cout << "Third coord: " << std::endl;
	X[2] = X[2] + dx;
	emap.EM3_To_R(X, R1);
	X[2] = X[2] - 2.0*dx;
	emap.EM3_To_R(X, R2);
	X[2] = X[2] + dx;
	for(int j = 0; j < 3; j++) {
		for(int k = 0; k < 3; k++) {
			Rnum[j][k] = (R1[j][k]-R2[j][k])/(2*dx);
			std::cout << "j = " << j << " k = " << k << " anal = " << dRz[j][k] << " num = " << Rnum[j][k] << std::endl;
		}
	}
		

	double dm1 = 0, dm2=0;
	double dist;
	eneAndDer(npoint1, Target, npoint2, Y, X, dX, method, &dist, &dm1, &dm2);
	for(int i = 0; i < 7; i++) anal[i] = dX[i];
	std::cout << std::endl;
	std::cout << "dist = " << dist << std::endl;

	double e1, e2;
	for(int i = 0; i < 7; i++) {
		X[i] = X[i] + dx;
		e1 = ene(npoint1, Target, npoint2, Y, X, method);
		X[i] = X[i] - 2.0*dx;
		e2 = ene(npoint1, Target, npoint2, Y, X, method);
		X[i] = X[i] + dx;
		std::cout << "e1 = " << e1 << " e2 = " << e2 << std::endl;
		num[i] = (e1-e2)/(2*dx);
	}

	std::cout << std::endl;
	for(int i = 0; i < 7; i++) {
		std::cout << "anal[i] = " << anal[i] << " num[i] = " << num[i] << std::endl;
	}
	exit(1);
   }

/*==========================================================================================
	Rigid body registration
========================================================================================== */

  double Optim::refine(Mesh& model1, Mesh& model2, double a1, double a2, int niter, 
  int method, int nthreads)
  {

/*==========================================================================================
	Transfer model1 vertices to Target
========================================================================================== */

	int npoint1 = model1.vertices.size();
	double *Target = new double[3*npoint1];
	Vector cm;
	double Scale;

	for(VertexIter v = model1.vertices.begin(); v != model1.vertices.end(); v++) {
		cm += v->position;
	}
	cm /= npoint1;
	double Area = 0;
	for (FaceIter f = model1.faces.begin(); f != model1.faces.end(); f++) {
		Area += f->area();
	}
	Scale = std::sqrt(1.0/Area);

	int idx = 0;
	for(VertexIter v = model1.vertices.begin(); v != model1.vertices.end(); v++) {
		for(int k = 0; k < 3; k++) {
			Target[3*idx + k] = Scale*(v->position[k]-cm[k]);
		}
		idx++;
	}

/*==========================================================================================
	Transfer model2 vertices to moving points Y
========================================================================================== */

	int npoint2 = model2.vertices.size();
	double *Y = new double[3*npoint2];
	double *Y0 = new double[3*npoint2];
	double CG2[3]={0.,0.,0.};

	idx = 0;
	for(VertexIter v = model2.vertices.begin(); v != model2.vertices.end(); v++) {
		for(int k = 0; k < 3; k++) {
			Y[3*idx + k] = Scale*(v->position[k]-cm[k]);
			CG2[k] += Y[3*idx + k];
		}
		idx++;
	}
	CG2[0] /= npoint2; CG2[1] /= npoint2; CG2[2] /= npoint2;

	for(int i = 0; i < npoint2; i++) {
		for(int k = 0; k < 3; k++) {
			Y0[3*i + k] = Y[3*i+k] - CG2[k];
		}
	}

/*==========================================================================================
	Optimize positions
========================================================================================== */

	double X[7];
	double dX[7];

	memset(X, 0, 7*sizeof(double));
	memset(dX, 0, 7*sizeof(double));
	X[3] = CG2[0]; X[4] = CG2[1]; X[5] = CG2[2];
        X[6] = 1.;

/* ===============================================================================================
   	Estimate values for alpha1 and alpha2
   =============================================================================================== */

/*
	double sum, val;
	double s = 0.0;
	for(int i = 0; i < npoint2; i++) {
		for(int j = 0; j < npoint1; j++) {
			sum = 0;
			for(int k = 0; k < 3; k++) {
				val = Target[3*j+k] - Y[3*i+k];
				sum += val*val;
			}
			s+= std::sqrt(sum);
		}
	}
	s = s/(npoint1*npoint2);

	double alpha1 = 10*s;
	double alpha2 = 10*s;

	if(a1==-1) {
		a1 = alpha1;
		a2 = alpha2;
	}
*/

/* ===============================================================================================
   	Initialize optimization
   =============================================================================================== */

	initOptim(npoint1, Target, npoint2, Y0, X, dX, a1, a2, nthreads);

//	checkDeriv(npoint1, Target, npoint2, Y, X, dX, method);

/* ===============================================================================================
   	Perform optimization
   =============================================================================================== */

	int IFLAG = 0;
	double dist = 0;

	for(int i = 0; i < niter; i++) {
		dist = oneStep_LBFGSB(npoint1, Target, npoint2, Y, X, dX, method, &IFLAG, i);
		if(IFLAG==0 || IFLAG == 1) break;
//		if(std::abs(dist-dist_old) < 1.e-5) break;
	}

	idx = 0;
	for (VertexIter v = model2.vertices.begin(); v != model2.vertices.end(); v++) {
		v->position2[0] = Y[3*idx+0]/Scale + cm[0];
		v->position2[1] = Y[3*idx+1]/Scale + cm[1];
		v->position2[2] = Y[3*idx+2]/Scale + cm[2];
		idx++;
	}

	return dist;

}

#endif
