/*================================================================================================
  MatVect_GPU.h
  Version 1: 12/1/2017

  Purpose: Sets of routine for defining matrix (Hessian) - Vect multiplication

Copyright (c) Patrice Koehl.

>>> SOURCE LICENSE >>>

  This library is free software; you can redistribute it and/or
  modify it under the terms of the GNU Lesser General Public
  License as published by the Free Software Foundation; either
  version 2.1 of the License, or (at your option) any later version.

  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public
  License along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

>>> END OF LICENSE >>>

================================================================================================== */

#ifndef _MATVECTGPU_
#define _MATVECTGPU_

/*================================================================================================
 Includes
================================================================================================== */

#include <math.h>
#include <cstdlib>
#include <vector>

#define BLOCKSIZE 32

/*================================================================================================
 CUDA kernel for computing H *x (spmv "like")
 H is stored in CSR format, with values provided
================================================================================================== */

template <typename T> 
__global__ void spmv_kernel(int num_atoms, int * Ap, int * Aj, T * Ax, T * x, T * y)
{

	// row index
	int atom = threadIdx.x+blockDim.x*blockIdx.x;
	
	int start, end;
	T sum[3], gi[3];
	T dX, val;
	int j;

	if(atom < num_atoms){     
		for(int i = 0; i < 3; i++) sum[i] = 0;

		start = Ap[atom];
		end   = Ap[atom+1];
	
		for (int jj = start; jj < end; jj++){	         

			j = Aj[jj]; val=0;
			for(int k = 0; k < 3; k++) {
				dX =  x[3*atom+k]- x[3*j+k];
				gi[k] = Ax[3*jj+k];
				val += dX*gi[k];
			}

			for(int k = 0; k < 3; k++) {
				sum[k] += val*gi[k];
			}
		}

		for(int i = 0; i < 3; i++) y[3*atom+i] = sum[i];
	}
}

/*================================================================================================
 CUDA kernel for computing H *x (spmv "like") for Go backbone potential
 H is stored in CSR format, with values provided
================================================================================================== */

template <typename T> 
__global__ void spmv_go_kernel(int num_atoms, T * bonds, T *angles, T *diheds, T * x, T * y)
{

	// row index
	int atom = threadIdx.x+blockDim.x*blockIdx.x;
	
	T val, val2;
	T Xi[12];
	T vi[12];

/*================================================================================================
 	Hessian-vect: bond contribution
================================================================================================== */

	int iatm2, iatm1, iat0, iat1, iat2;

	iat0 = atom-1; iat1 = atom; iat2 = atom+1;
	if(atom > 0) {
		for(int k = 0; k < 3; k++) {
			vi[k]   = bonds[3*iat0+k];
			vi[k+3] = -vi[k];
			Xi[k]   = x[3*iat0+k];
			Xi[k+3] = x[3*iat1+k];
		}
		val = 0.0;
		for(int k = 0; k < 6; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k+3];
	}
	if(atom < num_atoms-1) {
		for(int k = 0; k < 3; k++) {
			vi[k]   = bonds[3*iat1+k];
			vi[k+3] = -vi[k];
			Xi[k]   = x[3*iat1+k];
			Xi[k+3] = x[3*iat2+k];
		}
		val = 0.0;
		for(int k = 0; k < 6; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k];
	}

/*================================================================================================
 	Hessian-vect: angle contribution
================================================================================================== */

	iatm1 = atom-2; iat0 = atom-1; iat1 = atom; iat2 = atom+1;
	if(atom > 1) {
		for(int k = 0; k < 9; k++) {
			vi[k]   = angles[9*iatm1+k];
			Xi[k]   = x[3*iatm1+k];
		}
		val = 0.0;
		for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k+6];
	}
	if(atom>0 && atom < num_atoms-1) {
		for(int k = 0; k < 9; k++) {
			vi[k]   = angles[9*iat0+k];
			Xi[k]   = x[3*iat0+k];
		}
		val = 0.0;
		for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k+3];
	}
	if(atom < num_atoms-2) {
		for(int k = 0; k < 9; k++) {
			vi[k]   = angles[9*iat1+k];
			Xi[k]   = x[3*iat1+k];
		}
		val = 0.0;
		for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k];
	}

