/*================================================================================================

  FitNMA
  Version 1: 9/30/2019

Copyright (c) Patrice Koehl.

================================================================================================== */


/*================================================================================================
  All includes
================================================================================================== */

#include "FitNMA.h"

/*================================================================================================
  Main
================================================================================================== */

int main(int argc, char** argv)
{

/*	==========================================================================================
         Show usage
	========================================================================================== */

	if(argc < 2)
	{
		usage(argv);
		return -1;
	}

	std::string input = argv[1];
	if(input == "-h" || input == "-help")
	{
		usage(argv);
		return -1;
	}

/*	==========================================================================================
         Read command line arguments
	========================================================================================== */

	OptimBfactor_ki optim_ki;

	std::string inputfile, outputfile; 
	std::string targetfile = " ";
	double cutoff    = 0;
	int atom_type    = 0;
	int flag_network = 2;
	int potential    = 1;
	int nmodes       = 0;
	int nthreads     = sysconf( _SC_NPROCESSORS_ONLN );
	int flag_mass    = 0;
	int flag_rigid   = 1;
	int flag_mean    = 2;
	int flag_ent     = 1;
	int flag_optim   = 1;
	int flag_rna	 = 0;

	parse_args(argc, argv, &inputfile, &outputfile, &targetfile, &flag_rna, &atom_type, 
	&flag_network, &potential, &cutoff, &flag_mass, &nmodes, &flag_optim,
	&flag_rigid, &flag_mean, &flag_ent); 

	nmodes += 6;

	if(nthreads==0) nthreads = 1;
	if(potential==2) atom_type = 0;
	
/* 	==========================================================================================
	Count atoms in PDB file, generate array of vertices, and read in coordinates
   	========================================================================================== */

	std::vector<Atoms> atoms;
	double center[3];
	int nchains;
	if(flag_rna==0) {
		readPDB(inputfile, atom_type, flag_mass, atoms, &nchains, center);
	} else {
		readRNA(inputfile, atom_type, flag_mass, atoms, &nchains);
	}
	int natoms = atoms.size();
	for(int i = 0; i < natoms; i++) {
		std::cout << "i = " << i << " coord = " << atoms[i].coord[0] << " " << atoms[i].coord[1] << " " << atoms[i].coord[2] << std::endl;
	}


	std::cout << " " << std::endl;
	std::cout << "Number of atoms found in file      : " << natoms << std::endl;
	std::cout << "Number of chains in protein complex: " << nchains << std::endl;
	std::cout << " " << std::endl;

	int mindiff = 0;
	if(potential==2) mindiff = 4;

	int npair, N;
	double *Uij, *U1ij, *Uijk, *Uijkl;
	double *hessian, *eigVal, *eigVect;

	N     = 3*natoms;
	int dim_max = N;
	long isize = N*dim_max;

	eigVal = new double[dim_max];
	eigVect = new double[isize];
	hessian = new double[isize];

	double eps;
	if(potential==1) {

		U1ij = NULL;
		Uijk = NULL;
		Uijkl = NULL;

	} else {

		U1ij  = new double[3*natoms];
		Uijk  = new double[9*natoms];
		Uijkl = new double[12*natoms];

		hess.buildU1ij(atoms, natoms, U1ij);
		hess.buildUijk(atoms, natoms, Uijk);
		hess.buildUijkl(atoms, natoms, Uijkl);

		eps = 0.36;
		double k_bond  = 200.0*eps;
		double k_angle = 40.0*eps;
		double k_dihed = 5.5*eps;
		for(int i = 0; i < natoms; i++) {
			atoms[i].k_bond  = k_bond;
			atoms[i].k_angle = k_angle;
			atoms[i].k_dihed = k_dihed;
		}

	}

	int npairs, nvar;
	int N2 = 9*natoms*natoms;
	int inc = 1;

	std::vector<Links> List;

	if(flag_network == 1) {

/* 	==========================================================================================
		Build cutoff-based elastic network
		modify cutoff if problems with Hessian (more than 6 zero eigenvalues)
   	========================================================================================== */

		do {
			List.clear();
			std::cout << "mindiff = " << mindiff << std::endl;
			net.network(atoms, cutoff, mindiff, List, nthreads);

			double *X;
			npairs = List.size();
			nvar = natoms;
			X = new double[nvar];

			for(int i = 0; i < nvar; i++) {
				if(potential==1) {
					X[i] = 1.0;
				} else {
					X[i] = 120*0.36;
				}
			}

			Uij = new double[3*npairs];
			hess.buildUij(atoms, List, potential, Uij);

			hess.buildK2(List, atoms, X, flag_mean);
			hess.fullHessian(atoms, List, Uij, U1ij, Uijk, Uijkl, hessian);
			dcopy_(&N2, hessian, &inc, eigVect, &inc);
			hess.fullEigen(natoms, eigVal, eigVect);

			if(eigVal[6] < 1.e-10) cutoff += 1;

			delete [] X;

		} while (eigVal[6] < 1.e-10);

	} else {

/* 	==========================================================================================
		Build Delaunay-based elastic network
   	========================================================================================== */

		double *coord = new double[3*natoms];
		double *radii = new double[natoms];
		double *coef  = new double[natoms];

		for(int i = 0; i < natoms; i++) {
			for(int j = 0; j < 3; j++) coord[3*i+j] = atoms[i].coord[j];
			radii[i] = 2.0;
			coef[i] = 1.0;
		}

		std::vector<Vertex> vertices;
		std::vector<Tetrahedron> tetra;

		delcx.setup(natoms, coord, radii, coef, vertices, tetra);
		delcx.regular3D(vertices, tetra);
		
		std::vector<std::pair<int, int> > edges;
		delcx.delaunayEdges(tetra, edges);

		npairs = edges.size();
		double kval = 0.0;
		int i1, j1;
		double r;
		int is = edges.size();
		for(int i = 0; i < is; i++) {
			i1 = edges[i].first;
			j1 = edges[i].second;
			r = distancesq(atoms, i1, j1);
			r = std::sqrt(r);
			Links l(i1, j1, atoms[i1].resid, atoms[j1].resid, kval, r);
			List.push_back(l);
		}

		double *X;
		npairs = List.size();
		nvar = natoms;
		X = new double[nvar];

		for(int i = 0; i < nvar; i++) {
			if(potential==1) {
				X[i] = 1.0;
			} else {
				X[i] = 120.0*0.36;
			}
		}

		Uij = new double[3*npairs];
		hess.buildUij(atoms, List, potential, Uij);

		hess.buildK2(List, atoms, X, flag_mean);
		hess.fullHessian(atoms, List, Uij, U1ij, Uijk, Uijkl, hessian);
		dcopy_(&N2, hessian, &inc, eigVect, &inc);
		hess.fullEigen(natoms, eigVal, eigVect);

		vertices.clear();
		tetra.clear();
		edges.clear();

		delete [] coord; delete [] radii; delete [] coef; delete [] X;

		if(eigVal[6] < 1.e-10) {
			std::cout << "Status: BAD Problem with Delaunay: Null space > 6" << std::endl;
			exit(1);
		}
	}
	std::cout << "Status: OK" << std::endl;

/* 	==========================================================================================
		Optimize kconsts
   	========================================================================================== */

	npair = List.size();

	std::cout << " " << std::endl;
	std::cout << "Cutoff for ENM       : " << cutoff << std::endl;
	std::cout << "Number of pairs found: " << npair << std::endl;
	std::cout << " " << std::endl;

	double check = 0;
	for(int i = 0; i < natoms; i++) {
		check = check + atoms[i].bfact;
	}

	if(check == 0) {
		std::cout << "no B-factors in PDB file..." << std::endl;
		exit(1);
	}

	int nmode1 = 6;
	int nmode2 = 3*natoms;
	if(nmodes==6) nmodes = nmode2;
	nmodes = std::min(3*natoms, nmodes);
	double *bval;
	bval = new double[natoms];

/*
	optim_ki.checkDeriv(atoms, List, nmode1, nmode2, Uij, U1ij, Uijk, Uijkl,
		hessian, eigVal, eigVect, bval, flag_mean, flag_ent, flag_rigid,
		nthreads);

	exit(1);
*/

	double scale = tools.findScale(atoms);
	scale = 1.0;
	tools.scaleBfact(atoms, scale);

	double *rigid = new double[10];
	memset(rigid, 0, 10*sizeof(double));
	if(flag_optim==1) {
		optim_ki.optimkval(atoms, List, nmode1, nmode2, nmodes, Uij, U1ij, Uijk, Uijkl,
			hessian, eigVal, eigVect, bval, flag_mean, flag_ent, flag_rigid,
			rigid, nthreads);
	}

/* 	==========================================================================================
		Check final normal modes
   	========================================================================================== */

	double *err = new double[nmodes];
	double Emin, Emax, Emean;
	hess.checkEigen(N, nmodes, hessian, eigVal, eigVect, err, &Emin, &Emax, &Emean, 
	nthreads);

	std::cout << " " << std::endl;
	std::cout << "Mean error on eigenpairs      : " << Emean << std::endl;
	std::cout << "Min error on eigenpairs       : " << Emin << std::endl;
	std::cout << "Max error on eigenpairs       : " << Emax << std::endl;
	std::cout << " " << std::endl;

	double scale1 = scale;
	tools.scaleEig(nmodes, eigVal, scale1);

/* 	==========================================================================================
	Print eigenvalues and entropy
   	========================================================================================== */

	double Temp = 300;
	double hbar = 1.0546e-34;
	double kb   = 1.3807e-23;
	double t0   = 4.88882e-14;
	double fac  = hbar/(Temp*kb*t0);
	double fac2 = 108.59;

	double ent;
	double x, A, B, expA, expmA;
	double entropy = 0.0;

	std::cout << "================================================================================================="<<std::endl;
	std::cout << "         Mode #      Error     Eigenval       Frequency    entropy (this mode)   entropy (total) "<<std::endl;
	std::cout << "================================================================================================="<<std::endl;

	for(int i = 0; i < nmodes; i++) {
		if(i < 6) {
			ent = 0.0;
			B = 0.0;
		} else {
			x = std::sqrt(eigVal[i]);
			A = x*fac;
			B = x*fac2;
			expA = std::exp(A);
			expmA = 1.0/expA;
			ent = A/(expA-1) - std::log(1-expmA);
		}
		entropy = entropy + ent;
		std::cout << "    " << std::setw(10) << i+1 << "      " ;
		std::cout << std::setw(10) << std::abs(err[i]) << "      " ;
		std::cout << std::setw(10) << std::abs(eigVal[i]) << "      " ;
		std::cout <<  std::setw(10) << B << "       ";
		std::cout << std::setw(10) << ent << "        "; 
		std::cout << std::setw(10) << entropy << std::endl;
	}
	std::cout << "================================================================================================="<<std::endl;
	std::cout << " " << std::endl;

/* 	==========================================================================================
	Print elastic network, b-factors, and eigenvalues;
   	========================================================================================== */

	std::string extension;

	double rms_bfact, correl_bfact;

	extension=".bfact";
	std::string bfile = outputfile;
	bfile.append(extension);

	tools.computeBfact(atoms, rigid, nmode1, nmodes, eigVal, eigVect, bval);
	scale1 = 1./scale;
	tools.scaleBfact(atoms, scale1);
	tools.compareBfact(atoms, bval, &rms_bfact, &correl_bfact); 

	std::cout << "RMS between exp and calc b-fact        :  " << std::setw(10) <<  rms_bfact << std::endl;
	std::cout << "CC between exp and calc b-fact         :  " << std::setw(10) <<  correl_bfact << std::endl;

	wres.writeBfact(bfile, atoms, bval);

	delete [] bval;
	std::cout << " " << std::endl;

	extension=".enm";
	bfile = outputfile;
	bfile.append(extension);
	wres.writeEN(bfile, atoms, List);

/*
	extension="_atm.pdb";
	bfile = outputfile;
	bfile.append(extension);
	wres.writekAtom(bfile, atoms);

	extension=".eig";
	bfile = outputfile;
	bfile.append(extension);
	wres.writeEigen(bfile, nmodes, eigVal);

//	int nc = 50;
//	extension="_traj7.pdb";
//	bfile = outputfile;
//	bfile.append(extension);
//	wres.genTraj(bfile, atoms, nmode1, nmode1, nc, eigVal, eigVect);

//	extension=".nmd";
//	bfile = outputfile;
//	bfile.append(extension);
//	wres.writeNMD(bfile, atoms, nmode1, nmode2, eigVal, eigVect);

	extension="_1.pml";
	bfile = outputfile;
	bfile.append(extension);
	int flag_weight = 0;
	int idx = 1;
	pml.write_network(bfile, atoms, center, List, flag_mean, flag_weight, idx);
*/
	extension=".pml";
	bfile = outputfile;
	bfile.append(extension);
	int flag_weight = 1;
	int idx = 2;
	pml.write_network(bfile, atoms, center, List, flag_mean, flag_weight, idx);
/*
	double *Covar = new double[9*natoms*natoms];
	tools.computeCovar(natoms, nmode1, nmodes, eigVal, eigVect, Covar);

	extension="_covar.dat";
	bfile = outputfile;
	bfile.append(extension);
	wres.writeCOVAR(bfile, natoms, Covar);

	double *Correl = new double[natoms*natoms];
	tools.computeCorrel(natoms, Covar, Correl);

	extension="_correl.dat";
	bfile = outputfile;
	bfile.append(extension);
	wres.writeCORREL(bfile, natoms, Correl);
*/

/* 	==========================================================================================
	If target file is provided, compute overlap between eigenvectors and structural differences
	between structure and target
   	========================================================================================== */

	if(targetfile != " ") {
		std::cout << "target file = " << targetfile << std::endl;
		std::vector<Atoms> atoms2;
		double center2[3];
		int nchains2;
		readPDB(targetfile, atom_type, flag_mass, atoms2, &nchains2, center2);
		int natoms2 = atoms2.size();
		if(natoms!=natoms2) {
			std::cout << "Problem with target file!" << std::endl;
			std::cout << "natoms1 = " << natoms << " natoms2 = " << natoms2 << std::endl;
			std::cout << std::endl;
			exit(1);
		}
		double rms, norm;
		double *rms_over  = new double[nmodes];
		double *overlap    = new  double[nmodes];
		memset(rms_over, 0, nmodes*sizeof(double));
		memset(overlap, 0, nmodes*sizeof(double));

		nmode1 = 0;
                tools.computeOverlap(nmode1, nmodes, eigVect, atoms, atoms2, &rms, rms_over, overlap);
                norm = dnrm2_(&nmodes, overlap, &inc);
                std::cout << "RMS between B_map and target   :  " << std::setw(10) <<  rms_over[nmodes-1] << std::endl;
                std::cout << "Total overlap NM - deformation :  " << std::setw(10) <<  norm << std::endl;
		extension=".over";
		bfile = outputfile;
		bfile.append(extension);
                wres.writeOverlap(bfile, nmodes, rms, rms_over, overlap);
        }
        std::cout << " " << std::endl;

}

