/*================================================================================================
  Init.h
  Version 1: 06/21/2024

  Purpose: Finds the smallest N eigenvalues / eigenvectors of a real symmetric matrix

	   All methods use Chebyshev filtering

Copyright (c) Patrice Koehl.

>>> SOURCE LICENSE >>>

  This library is free software; you can redistribute it and/or
  modify it under the terms of the GNU Lesser General Public
  License as published by the Free Software Foundation; either
  version 2.1 of the License, or (at your option) any later version.

  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public
  License along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

>>> END OF LICENSE >>>

================================================================================================== */

#ifndef _INIT_
#define _INIT_

/*================================================================================================
 Includes
================================================================================================== */

#include <math.h>
#include <iostream>
#include <cstdlib>
#include <numeric>

/*================================================================================================
 Local Includes
================================================================================================== */

/*================================================================================================
  Define a class for computing eigenvectors / eigenvalues of a symmetric sparse matrix
  stored as sum of outer products
================================================================================================== */

template <typename T>
  class INIT {

	public:

		// initialize info struct
		void init_info(eig_info<T> * info, int N, int act_max, int nblock, int mpoly, 
			int method, int flag_start, int nev, int ned, T tol, int nthreads, void *mvparams);

		// initialize params struct
		void init_params(eig_info<T> * info, hessianMat<T> *csrMat, T *eigVal, T *eigVect);

		// print information about run from structure info
		void printinfo(eig_info<T> *info, T orth);

	private:

  };
/*================================================================================================
  init_info

	Purpose:
	========
	Initializes an INFO structure. This function must be called before calling
	any other user level routine based on Lanczos

	Arguments:
	==========
	info    (input) pointer to the structure eig_info
	         On entry, points to the data structure to store the information
	         about the eigenvalue problem and the progress of chebLan

	N       (input) integer
	         On entry, specifies the local dimension of the problem.

	act_max (input) integer
	         On entry, specifies the size of the active space used

	nblock   (input) integer
	         On entry, specifies the block size used

	mpoly	(input) integer
	         On entry, specifies the order of the Chebishev polynomial fiter

	method	(input) integer
		On entry, method used for solving eigenpair problem

	flag_start (input) integer
		 On entry, flag:
			- if 0, fresh restart
			- if 1, used existing vectors in eigVect to start

	nev      (input) integer
	          On entry, specifies the number of wanted eigenvalues and eigenvectors.

	ned      (input) integer
	          On entry, specifies the number of deflation cycles to use

	tol      (optional) float or double
	          If provided, specifies the tolerance on residual norm. By default,
	          tol is set to be sqrt(epsilon).

	mvparams (input) pointer
		 Pointer to params for matrix vector operator

	nthreads (input) integer
		 On entry, number of threads available for computation

================================================================================================== */

template <typename T>
void INIT<T>::init_info(eig_info<T> * info, int N, int act_max, int nblock,
		int mpoly, int method, int start, int nev, int ned, T tol, int nthreads, void *mvparams)
{

/*================================================================================================
	Parameters
================================================================================================== */

#if defined(DOUBLE)
	T dmin  = DBL_MIN;
	T eps   = DBL_EPSILON;
	T orthTol = 2.2e-14; 
#else
	T dmin  = FLT_MIN;
	T eps   = FLT_EPSILON;
	T orthTol = 2.2e-6; 
#endif

	T EPS_MULT = 10;
	info->epsilon  = EPS_MULT*eps;
	info->itmax    = 1000;
	info->orthTol  = orthTol;

/*================================================================================================
	set tolerance
================================================================================================== */

	if (tol > 0) {
		info->tol = tol;
		if (info->tol <= dmin) {
			info->tol = eps;
		} else if (info->tol > 1.0) {
			info->tol = std::min(0.1, 1.0 / (info->tol));
		}
	} else {
		info->tol = std::sqrt(eps);
	}

/*================================================================================================
	Initialisation of info based on input parameters:
		n:		size of the matrix
		act_max:	active space
		block:		block size
		mpoly:		Chebishev order
		nev:		# of eigenpairs to be computed
		ned:		# of deflation cycles
================================================================================================== */

	int nev_d = nev/ned + 1;
	if(nev%ned==0) nev_d -= 1;
	info->n          = N;
	info->nev        = nev;
	info->ned        = ned;
	info->nev_d      = nev_d;
	info->act_max    = act_max;
	info->block      = nblock;
	info->mpoly      = mpoly;
	info->flag_start = start;
	info->flag_zero  = 1;

	info->method     = method;

	info->current_slice = 0;

	info->slices.clear();

/*================================================================================================
	Verbose
================================================================================================== */

	info->verbose = 1;

/*================================================================================================
	Info about computing time
================================================================================================== */

	info->clk_tot  = 0;
	info->clk_op   = 0;
	info->clk_orth = 0;

	info->walltime = 0;

	info->mvp      = 0;
	info->north    = 0;

/*================================================================================================
	Passing information for Matrix-Vector operations
================================================================================================== */

	info->mvparams = mvparams;

/*================================================================================================
	Number of threads available
================================================================================================== */

	info->nthreads = nthreads;

/*================================================================================================
	Type of Chebishev filtering: initialize to 3 (polynomial expansion)
	set_filter set to 1: define filter
================================================================================================== */

	info->filter_type = 3;

	info->mu = new T[1000];

/*================================================================================================
	Space needed for dsyev on GPU
================================================================================================== */

#if defined(GPU)
	cusolverEigMode_t jobz_GPU    = CUSOLVER_EIG_MODE_VECTOR;
	cublasFillMode_t uplo_GPU     = CUBLAS_FILL_MODE_UPPER;
	int bufferSize;
	T *A = NULL; T *W = NULL;
	cusolverDnDsyevd_bufferSize(cusolverH, jobz_GPU, uplo_GPU, act_max, A, N, W,
                        &bufferSize);
	info->bufferSize = bufferSize;
#endif


}

