/*================================================================================================
  OptimBfactor.h
  Version 1: 12/1/2017

Copyright (c) Patrice Koehl.

================================================================================================== */

#ifndef _OPTIMBFACTOR_KI_H_
#define _OPTIMBFACTOR_KI_H_

/*================================================================================================
 Includes
================================================================================================== */

#include <math.h>
#include <cstdlib>
#include "derivs_ki_thread.h"

#ifndef M_PI
    #define M_PI 3.14159265358979323846
#endif

/*================================================================================================
  Protoflag_aves for BLAS
================================================================================================== */

extern "C" {

	void daxpy_(int * n ,double *alpha , double * X, int *incx, double * Y,int *incy);
	double dnrm2_(int * n, double * X, int *incx);
	void dscal_(int * n, double * alpha, double * X, int *incx);
	void dcopy_(int * n, double * X, int *incx, double * Y, int *incy);
	double ddot_(int * n, double * u, int * incu, double * v, int *incv);

	void setulb_(int *N, int *M, double *X, double *L, double *U, int *NBD, double *F, 
		double *G, double *FACTR, double *PGTOL, double *WA, int *IWA, char *TASK,
		int *IPRINT, char *CSAVE, bool *LSAVE, int *ISAVE, double *DSAVE);

}

/*================================================================================================
  class
================================================================================================== */

  class OptimBfactor_ki {

	public:

		// Optim bfactors
		double optimkval(std::vector<Atoms>& atoms, std::vector<Links>& pairs, int nm1,
		int nm2, int nm3, double *Uij, double *U1ij, double *Uijk, double *Uijkl,
		double *hessian, double *eigVal, double *eigVect, double *bfact, int flag_ave, 
		int flag_ent, int nthreads);

		// Check derivatives
		void checkDeriv(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
		int nm1, int nm2, double *Uij, double *U1ij, double *Uijk, double *Uijkl,
		double *hessian, double *eigVal, double *eigVect, double *bfact, int flag_ave, 
		int flag_ent, int nthreads);

	private:

		// init
		void init(int natoms, int nlinks, int nm1, int nm2, int nm3);

		// InitKconst
		double initKconst(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
		double *Uij, double *U1ij, double *Uijk, double *Uijkl, double *hessian,
		double *eigVal, double *eigVect, int flag_ave);

		//
		void resetKconst(std::vector<Atoms>& atoms, double *X, int potential);

		// entropy
		double entropy(double *eigVal);

		// Energy function
		double energy(std::vector<Atoms>& atoms, double *eigVal, double *eigVect, double *bfact);

		// derivatives of Go-specific terms
		void deriv_Go(std::vector<Atoms>& atoms, double *U1ij, double *Uijk, double *Uijkl, 
		double *eigVal, double *eigVect, double *bfact);

		// Energy and derivatives
		double eneAndDer(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
		double *Uij, double *U1ij, double *Uijk, double *Uijkl,
		double *eigVal, double *eigVect, double *bfact, int flag_ave, int flag_ent, int nthreads);

		// Recompute Hessian and diagonalize
		void diagHessian(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
		double *Uij, double *U1ij, double *Uijk, double *Uijkl, double *hessian,
		double *eigVal, double *eigVect, int flag_ave);

		// One step of LBFGSB
		double oneStep_LBFGSB(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
		double *Uij, double *U1ij, double *Uijk, double *Uijkl, double *hessian,
		double *eigVal, double *eigVect, double *bfact, int *IFLAG, int iter, 
		int flag_ave, int flag_ent, int nthreads);

		int nmode1, nmode2, nmodeB;
		double *DiagHm1, *dDiagHm1;
		double facb;

		int N, M;
		double *X, *G;
		double *L, *U, *WA, *DSAVE;
		int *NBD, *IWA, *ISAVE;
		char *TASK, *CSAVE;
		bool *LSAVE;

		int IPRINT = -1;
		double FACTR = 1.e7;
		double PGTOL = 1.e-4;

		int flag_param = 2;
		int potential = 0;
		double lambda, lambda_ent;

  };

