/*================================================================================================
  ProteinNM
  Version 1: 08/06/2024

Copyright (c) Patrice Koehl.

>>> SOURCE LICENSE >>>

  This library is free software; you can redistribute it and/or
  modify it under the terms of the GNU Lesser General Public
  License as published by the Free Software Foundation; either
  version 2.1 of the License, or (at your option) any later version.

  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public
  License along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

>>> END OF LICENSE >>>

================================================================================================== */

/*================================================================================================
  All includes
================================================================================================== */

#include <sys/time.h>
#include <unistd.h>
#include "ProteinNM.h"

#if defined(GPU)
#include "setGPU.h"
#endif

/*================================================================================================
  Main
================================================================================================== */

int main(int argc, char** argv)
{

#if defined(GPU)

/*	==========================================================================================
	Create handles for CUDA libraries, if using GPU
	========================================================================================== */

	// cuBLAS
	cublasCreate(&handle);

	// cuRAND
	curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
	curandSetPseudoRandomGeneratorSeed(gen, 1234ULL);

	// cuSolver
	cusolverDnCreate (&cusolverH);

#endif

/*	==========================================================================================
         Show usage
	========================================================================================== */

	if(argc < 2)
	{
		usage(argv);
		return -1;
	}

	std::string input = argv[1];
	if(input == "-h" || input == "-help")
	{
		usage(argv);
		return -1;
	}

/* 	==========================================================================================
	Some hard-coded parameters
   	========================================================================================== */

	T tol;						// tolerance for eigenpairs
#if defined(DOUBLE)
	tol = 1.e-8;
#else
	tol = 1.e-6;
#endif

/*	==========================================================================================
         Read command line arguments
	========================================================================================== */

	std::string fileIN=" ";
	std::string fileIN2=" ";
	std::string fileOUT=" ";

	int atom_type   = 0;
	int potential   = 1;
	int enm_type    = 0;
	int filter      = 0;
	int method      = 0;
	int nmodes      = 100;
	int ndeflations = 1;
	int maxlan      = 100;
        int mpoly       = 0;
	int order       = 0;
	int nthreads	= sysconf( _SC_NPROCESSORS_ONLN );
	nthreads        = std::max(nthreads, 1);
	int nblocks     = nthreads;
	T cutoff        = 14;
	T kconst        = 1.0;

	read_flags(argc, argv, &fileIN, &fileOUT, &fileIN2, &atom_type, &potential, &enm_type, &cutoff, &kconst, 
		&filter, &order, &nmodes, &ndeflations, &method, &maxlan, &mpoly, &nblocks, &nthreads);

/*	==========================================================================================
         Adjustments
	========================================================================================== */

	if(potential==2) atom_type = 0; 		// If GO potential, only CA

	int mindiff = 0;				// r_ij considered if abs(residue(i) - residue(j)) > mindiff
	if(potential==2) mindiff = 4;

	int flag_mass = 0;

/* 	==========================================================================================
	Count atoms in PDB file, generate array of vertices, and read in coordinates
   	========================================================================================== */

	std::vector<Atoms<T> > atoms;
	int nchains;
	std::size_t found = fileIN.find("pdb");
	if(found !=std::string::npos) {
		readinput.readFromPDB(fileIN, atom_type, flag_mass, atoms, &nchains);
	} else {
		found = fileIN.find("pqr");
		if(found !=std::string::npos) {
			readinput.readFromPQR(fileIN, atom_type, flag_mass, atoms, &nchains);
		} else {
			found = fileIN.find("cif");
			if(found !=std::string::npos) {
				readinput.readFromCIF(fileIN, atom_type, flag_mass, atoms, &nchains);
			} else {
				std::cout << " " << std::endl;
				std::cout << "Input file format not recognized; program can only read PDB, PQR, and CIF files" << std::endl;
				std::cout << " " << std::endl;
				exit(1);
			}
		}
	}

	int natoms = (int) atoms.size();

	std::cout << " " << std::endl;
	std::cout << "Number of atoms found in file      : " << natoms << std::endl;
	std::cout << "Number of chains in protein complex: " << nchains << std::endl;
	std::cout << " " << std::endl;

/* 	==========================================================================================
	Build elastic network
   	========================================================================================== */

	timeval tim;
	double t1, t2;

	gettimeofday(&tim,NULL);
	t1 = tim.tv_sec + tim.tv_usec*1.e-6;

	std::vector<Edges<T> > List;

	if(enm_type == 1) {
		std::cout << "ENM with cutoff" << std::endl;
		std::cout << "cutoff, mindiff: " << cutoff << " " << mindiff << std::endl;
		net.network(atoms, cutoff, mindiff, List, nthreads);
		std::cout << "done..." << std::endl;
	} else if (enm_type == 2) {
		std::cout << "ENM from Delaunay" << std::endl;
		net.delaunay(atoms, mindiff, potential, List);
	}

	if(filter==1) {
		std::vector<Edges<T> > Listf;
		filterR.minimGraph(potential, atoms, List, Listf, nthreads);
		List.clear();
		for(int i = 0; i < (int) Listf.size(); i++) {
			Edges<T> l(Listf[i].atm1, Listf[i].atm2);
			List.push_back(l);
		}
	}

	gettimeofday(&tim,NULL);
	t2 = tim.tv_sec + tim.tv_usec*1.e-6;

	std::cout << " " << std::endl;
	std::cout << "Network running time: " << t2 - t1 << "seconds" << std::endl;
	std::cout << "Number of pairs found: " << List.size() << std::endl;
	std::cout << " " << std::endl;

	// Write elastic network
	if(fileOUT != " ") {
		std::string ext=".pml";
		std::string nfile = fileOUT;
		nfile.append(ext);
		wres.write_pml(nfile, atoms, List);
	}

/* 	==========================================================================================
	Build Hessian; optionally, apply reverse cutHill-McKee ordering
   	========================================================================================== */

	// Load Hessian

	hessianMat<T> *csrHessian = new hessianMat<T>;
	hess.buildH(atoms, List, kconst, potential, csrHessian);
	if(potential==2) {
		csrHessian->type = 2;
		hess.buildGO(atoms, csrHessian);
	} else {
		csrHessian->type = 1;
	}


	int d1, d2;
	int bandwidth = rcm.computeBandwidth(natoms, csrHessian->ia, csrHessian->ja);
	rcm.computeDegrees(natoms, csrHessian->ia, &d1, &d2);

	std::cout << std::endl;
	std::cout << "Hessian matrix bandwidth           :       " << bandwidth << std::endl;
	std::cout << "ENM largest vertex degree          :       " << d1 << std::endl;
	std::cout << "ENM second largest vertex degree   :       " << d2 << std::endl;
	std::cout << std::endl;

	std::vector<int> perm(natoms);
	std::iota(perm.begin(), perm.end(), 0);

	if(order==1) {
		int nnz = csrHessian->nnz;
		int *new_ia = new int[natoms+1];
		int *new_ja = new int[nnz];
		T *new_val  = new T[3*nnz];
		perm = rcm.computeRCM(natoms, csrHessian->ia, csrHessian->ja);
		rcm.buildRCM(natoms, csrHessian->ia, csrHessian->ja, csrHessian->val, perm, 
                       new_ia, new_ja, new_val);

		for(int i = 0; i < natoms+1; i++) csrHessian->ia[i] = new_ia[i];
		for(int i = 0; i < nnz; i++) csrHessian->ja[i] = new_ja[i];
		for(int i = 0; i < 3*nnz; i++) csrHessian->val[i] = new_val[i];

		delete [] new_ia; delete [] new_ja; delete [] new_val;

		bandwidth = rcm.computeBandwidth(natoms, csrHessian->ia, csrHessian->ja);

		std::cout << std::endl;
		std::cout << "After reverse cutHill-McKee reordering: " << std::endl;
		std::cout << "Hessian matrix bandwidth           :       " << bandwidth << std::endl;
		std::cout << std::endl;
	}

	//if GPU, Transfer matrix on GPU
#if defined(GPU)
	hessianMat<T> *csrHessian_GPU = new hessianMat<T>;
	transferCSR2GPU(csrHessian, csrHessian_GPU);
	csrHessian = csrHessian_GPU;
#endif

/* 	==========================================================================================
 	Define arrays for eigenvalues and eigenvectors
	eigenvalues:  on CPU
	eigenvectors: on GPU or CPU, depending on space
   	========================================================================================== */

	int N = 3*natoms;
	int dim_max; 

	if(method==0) {
		dim_max = N;
	} else {
		dim_max = nmodes + maxlan;
	} 

	long isize = N*dim_max*sizeof(T);

	T *eigVal = NULL, *eigVect = NULL; 
	eigVal = new T[dim_max];
	std::memset(eigVal, 0, dim_max*sizeof(T));

#if defined(GPU)
	cudaMalloc((void **)&eigVect, isize);
	cudaMemset(eigVect, 0, isize);
#else
	eigVect = new T[N*dim_max];
	std::memset(eigVect, 0, isize);
#endif

/* 	==========================================================================================
	Initialization
   	========================================================================================== */

	int nconv=0;					// Number of converged eigenpairs
	T Anorm=0;
	eig_info<T> *info = new eig_info<T>;
	std::vector<std::vector<double> > edinfo; // vector containing information on the actual run

	opparams<T> *params = new opparams<T>;
	int istart = 0;

	if(method==0 ) {
		core.init_info(info, N, maxlan, nblocks, mpoly, method, istart, nmodes, 
		ndeflations, tol, nthreads, params);
		core.init_params(info, csrHessian, eigVal, eigVect);
	} else {
		if(mpoly==0) mpoly = 80;
		dav.init_info(info, N, maxlan, nblocks, mpoly, method, istart, nmodes, 
			ndeflations, tol, nthreads, params);
		dav.init_params(info, csrHessian, eigVal, eigVect);
	}


	if(method == 0) {

/* 		====================================================================================
		use full diagonalization; only works for small system and on CPU!
   		==================================================================================== */

		hess.fullHessian(csrHessian, eigVect);

		hess.fullEigen(natoms, eigVal, eigVect);
		nconv = N;
		Anorm = eigVal[N-1];

	} else {

/* 		=====================================================================================
		Chebishev block Jacobi-Davidson
		1. Work space 
   		===================================================================================== */

		int lwork1 = N*(nblocks+maxlan) + maxlan*(maxlan+1);
		T *work;

#if defined(GPU)
		cudaMalloc((void **)&work, lwork1*sizeof(T));
#else
		work = new T[lwork1];
#endif

/* 		=====================================================================================
		3. Now run CBJD method, with explicit deflation, either:
			-naive, replacing A with A+USU^T (method==1)
			-conmmunication avoiding (method==2)
   		===================================================================================== */

		dav.chebDav(spmm<T>, spmm<T>, info, N, eigVal, eigVect, &nconv, work);
		Anorm = info->anrm;

#if defined(GPU)
		cudaFree(work);
#else
		delete [] work;
#endif

	} 

/* 	==========================================================================================
	Check eigen pairs
   	========================================================================================== */

	T *err = new T[nconv];
	T Emin=0, Emax=0, Emean=0;
	anal.checkEigVal(spmv<T>, N, nconv, Anorm, eigVal, eigVect, err, 
		&Emin, &Emax, &Emean, params, nthreads);

	T ortho=0;
	std::vector<double> orth;
	ortho = anal.checkOrtho(N, nconv, eigVect);

	std::cout << " " << std::endl;
	std::cout << "Mean error on eigenpairs      : " << Emean << std::endl;
	std::cout << "Min error on eigenpairs       : " << Emin << std::endl;
	std::cout << "Max error on eigenpairs       : " << Emax << std::endl;
	std::cout << "Orthogonalization error       : " << ortho << std::endl;
	std::cout << " " << std::endl;

/* 	==========================================================================================
	Print eigenvalues
   	========================================================================================== */

	std::cout << "========================================================================="<<std::endl;
	std::cout << "         Mode #        Error           Eigenval                          "<<std::endl;
	std::cout << "========================================================================="<<std::endl;

	int nmax = std::min(nconv, nmodes);

	for(int i = 0; i < nmax ; i++) {
		std::cout << "    " << std::setw(10) << i+1 << "      ";
		std::cout << std::setw(10) << std::abs(err[i]) << "      " << std::setw(10) << eigVal[i] << "      ";
		std::cout << std::endl;
	}
	std::cout << "========================================================================="<<std::endl;
	std::cout << " " << std::endl;


	if(method == 0) {
		core.printinfo(info, ortho);
	}

/* 	==========================================================================================
	Compute B-factors
   	========================================================================================== */

	T check = 0;
	for(int i = 0; i < natoms; i++)  check = check + atoms[i].bfact;

	if(nmax > 0  && check > 0.0) {
		T *rms_bfact = new T[nmax]; T *correl_bfact = new T[nmax];
		memset(rms_bfact, 0, nmax*sizeof(T)); memset(correl_bfact, 0, nmax*sizeof(T));
		std::string bfile = " ";
		std::string bfile2 = " ";
		if(fileOUT != " ") {
			std::string extension=".bfact"; 
			std::string extension2=".bfact2"; 
			bfile = fileOUT; bfile.append(extension);
			bfile2 = fileOUT; bfile2.append(extension2);
		}

		int nmode1 = 6;
		T *bfact = new T[natoms];
		anal.computeBfact(natoms, natoms, nmode1, nmax, eigVal, eigVect, atoms, perm, bfact, 
		rms_bfact, correl_bfact);

		std::cout << std::endl;
		std::cout << "RMS between exp and calc b-fact:  " << std::setw(10) <<  rms_bfact[nmax-1] << std::endl;
		std::cout << "CC between exp and calc b-fact :  " << std::setw(10) <<  correl_bfact[nmax-1] << std::endl;

		if(bfile != " " ) wres.writeBfact(bfile, atoms, bfact);
		if(bfile2 != " " ) wres.writeBfact2(bfile2, nmode1, nmax, rms_bfact, correl_bfact);

		if(fileIN2 != " ") {
			anal.Overlap(atom_type, atoms, nmode1, nmax, eigVect, check, rms_bfact, correl_bfact,
			fileIN2, fileOUT);
		}
		delete [] bfact; delete [] correl_bfact;
	}
	std::cout << " " << std::endl;

}