/*================================================================================================
  init_params

	Purpose:
	========
	Initializes the params structure. 

	Arguments:
	==========
	info    (input) pointer to the structure eig_info
	         On entry, points to the data structure to store the information
	         about the eigenvalue problem and the progress of chebLan
	
	csrMat	(input) pointer to sparse amtrix structure
		On entry, points to the matrix A on host

	eigVal	(input) pointer to array containing eigenvalues

	eigVect	(input) pointer to array containing eigenvectors

================================================================================================== */


template <typename T>
void INIT<T>::init_params(eig_info<T> * info, hessianMat<T> *csrMat, T *eigVal, T *eigVect)
{

	opparams<T> *params;
	params    = (opparams<T> *) info->mvparams;

	params->csrHessian  = csrMat;
	params->need        = 0;
	params->Ud          = eigVect;
	params->eigVal      = eigVal;
	params->nthreads    = info->nthreads;

	params->sigma       = NULL;
	params->sigma2      = NULL;
	params->space       = NULL;

	T *SpaceED, *h_SpaceED;
	if(info->method % 3 != 0) {

		int lsize, lsize1, lsize2;
		lsize1 = 2*info->nev;
		lsize2 = 2*info->nev * info->block;
		lsize  = 2*lsize1 + lsize2;

		h_SpaceED = new T[2*lsize1];
		params->sigma    = h_SpaceED;
		params->sigma2   = &h_SpaceED[lsize1];

#if defined(GPU)
		cudaMalloc((void **)&SpaceED, lsize*sizeof(T));
		params->d_sigma  = SpaceED; 
		params->d_sigma2 = &SpaceED[lsize1];
		params->space    = &SpaceED[2*lsize1];
#else
		SpaceED = new T[lsize];
		params->space    = SpaceED;
#endif
	}
}

/*================================================================================================
 printinfo

	Purpose:
	========
	prints information on the run that were stored in the structure "info"

	Arguments:
	==========

	info    (input) pointer to the structure eig_info
	         On entry, points to the data structure to store the information
	         about the eigenvalue problem and the progress of chebDav
	         On exit, points to the initialized data structure.

	T orth	(input) float or double
		On input, error on orthonormality

================================================================================================== */

 template <typename T>
 void INIT<T>::printinfo(eig_info<T> *info, T orth)
{

	double t_op, t_orth, t_tot, t_wall;
	double memsize, memsize1, memsize2, memsize3;
	int n_mvp, n_orth; 
	long lwork, leigen, lmat;
	int lanm1 = info->act_max + 1;

	t_op   = info->clk_op / double(CLOCKS_PER_SEC);
	t_orth = info->clk_orth / double(CLOCKS_PER_SEC);
	t_tot  = info->clk_tot / double(CLOCKS_PER_SEC);
	t_wall = info->walltime;

	n_mvp  = info->mvp;
	n_orth = info->north;

	lwork = 
		info->n * lanm1		// For active space
		+ 3*info->n             // work space
		+ lanm1 * lanm1 	// for projection matrix
		+ 2* lanm1;		// for Ritz values

	leigen =
		1.2 * info->nev 	// For eigenvalues
		+ 1.2 * info->n * info->nev; // For eigenvectors
	
	opparams<T> *params;
	params    = (opparams<T> *) info->mvparams;

	lmat   = params->csrHessian->nnz;

	memsize1 =  lwork  * sizeof(T) * 1.e-6;
	memsize2 =  leigen * sizeof(T) * 1.e-6;
	memsize3 =  lmat   * sizeof(T) * 1.e-6;
	memsize  = memsize1 + memsize2 + memsize3;

	std::cout << std::endl;
	std::cout << "=============================================statistics==================================================" << std::endl;
	std::cout << " Poly     Niter        Time (MV)      Time (orth)    Time (total)    # MV       #ORTH      orthogonality " << std::endl;
	std::cout << std::setw(5) << info->mpoly << "    ";
	std::cout << std::setw(5)  << info->niter << "      ";
	std::cout << std::setw(10) << t_op << "      ";
	std::cout << std::setw(10) << t_orth << "      ";
	std::cout << std::setw(10) << t_tot << "  ";
	std::cout << std::setw(10) << n_mvp << " ";
	std::cout << std::setw(10) << n_orth << "      ";
	std::cout << std::setw(10) << orth ;
	std::cout << std::endl;
	std::cout << "=========================================================================================================" << std::endl;
	std::cout << std::endl;
	std::cout << std::endl;
	std::cout << "Total walltime for the run: " << t_wall << " seconds" << std::endl;
	std::cout << std::endl;
	std::cout << "Memory for storing matrix : " << memsize3 << " MB" << std::endl;
	std::cout << "Memory for eigenpairs     : " << memsize2 << " MB" << std::endl;
	std::cout << "Memory for active space   : " << memsize1 << " MB" << std::endl;
	std::cout << std::endl;
	std::cout << "Total memory              : " << memsize << " MB" << std::endl;
	std::cout << std::endl;

}
#endif