/* ===============================================================================================
   Initialize parameters
   =============================================================================================== */

void OptimBfactor_ki::init(int natoms, int nlinks, int nm1, int nm2, int nm3)
{
	nmode1 = nm1;
	nmode2 = nm2;
	nmodeB = nm3;

	double kT = 0.593;
	facb = 8.0*kT*M_PI*M_PI/3.0;

	DiagHm1 = new double[3*natoms];
	dDiagHm1 = new double[3*natoms];

	N = natoms;
	if(potential==1) N = N + 3;
	M = 10;

	X = new double[N];
	G = new double[N];

	L = new double[N];
	U = new double[N];
	NBD = new int[N];
	WA = new double[2*M*N+5*N+11*M*M+8*M];
	IWA = new int[3*N];
	ISAVE = new int[44];
	DSAVE = new double[29];
	LSAVE = new bool[4];
	TASK = new char[60];
	CSAVE = new char[60];
	memset(L, 0, N*sizeof(double));
	memset(U, 0, N*sizeof(double));

	for(int i = 0; i < N; i++) {
		NBD[i] = 1;
		L[i] = 0.0001;
	}

	for(int i = 0; i < 60; i++) {
		TASK[i] = ' ';
		CSAVE[i] = ' ';
	}

	memset(G, 0, N*sizeof(double));

	std::cout << "        " << "=======================================================================" << std::endl;
	std::cout << "        " << "       Iter           FitB           F           chi2_tot       NormG  " << std::endl;
	std::cout << "        " << "=======================================================================" << std::endl;

  }

/*================================================================================================
 initKconst: initialize all ki to the same value
================================================================================================== */

 double OptimBfactor_ki::initKconst(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
	double *Uij, double *U1ij, double *Uijk, double *Uijkl, double *hessian,
	double *eigVal, double *eigVect, int flag_ave)
  {
	int natoms = atoms.size();
	int nshift = 0;
	if(potential==1) {
		nshift = 3;
		double eps = 0.36;
		X[0] = 100.0*eps;
		X[1]  = 20.0*eps;
		X[2]  = 50*eps;
	}

	for(int i = nshift; i < nshift+natoms; i++) {
		X[i] = 1.0;
	}

	diagHessian(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian, eigVal, eigVect, flag_ave);

	int npar = 3*natoms;
	memset(DiagHm1, 0, npar*sizeof(double));

	double sum, val;

	for(int k = 0; k < npar; k++) {
		sum = 0.0;
		for (int i = nmode1; i < nmode2; i++) 
		{
			val = facb/eigVal[i];
			sum = sum + val*eigVect[i*npar + k]*eigVect[i*npar +k];
		}
		DiagHm1[k] = sum;
	}

	double sum1=0, sum2=0;
	double b_exp, b_calc, diff;
	for(int i = 0; i < natoms; i++)
	{
		b_exp = atoms[i].bfact;
		b_calc = DiagHm1[3*i] + DiagHm1[3*i+1] + DiagHm1[3*i+2];
		sum1 += b_exp*b_calc;
		sum2 += b_calc*b_calc;
	}
	double kval = sum2/sum1;
	std::cout << "kval = " << kval << std::endl;

	for(int i = nshift; i < nshift+natoms; i++) {
		X[i] = kval;
	}
 	return kval;
  }

/*================================================================================================
 entropy: computes entropy regularisation
================================================================================================== */

  double OptimBfactor_ki::entropy(double *eigVal)
  {
	double sum = 0.0;
	for(int i = nmode1; i < nmode2; i++) {
		sum += std::log(eigVal[i]);
	}
	return sum;
  }

/*================================================================================================
 Energy: computes difference between computed and experimental B-factors
================================================================================================== */