/*================================================================================================
  Define Usage
================================================================================================== */

static void usage(char** argv)
{
	std::cout << "\n\n" <<std::endl;
	std::cout << "     " << "============================================================================================"<<std::endl;
	std::cout << "     " << "============================================================================================"<<std::endl;
	std::cout << "     " << "=                                                                                          ="<<std::endl;
	std::cout << "     " << "=                                     ProteinNM                                            ="<<std::endl;
	std::cout << "     " << "=                                                                                          ="<<std::endl;
	std::cout << "     " << "=     This program computes some eigenpairs of the Hessian of the ENM of a protein         ="<<std::endl;
	std::cout << "     " << "=     Usage is:                                                                            ="<<std::endl;
	std::cout << "     " << "=                 ProteinNM                                                                ="<<std::endl;
	std::cout << "     " << "=                    -i  <input file for protein> (in PDB / CIF format)                    ="<<std::endl;
	std::cout << "     " << "=                    -o  <basename for output files>                                       ="<<std::endl;
	std::cout << "     " << "=                    -a  <atoms to include (0) CA, (1) all> (Default: 0)                   ="<<std::endl;
	std::cout << "     " << "=                    -g  <(0) Tirion or (1) Go potential>   (Default: 0)                   ="<<std::endl;
	std::cout << "     " << "=                    -e  <(1) cutoff or (2) Delaunay ENM>   (Default: 1)                   ="<<std::endl;
	std::cout << "     " << "=                    -f  <(0) no filter, (1), rigidity filter> (Default: 0)                ="<<std::endl;
	std::cout << "     " << "=                    -cm <(0) no ordering, (1), reverse cuHill-McKee ordering (Default: 0) ="<<std::endl;
	std::cout << "     " << "=                    -k  <K-constant for springs>  (Default: 1)                            ="<<std::endl;
	std::cout << "     " << "=                    -c  <cutoff value>  (Default: 15)                                     ="<<std::endl;
	std::cout << "     " << "=                    -as <dimension of the active space  (Default 400)                     ="<<std::endl;
	std::cout << "     " << "=                    -cf <order of Chebishev filter)         (Default 60)                  ="<<std::endl;
	std::cout << "     " << "=                    -b  <block size> (only for ChebDav)     (Default 16)                  ="<<std::endl;
	std::cout << "     " << "=                    -ne <approx. number of eigenpairs>     (Default: 100)                 ="<<std::endl;
	std::cout << "     " << "=                    -m  <eigen method>   (Default: 0)                                     ="<<std::endl;
	std::cout << "     " << "=                        0: full diagonalization (only for small systems, no GPU!)         ="<<std::endl;
        std::cout << "     " << "=                        1: Chebyschev Jacobi Davidson                                     ="<<std::endl;
	std::cout << "     " << "=                    -p  <# of threads> (if on CPU)                                        ="<<std::endl;
	std::cout << "     " << "=                    -t  <input file for target> (in PDB / CIF  format) (default: none)    ="<<std::endl;
	std::cout << "     " << "============================================================================================"<<std::endl;
	std::cout << "     " << "============================================================================================"<<std::endl;
	std::cout << "\n\n" <<std::endl;
}