/*================================================================================================
 	Hessian-vect: dihedral angle contribution
================================================================================================== */

	iatm2 = atom -3; iatm1 = atom - 2; iat0 = atom - 1; iat1 = atom;

	if(atom > 2) {
		for(int k = 0; k < 12; k++) {
			vi[k]   = diheds[12*iatm2+k];
			Xi[k]   = x[3*iatm2+k];
		}

		val = 0;
		for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k+9];
	}

	if(atom > 1 && atom < num_atoms-1) {
		for(int k = 0; k < 12; k++) {
			vi[k]   = diheds[12*iatm1+k];
			Xi[k]   = x[3*iatm1+k];
		}

		val = 0;
		for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k+6];
	}

	if(atom > 0 && atom < num_atoms - 2) {
		for(int k = 0; k < 12; k++) {
			vi[k]   = diheds[12*iat0+k];
			Xi[k]   = x[3*iat0+k];
		}

		val = 0;
		for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k+3];
	}

	if(atom < num_atoms - 3) {
		for(int k = 0; k < 12; k++) {
			vi[k]   = diheds[12*iat1+k];
			Xi[k]   = y[3*iat1+k];
		}

		val = 0;
		for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
		for(int k = 0; k < 3; k++) y[3*iat1+k] += val*vi[k];
	}

}

/*================================================================================================
  spmv_GPU

	Purpose:
	========
	Computes y = (A * x)
	where A is a symmetric sparse matrix corresponding to the Hessian and x is a vector

	Arguments:
	==========

	N       (input) integer
		On entry, number of rows / cols of the matrix A

	M       (input) integer
		On entry, the number of columns of X

	x       (input) array of floats or double, size N*M
		On entry, input matrix x

	ldx     (input) int
		On entry, leading dimension of x

	y       (output) array of floats or double, size N*M
		On exit, vector y = a*A*x + b*y

	ldx     (input) int
		On entry, leading dimension of y

	mvparams (input) structure
		On entry, information about the matriX A


================================================================================================== */

template <typename T>
void spmv_GPU(int N, int Nvect, T * x, int ldx, T * y, int ldy, void *mvparams)
{

	opparams<T> *params;
	params = (opparams<T> *) mvparams;

	cudaMemset(y, 0, N*sizeof(T));

	int natoms = N/3;
	int NBLOCKS = (natoms+THREADS_PER_BLOCK-1)/THREADS_PER_BLOCK;

	if(params->csrHessian->type==2) {

		T *bonds  = params->csrHessian->bonds;
		T *angles = params->csrHessian->angles;
		T *diheds = params->csrHessian->diheds;

		spmv_go_kernel<T><<<NBLOCKS, THREADS_PER_BLOCK>>> (natoms, bonds, angles, diheds, x, y);   

	}

        int *ia   = params->csrHessian->ia;
        int *ja   = params->csrHessian->ja;
        T *val    = params->csrHessian->val;

	spmv_kernel<T><<<NBLOCKS, THREADS_PER_BLOCK>>> (natoms, ia, ja, val, x, y);   
}

/* ======================================================================================
 CSR format sparse matrix A multiplied by dense matrix B gives dense matrix C
 A is m x k, B is k x n, C is m x n
 Each element of A, B, or C is a 3D vector.
 A(i,j) *B(j) is really 
 		c(i,j) += <B(i)-B(j), A(i,j)> A(i,j)
 where < , > stands for the dot product of the vector B(i)-B(j) with the vector A(i,j)
   ====================================================================================== */

template <typename T>
__global__ void spmm_csr_2d(
	int m,                            // Matrix is 3*m x 3*m
	int N,                            // Number of columns in B,
	int* A_rowptr,       // CSR row pointers [m+1]
	int* A_colind,       // CSR column indices [NNZ]
	T* A_values,         // CSR values [3*NNZ]
	T* B,                // Dense matrix B [3*m×n] in column-major
	T* C )                     // Output (dense) matrix C [3*m×n] in column-major
{

	int M = 3*m;

	// Each thread handles one element of the output matrix C
	int row = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
	int col = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);
	
	// Early exit if thread is out of bounds
	if (row >= m || col >= N) return;
	
	int row_start, row_end, j_col;
	T val;
	T sum[3], dA[3], dB[3];

	for(int i = 0; i < 3; i++) sum[i] = 0.0;
	
	// Get the range of non-zero elements for this row
	row_start = A_rowptr[row];
	row_end = A_rowptr[row + 1];
	
	// Compute dot product of sparse row with dense column
	for (int i = row_start; i < row_end; i++) {
		j_col = A_colind[i];

		// Load A values into registers
		#pragma unroll
		for(int l = 0; l < 3; l++) dA[l] = A_values[3*i+l];
		val = 0;
		for(int l = 0; l < 3; l++) {
			dB[l] = B[3*row+l + col*M] - B[3*j_col+l+col*M];
			val += dB[l]*dA[l];
		}
		for(int l = 0; l < 3; l++) sum[l] += val*dA[l];
	}
	
	// Write result to C (column major)
	for(int i = 0; i < 3; i++) C[3*row + i + col * M] = sum[i];

}