double OptimBfactor_ki::energy(std::vector<Atoms>& atoms, double *eigVal, double *eigVect, double *bfact)
{
	int natoms = atoms.size();
	int npar = 3*natoms;
	memset(DiagHm1, 0, npar*sizeof(double));

	double sum, val;

	for(int k = 0; k < npar; k++) {
		sum = 0.0;
		for (int i = nmode1; i < nmodeB; i++) 
		{
			val = facb/eigVal[i];
			sum = sum + val*eigVect[i*npar + k]*eigVect[i*npar +k];
		}
		DiagHm1[k] = sum;
	}

	double chi2=0;
	double b_exp, b_calc, diff;
	for(int i = 0; i < natoms; i++)
	{
		b_exp = atoms[i].bfact;
		b_calc = DiagHm1[3*i] + DiagHm1[3*i+1] + DiagHm1[3*i+2];
		diff = b_calc - b_exp;
		chi2 += diff*diff;
		bfact[i] = b_calc;
	}

	return chi2;
}

/*================================================================================================
  derivatives of the specific Go terms
================================================================================================== */

  void OptimBfactor_ki::deriv_Go(std::vector<Atoms>& atoms,
	double *U1ij, double *Uijk, double *Uijkl, double *eigVal, double *eigVect, double *bfact)
  {

	int natoms = atoms.size();
	int npar = 3*natoms;

	int iat, jat;
	double a, b, sum;
	int inc=1;
	double *u, *e1, *e2;

	// Bond term
	G[0] = 0.;
	for(int k = 0; k < natoms-1; k++) {
		u = &U1ij[3*k];
		iat = k;
		jat = k+1;

		memset(dDiagHm1, 0, npar*sizeof(double));
		for(int i = nmode1; i < nmode2; i++) {
			e1 = &eigVect[i*npar + 3*iat];
			e2 = &eigVect[i*npar + 3*jat];
			a = u[0]*(e1[0]-e2[0]) + u[1]*(e1[1]-e2[1]) + u[2]*(e1[2]-e2[2]);
			a = a/eigVal[i];
			daxpy_(&npar, &a, &eigVect[i*npar], &inc, dDiagHm1, &inc);
		}

		for(int i = 0; i < npar; i++) {
			a = dDiagHm1[i];
			dDiagHm1[i] = -a * a;
		}

		sum = 0.;
		for(int i = 0; i < natoms; i++) {
			a = bfact[i] - atoms[i].bfact;
			b = facb*(dDiagHm1[3*i] + dDiagHm1[3*i+1] + dDiagHm1[3*i+2]);
			sum += a*b;
		}
		sum *= lambda;
		G[0] += 2*sum;

		sum = 0.;
		for(int i = nmode1; i < nmode2; i++) {
			e1 = &eigVect[i*npar + 3*iat];
			e2 = &eigVect[i*npar + 3*jat];
			a = u[0]*(e1[0]-e2[0]) + u[1]*(e1[1]-e2[1]) + u[2]*(e1[2]-e2[2]);
			a = (a*a)/eigVal[i];
			sum += a;
		}
		sum *= lambda_ent;
		G[0] += sum;
	}

	// Angle term
	int nval = 9;
	G[1] = 0;
	for(int k = 0; k < natoms-2; k++) {
		u = &Uijk[9*k];

		memset(dDiagHm1, 0, npar*sizeof(double));
		for(int i = nmode1; i < nmode2; i++) {
			e1 = &eigVect[i*npar + 3*k];
			a = ddot_(&nval, u, &inc, e1, &inc);
			a = a/eigVal[i];
			daxpy_(&npar, &a, &eigVect[i*npar], &inc, dDiagHm1, &inc);
		}

		for(int i = 0; i < npar; i++) {
			a = dDiagHm1[i];
			dDiagHm1[i] = -a * a;
		}

		sum = 0.;
		for(int i = 0; i < natoms; i++) {
			a = bfact[i] - atoms[i].bfact;
			b = facb*(dDiagHm1[3*i] + dDiagHm1[3*i+1] + dDiagHm1[3*i+2]);
			sum += a*b;
		}
		sum *= lambda;
		G[1] += 2*sum;

		sum = 0.;
		for(int i = nmode1; i < nmode2; i++) {
			e1 = &eigVect[i*npar + 3*k];
			a = ddot_(&nval, u, &inc, e1, &inc);
			a = (a*a)/eigVal[i];
			sum += a;
		}
		sum *= lambda_ent;
		G[1] += sum;
	}

	// dihedral term
	nval = 12;
	G[2] = 0;
	for(int k = 0; k < natoms-3; k++) {
		u = &Uijkl[12*k];

		memset(dDiagHm1, 0, npar*sizeof(double));
		for(int i = nmode1; i < nmode2; i++) {
			e1 = &eigVect[i*npar + 3*k];
			a = ddot_(&nval, u, &inc, e1, &inc);
			a = a/eigVal[i];
			daxpy_(&npar, &a, &eigVect[i*npar], &inc, dDiagHm1, &inc);
		}

		for(int i = 0; i < npar; i++) {
			a = dDiagHm1[i];
			dDiagHm1[i] = -a * a;
		}

		sum = 0.;
		for(int i = 0; i < natoms; i++) {
			a = bfact[i] - atoms[i].bfact;
			b = facb*(dDiagHm1[3*i] + dDiagHm1[3*i+1] + dDiagHm1[3*i+2]);
			sum += a*b;
		}
		sum *= lambda;
		G[2] += 2*sum;

		sum = 0.;
		for(int i = nmode1; i < nmode2; i++) {
			e1 = &eigVect[i*npar + 3*k];
			a = ddot_(&nval, u, &inc, e1, &inc);
			a = (a*a)/eigVal[i];
			sum += a;
		}
		sum *= lambda_ent;
		G[2] += sum;
	}

  }