/*================================================================================================
  Define Usage
================================================================================================== */

static void usage(char** argv)
{
	std::cout << "\n\n" <<std::endl;
	std::cout << "     " << "=================================================================================="<<std::endl;
	std::cout << "     " << "=================================================================================="<<std::endl;
	std::cout << "     " << "=                                                                                ="<<std::endl;
	std::cout << "     " << "=                            FitNMA                                              ="<<std::endl;
	std::cout << "     " << "=                                                                                ="<<std::endl;
	std::cout << "     " << "=   This program computes the force constant of edges in an elastic              ="<<std::endl;
	std::cout << "     " << "=   network such that its normal modes reproduce experimental Bfactors           ="<<std::endl;
	std::cout << "     " << "=                                                                                ="<<std::endl;
	std::cout << "     " << "=     Usage is:                                                                  ="<<std::endl;
	std::cout << "     " << "=                 FitNMA                                                         ="<<std::endl;
	std::cout << "     " << "=                    -i  <path to input .pdb file>                               ="<<std::endl;
	std::cout << "     " << "=                    -o  <basename BAS for output files>                         ="<<std::endl;
	std::cout << "     " << "=                     (two files generated: BAS.bfact and BAS.pml)               ="<<std::endl;
	std::cout << "     " << "=                    -a  <atoms to include (0) CA, (1) all>        (default: 0)  ="<<std::endl;
	std::cout << "     " << "=                    -e  <(1) Elastic or (2) Delaunay network>     (default: 2)  ="<<std::endl;
	std::cout << "     " << "=                    -c  <cutoff value, only if e=1>               (default: 14) ="<<std::endl;
	std::cout << "     " << "=                    -n  <# of modes for approximation, -1 if all> (default: -1) ="<<std::endl;
	std::cout << "     " << "=                                                                                ="<<std::endl;
	std::cout << "     " << "=================================================================================="<<std::endl;
	std::cout << "     " << "=================================================================================="<<std::endl;
	std::cout << "\n\n" <<std::endl;
}