/* ======================================================================================
 CSR format sparse matrix A multiplied by dense matrix B gives dense matrix C
 A is m x k, B is k x n, C is m x n
 Each element of A, B, or C is a 3D vector.
 A(i,j) *B(j) is really 
 		c(i,j) += <B(i)-B(j), A(i,j)> A(i,j)
 where < , > stands for the dot product of the vector B(i)-B(j) with the vector A(i,j)
   ====================================================================================== */

template <typename T>
__global__ void spmm_go_2d(
	int m,                            // Matrix is 3*m x 3*m
	int N,                            // Number of columns in B,
	T* bonds,      
	T* angles,       
	T* diheds,     
	T* X,                // Dense matrix B [3*m×n] in column-major
	T* Y )                     // Output (dense) matrix C [3*m×n] in column-major
{

	int M = 3*m;

	// Each thread handles one element of the output matrix C
	int atom = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE);
	int col = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);
	
	// Early exit if thread is out of bounds
	if (atom >= m || col >= N) return;
	
	T val, val2;
	T Xi[12];
	T vi[12];

/*================================================================================================
 Hessian-vect: bond contribution
================================================================================================== */

	int iatm2, iatm1, iat0, iat1, iat2;
	int natoms = m;

	iat0 = atom-1; iat1 = atom; iat2 = atom+1;
	if(atom > 0) {
		for(int k = 0; k < 3; k++) {
			vi[k]   = bonds[3*iat0+k];
			vi[k+3] = -vi[k];
			Xi[k]   = X[col*M+3*iat0+k];
			Xi[k+3] = X[col*M+3*iat1+k];
		}
		val = 0.0;
		for(int k = 0; k < 6; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k+3];
	}
	if(atom < natoms-1) {
		for(int k = 0; k < 3; k++) {
			vi[k]   = bonds[3*iat1+k];
			vi[k+3] = -vi[k];
			Xi[k]   = X[col*M+3*iat1+k];
			Xi[k+3] = X[col*M+3*iat2+k];
		}
		val = 0.0;
		for(int k = 0; k < 6; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k];
	}

/*================================================================================================
 Hessian-vect: angle contribution
================================================================================================== */

	iatm1 = atom-2; iat0 = atom-1; iat1 = atom; iat2 = atom+1;
	if(atom > 1) {
		for(int k = 0; k < 9; k++) {
			vi[k]   = angles[9*iatm1+k];
			Xi[k]   = X[col*M+3*iatm1+k];
		}
		val = 0.0;
		for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k+6];
	}
	if(atom>0 && atom < natoms-1) {
		for(int k = 0; k < 9; k++) {
			vi[k]   = angles[9*iat0+k];
			Xi[k]   = X[col*M+3*iat0+k];
		}
		val = 0.0;
		for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k+3];
	}
	if(atom < natoms-2) {
		for(int k = 0; k < 9; k++) {
			vi[k]   = angles[9*iat1+k];
			Xi[k]   = X[col*M+3*iat1+k];
		}
		val = 0.0;
		for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k];
	}