/*================================================================================================
 EneAndDer: computes difference between computed and experimental B-factors, as well as derivatives
	    wrt force constants
================================================================================================== */

 double OptimBfactor_ki::eneAndDer(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
	double *Uij, double *U1ij, double *Uijk, double *Uijkl,
	double *eigVal, double *eigVect, double *bfact, int flag_ave, int flag_ent, int nthreads)
  {

	double chi2 = energy(atoms, eigVal, eigVect, bfact);

	int natoms = atoms.size();
	int npairs = pairs.size();
	int nshift = 0;
	if(potential==1) nshift = 3;

/*	==========================================================================================
	Break list to all threads and send jobs
	========================================================================================== */

	int nval = npairs/nthreads;
	int N1, N2;

	for (int i=0; i < nthreads; i++)
	{
		N1 = i*nval;
		N2 = N1 + nval;
		if(i==nthreads-1) N2 = npairs;
		threadids[i] = i;

		derivs_ki[i].N1 = N1;
		derivs_ki[i].N2 = N2;
		derivs_ki[i].flag_ave  = flag_ave;
		derivs_ki[i].facb      = facb;
		derivs_ki[i].nmode1    = nmode1;
		derivs_ki[i].nmode2    = nmode2;
		derivs_ki[i].lambda    = lambda;
		derivs_ki[i].lambda_ent= lambda_ent;
		derivs_ki[i].eigVect   = eigVect;
		derivs_ki[i].eigVal    = eigVal;
		derivs_ki[i].bfact     = bfact;
		derivs_ki[i].atoms     = atoms;
		derivs_ki[i].pairs     = pairs;
		derivs_ki[i].Uij       = Uij;
		derivs_ki[i].X         = &X[nshift];
		double *Work = new double[pairs.size()];
		memset(Work, 0., pairs.size()*sizeof(double));
		derivs_ki[i].Work      = Work;
		double *WG = new double[atoms.size()];
		memset(WG, 0., atoms.size()*sizeof(double));
		derivs_ki[i].G         = WG;

		pthread_create(&threads[i], NULL, deriv_ki_thread, (void*) &threadids[i]);
	}

/*	==========================================================================================
	Compute bonded derivatives on main thread
	========================================================================================== */

	memset(G, 0, natoms*sizeof(double));
	if(potential==1) {
		deriv_Go(atoms, U1ij, Uijk, Uijkl, eigVal, eigVect, bfact);
	}

/*	==========================================================================================
	Join all the threads (to make sure they are all finished)
	========================================================================================== */

	int inc = 1;
	double a = 1.;
	for (int i=0; i < nthreads; i++)
	{
		pthread_join(threads[i], NULL);
		daxpy_(&natoms, &a, derivs_ki[i].G, &inc, &G[nshift], &inc);
	}

	return chi2;

  }