/*================================================================================================
  Reads in parameters
================================================================================================== */


void parse_args(int argc, char** argv, std::string *inputfile, std::string *outputfile, 
	std::string *targetfile, int *flag_rna, int *atoms, int *flag_network, int *flag_potential, 
	double *cutoff, int *flag_mass, int *nmodes, int *flag_optim, int *flag_rigid, int *flag_mean, 
	int *flag_ent)
{
	std::string input;
	int i;
	for(i = 0; i < argc; i++)
	{
		input = argv[i];
		if(input == "-i")
		{	*inputfile = argv[i+1];         }

		if(input == "-o")
		{	*outputfile = argv[i+1];	}

		if(input == "-d")
		{	*targetfile = argv[i+1];	}

		if(input == "-p")
		{	*flag_rna = atoi(argv[i+1]);	}

		if(input == "-a")
		{	*atoms = atoi(argv[i+1]);	}

		if(input == "-e")
		{	*flag_network = atoi(argv[i+1]); }

		if(input == "-g")
		{	*flag_potential = atoi(argv[i+1]);	}

		if(input == "-c")
		{	*cutoff = atof(argv[i+1]);	}

		if(input == "-n")
		{	*nmodes = atoi(argv[i+1]);	}

		if(input == "-w")
		{	*flag_mass = atoi(argv[i+1]);	}

		if(input == "-t")
		{	*flag_rigid = atoi(argv[i+1]);	}

		if(input == "-m")
		{	*flag_mean = atoi(argv[i+1]);	}

		if(input == "-r")
		{	*flag_ent = atoi(argv[i+1]);	}

		if(input == "-u")
		{	*flag_optim = atoi(argv[i+1]);	}

	}
}