/*================================================================================================
 Hessian-vect: dihedral angle contribution
================================================================================================== */

	iatm2 = atom -3; iatm1 = atom - 2; iat0 = atom - 1; iat1 = atom;

	if(atom > 2) {
		for(int k = 0; k < 12; k++) {
			vi[k]   = diheds[12*iatm2+k];
			Xi[k]   = X[col*M+3*iatm2+k];
		}

		val = 0;
		for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k+9];
	}

	if(atom > 1 && atom < natoms-1) {
		for(int k = 0; k < 12; k++) {
			vi[k]   = diheds[12*iatm1+k];
			Xi[k]   = X[col*M+3*iatm1+k];
		}

		val = 0;
		for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k+6];
	}

	if(atom > 0 && atom < natoms - 2) {
		for(int k = 0; k < 12; k++) {
			vi[k]   = diheds[12*iat0+k];
			Xi[k]   = X[col*M+3*iat0+k];
		}

		val = 0;
		for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k+3];
	}

	if(atom < natoms - 3) {
		for(int k = 0; k < 12; k++) {
			vi[k]   = diheds[12*iat1+k];
			Xi[k]   = X[col*M+3*iat1+k];
		}

		val = 0;
		for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
		for(int k = 0; k < 3; k++) Y[col*M+3*iat1+k] += val*vi[k];
	}

}

/*================================================================================================
  spmm_GPU

	Purpose:
	========
	Computes y = (A * x)
	where A is a symmetric sparse matrix corresponding to the Hessian and x is a matrix

	Arguments:
	==========

	N       (input) integer
		On entry, number of rows / cols of the matrix A

	M       (input) integer
		On entry, the number of columns of X

	x       (input) array of floats or double, size N*M
		On entry, input matrix x

	ldx     (input) int
		On entry, leading dimension of x

	y       (output) array of floats or double, size N*M
		On exit, vector y = a*A*x + b*y

	ldx     (input) int
		On entry, leading dimension of y

	mvparams (input) structure
		On entry, information about the matriX A


================================================================================================== */

template <typename T>
void spmm_GPU(int N, int M, T * x, int ldx, T * y, int ldy, void *mvparams)
{

	opparams<T> *params;
	params = (opparams<T> *) mvparams;

	int m = M/3;

	cudaMemset(y, 0, N*M*sizeof(T));

	// Use smaller block size but process more columns per thread
	dim3 block(BLOCKSIZE*BLOCKSIZE);
	dim3 grid((m + BLOCKSIZE - 1) / BLOCKSIZE,
              (N + BLOCKSIZE - 1) / BLOCKSIZE);
    
	if(params->csrHessian->type==2) {

		T *bonds  = params->csrHessian->bonds;
		T *angles = params->csrHessian->angles;
		T *diheds = params->csrHessian->diheds;

		spmm_go_2d<T><<<grid, block>>> (m, N, bonds, angles, diheds, x, y);   
	}


        int *ia   = params->csrHessian->ia;
        int *ja   = params->csrHessian->ja;
        T *val    = params->csrHessian->val;

	spmm_csr_2d<T><<<grid, block>>> (m, N, ia, ja, val, x, y);   

}

/*================================================================================================
  spmmED_GPU

	Purpose:
	========
	Computes y = ((A +USigmaU^T)* x)
	where A is a symmetric sparse matrix stored in CSR format and x is a matrix
	and U is (used for explicit deflation when computing eigenpairs of the matrix A)

	Arguments:
	==========

	N   	(input) integer
		On entry, number of rows / cols of the matrix A

	M	(input) integer
		On entry, the number of columns of X

	d_x	(input) array of floats or double, size N*M
		On entry, input matrix x (on device)

	ldx	(input) int
		On entry, leading dimension of x
		
	d_y	(output) array of floats or double, size N*M
		On exit, vector y = A*x (on device)

	ldy	(input) int
		On entry, leading dimension of y
		
	mvparams (input) structure
		On entry, information about the matrices A, U, and Sigma


================================================================================================== */

template <typename T>
void spmmED_GPU(int N, int Nvect, T *d_x, int ldx, T *d_y, int ldy,  void *mvparams) {

	opparams<T> *params;
	params = (opparams<T> *) mvparams;

	int n_eed = params->need;
	T *d_Space  = params->space;
	T *h_sigma  = params->sigma;

	T alpha, beta;
	char Trans = 'T'; char noTrans = 'N';

/*      ==========================================================================================
	compute A*X using spmm
	========================================================================================== */

	spmm_GPU(N, Nvect, d_x, ldx, d_y, ldy,  mvparams);

/*      ==========================================================================================
	Now compute U Sigma U^T X using dgemm
	========================================================================================== */

	alpha = 1; beta = 0.0;
	eig_dgemm_(&Trans, &noTrans, &n_eed, &Nvect, &N, &alpha, params->Ud, &N, d_x, &N, 
		&beta, d_Space, &n_eed);

	int inc = n_eed;
	for(int i = 0; i < n_eed; i++) {
		alpha = h_sigma[i]; 
		eig_dscal_(&Nvect, &alpha, &d_Space[i], &inc);
	} 

	alpha = 1; beta = 1.0;
	eig_dgemm_(&noTrans, &noTrans, &N, &Nvect, &n_eed, &alpha, params->Ud, &N, 
		d_Space, &n_eed, &beta, d_y, &N);

}