/*================================================================================================
 diagHessian: recompute Hessian and its eigenvalues and vectors for new kval
================================================================================================== */

 void OptimBfactor_ki::diagHessian(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
	double *Uij, double *U1ij, double *Uijk, double *Uijkl, double *hessian,
	double *eigVal, double *eigVect, int flag_ave)
 {
        hess.buildK2(pairs, atoms, X, potential, flag_param, flag_ave);
	int natoms = atoms.size();
        hess.fullHessian(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian);
        int N2 = 9*natoms*natoms;
        int inc = 1;
        dcopy_(&N2, hessian, &inc, eigVect, &inc);
        hess.fullEigen(natoms, eigVal, eigVect);
	hess.rescaleEigVect(atoms, nmode2, eigVect);
  }

/*================================================================================================
 From X to atoms information
================================================================================================== */

   void OptimBfactor_ki::resetKconst(std::vector<Atoms>& atoms, double *X, int potential)
  {
	int natoms = atoms.size();
	int nshift = 0;
	if(potential==1) nshift = 3;
	for(int i = 0; i < natoms; i++) {
		atoms[i].kconst = X[i+nshift];
	}

	if(potential==1) {
		for(int i = 0; i < natoms-1; i++) {
			atoms[i].k_bond = X[0];
		}
		for(int i = 0; i < natoms-2; i++) {
			atoms[i].k_angle = X[1];
		}
		for(int i = 0; i < natoms-3; i++) {
			atoms[i].k_dihed = X[2];
		}
	}
   }

  /* ==============================================================================================
	One step of the Optimization of kconst with lbfgsb
   =============================================================================================== */

  double OptimBfactor_ki::oneStep_LBFGSB(std::vector<Atoms>& atoms, std::vector<Links>& pairs, 
		double *Uij, double *U1ij, double *Uijk, double *Uijkl, double *hessian,
		double *eigVal, double *eigVect, double *bfact, int *IFLAG, int iter, 
		int flag_ave, int flag_ent, int nthreads)
  {

  /* ==================================================================================================
	Variables for lbfgs
   ==================================================================================================*/

	if(*IFLAG == 0) {
		TASK[0] = 'S'; TASK[1] = 'T'; TASK[2]='A'; TASK[3]='R'; TASK[4]='T';
	}

	double dist, dist_ent, dist_tot;
	do {
		diagHessian(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian, eigVal, eigVect, flag_ave);
		dist = eneAndDer(atoms, pairs, Uij, U1ij, Uijk, Uijkl, eigVal, eigVect, bfact, 
			flag_ave, flag_ent, nthreads);
		dist_ent = entropy(eigVal);
		dist_tot = dist*lambda + lambda_ent*dist_ent;
		setulb_(&N, &M, X, L, U, NBD, &dist_tot, G, &FACTR, &PGTOL, WA, IWA,
		TASK, &IPRINT, CSAVE, LSAVE, ISAVE, DSAVE);
//		std::cout << TASK[0] << " " << TASK[1] << " " << TASK[2] << " " << TASK[3] << " " << TASK[4] << std::endl;
	} while(TASK[0]=='F');

	int inc = 1;
	double gnorm = dnrm2_(&N, G, &inc);
        std::cout << "     " << "   " << std::setw(8)<< iter+1;
        std::cout << "     " << std::fixed << std::setprecision(6) << std::setw(8) << dist;
        std::cout << "     " << std::fixed << std::setprecision(6) << std::setw(8) << dist_ent;
        std::cout << "     " << std::fixed << std::setprecision(6) << std::setw(8) << dist_tot;
        std::cout << "     " << std::fixed << std::setprecision(6) << std::setw(8) << gnorm;
	std::cout << std::endl;

	if(TASK[0]=='C') {
		*IFLAG=0;
	} else if( TASK[0]=='A' || TASK[0]=='E') {
		*IFLAG=1;
	} else {
		*IFLAG=2;
	}

	return dist;
  }

  /* ==============================================================================================
	Check derivatives of Bfactors
   =============================================================================================== */

  void OptimBfactor_ki::checkDeriv(std::vector<Atoms>& atoms, std::vector<Links>& pairs, int nm1, int nm2,
		double *Uij, double *U1ij, double *Uijk, double *Uijkl, double *hessian,
		double *eigVal, double *eigVect, double *bfact, int flag_ave, int flag_ent, int nthreads)
  {

	int natoms = atoms.size();
	int npairs = pairs.size();

	if(U1ij) potential = 1;
	init(natoms, npairs, nm1, nm2, nm2);

	int nval = natoms;
	if(potential==1) nval += 3;
	double *anal = new double[nval];

 	double kval = initKconst(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian, eigVal, eigVect, flag_ave);

	lambda = 4.0;
	lambda_ent = 1.0;
	if(flag_ent==1) {
		lambda = 4.0;
		lambda_ent = 0.0;
	}

	diagHessian(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian, eigVal, eigVect, flag_ave);

	double dist = eneAndDer(atoms, pairs, Uij, U1ij, Uijk, Uijkl, eigVal, eigVect, bfact, 
		flag_ave, flag_ent, nthreads);

	int inc = 1;
	dcopy_(&nval, G, &inc, anal, &inc);

	double dx=1.e-5;

	double *num = new double[nval];
	double dist1, dist1_ent, dist2, dist2_ent;
	for(int i = 0; i < nval; i++) {
		X[i] += dx;
		diagHessian(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian, eigVal, eigVect, flag_ave);
		dist1 = energy(atoms, eigVal, eigVect, bfact);
		dist1_ent = entropy(eigVal);
		dist1 = lambda*dist1 + lambda_ent*dist1_ent;
		X[i] -= 2*dx;
		diagHessian(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian, eigVal, eigVect, flag_ave);
		dist2 = energy(atoms, eigVal, eigVect, bfact);
		dist2_ent = entropy(eigVal);
		dist2 = lambda*dist2 + lambda_ent*dist2_ent;
		X[i] += dx;
		num[i] = (dist1-dist2)/(2*dx);
	}

	for(int i = 0; i < natoms; i++) {
		std::cout << "i = " << i << " anal = " << anal[i] << " num[i] = " << num[i] << std::endl;
	}
	exit(1);

   }

  /* ==============================================================================================
	Optimization of kconst with lbfgsb
   =============================================================================================== */

   double OptimBfactor_ki::optimkval(std::vector<Atoms>& atoms, std::vector<Links>& pairs, int nm1,
		int nm2, int nm3, double *Uij, double *U1ij, double *Uijk, double *Uijkl, double *hessian,
		double *eigVal, double *eigVect, double *bfact, int flag_ave, int flag_ent, int nthreads)
   {

	int natoms = atoms.size();
	int npairs = pairs.size();

	if(U1ij) potential = 1;
	init(natoms, npairs, nm1, nm2, nm3);

 	double kval = initKconst(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian, eigVal, eigVect, flag_ave);

	int niter = 1000;
	if(flag_ent==2) niter=200;
	int IFLAG = 0;
	double chi2;

	lambda = 4.0;
	lambda_ent = 1.0;
	if(flag_ent==1) {
		lambda = 4.0;
		lambda_ent = 0;
	} 

	do {
		IFLAG = 0;
		for(int i = 0; i < niter; i++) {
			chi2 = oneStep_LBFGSB(atoms, pairs, Uij, U1ij, Uijk, Uijkl, hessian, eigVal, eigVect, 
			bfact, &IFLAG, i, flag_ave, flag_ent, nthreads);
			if(IFLAG==0 || IFLAG == 1) break;
			if(chi2 < 1.e-3) break;
		}
		lambda = 2*lambda;
	} while (chi2 > 0.01 && flag_ent == 2);

	return chi2;

  }

#endif