/*================================================================================================
  Reads in parameters
================================================================================================== */

void read_flags(int argc, char** argv, std::string *fileIN, std::string *fileOUT, std::string *fileIN2,
		int *atom_type, int *potential, int *enm_type, T *cutoff, T *kconst, int *filter,
		int *order, int *nev, int *nd, int *method, int *maxlan, int *mpoly, int *nblocks, int *nthreads)
{
	std::string input;
	int i;
	for(i = 0; i < argc; i++)
	{
		input = argv[i];
		if(input == "-i")
		{	*fileIN = argv[i+1];         }

		if(input == "-t")
		{	*fileIN2 = argv[i+1];         }

		if(input == "-o")
		{	*fileOUT = argv[i+1];         }

		if(input == "-a")
		{	*atom_type = std::atoi(argv[i+1]); }

		if(input == "-e")
		{	*enm_type = std::atoi(argv[i+1]); }

		if(input == "-f")
		{	*filter   = std::atoi(argv[i+1]); }

		if(input == "-cm")
		{	*order   = std::atoi(argv[i+1]); }

		if(input == "-c")
		{	*cutoff   = std::atof(argv[i+1]); }

		if(input == "-k")
		{	*kconst   = std::atof(argv[i+1]); }

		if(input == "-ne")
		{	*nev = std::atoi(argv[i+1]);    }

		if(input == "-nd")
		{	*nd = std::atof(argv[i+1]);	}

		if(input == "-as")
		{	*maxlan = std::atoi(argv[i+1]);	}

		if(input == "-cf")
		{	*mpoly = std::atoi(argv[i+1]);	}

		if(input == "-m")
		{	*method = std::atoi(argv[i+1]);	}

		if(input == "-p")
		{	*nthreads = atoi(argv[i+1]);	}

		if(input == "-g")
		{	*potential = atoi(argv[i+1]);	}

		if(input == "-b")
		{	*nblocks = atoi(argv[i+1]);	}


	}
}