/*================================================================================================
  spmvED_GPU

	Purpose:
	========
	Computes y = ((A +USigmaU^T)* x)
	where A is a symmetric sparse matrix stored in CSR format and x is a vector
	and U is (used for explicit deflation when computing eigenpairs of the matrix A)

	Arguments:
	==========

	N   	(input) integer
		On entry, number of rows / cols of the matrix A

	M	(input) integer
		On entry, just 1! (used for consistency)

	d_x	(input) array of floats or double, size N*M
		On entry, input matrix x (on device)

	ldx	(input) int
		On entry, leading dimension of x
		
	d_y	(output) array of floats or double, size N*M
		On exit, vector y = A*x (on device)

	ldy	(input) int
		On entry, leading dimension of y
		
	mvparams (input) structure
		On entry, information about the matrices A, U, and Sigma


================================================================================================== */

template <typename T>
void spmvED_GPU(int N, int M, T *d_x, int ldx, T *d_y, int ldy, void *mvparams) {

	opparams<T> *params;
	params = (opparams<T> *) mvparams;

	int n_eed   = params->need;

	T *d_Space  = params->space;
	T *d_sigma  = params->d_sigma;

	T alpha, beta;
	int inc = 1;
	char Trans = 'T'; char noTrans = 'N';

/*      ==========================================================================================
	compute A*X using spmv
	========================================================================================== */

	spmv_GPU(N, M, d_x, ldx, d_y, ldy, mvparams);

/*      ==========================================================================================
	Now compute U Sigma U^T X using dgemm
	========================================================================================== */

	alpha = 1; beta = 0.0; 
	eig_dgemv_(&Trans, &N, &n_eed, &alpha, params->Ud, &N, d_x, &inc, 
		&beta, d_Space, &inc);

	int NBLOCKS = (n_eed+THREADS_PER_BLOCK-1)/THREADS_PER_BLOCK;
	elementWise<T><<<NBLOCKS, THREADS_PER_BLOCK>>>(d_sigma, d_Space, n_eed);

	alpha = 1; beta = 1.0;
	eig_dgemv_(&noTrans, &N, &n_eed, &alpha, params->Ud, &N, 
		d_Space, &inc, &beta, d_y, &inc);

}

/*================================================================================================
  inplace_dgemm

	Purpose:
	========
	Computes A = A * B (i.e. inplace matrix = matrix computation)

	A is N*M, while B = M*M

	Arguments:
	==========

	N   	(input) integer
		On entry, number of rows of the matrix A

	M	(input) integer
		On entry, number of cols of the matrix A

	d_A	(input) array of floats or double, size N*M
		On entry, input matrix A (on device)
		On exit, matrix A * B

	lda	(input) int
		On entry, leading dimension of A
		
	d_B	(output) array of floats or double, size M*M
		On entry, input matrix B (on device)

	ldb	(input) int
		On entry, leading dimension of B
		
================================================================================================== */

template <typename T>
void inplace_dgemm_GPU(int N, int M, T* d_A, int lda, T* d_B, int ldb) {
	const int MAX_M = 75;  

	int threads = 128;
	int blocks  = N;


	if (M <= MAX_M) {
		inplace_dgemm_kernel<T, MAX_M><<<blocks, threads>>>(N, M, d_A, lda, d_B, ldb);
	} else {
		size_t shared_bytes = M * sizeof(T);  // one row only
		inplace_dgemm_largeM_kernel<T><<<blocks, threads, shared_bytes>>>(N, M, d_A, lda, d_B, ldb);
	}
	cudaDeviceSynchronize();  // optional: useful for debugging
}

#endif
