/*================================================================================================
  BlasWrapper.h
  Version 1: 06/21/2024

  Purpose: Define wrappers to BLAS and LAPACK functions

Copyright (c) Patrice Koehl.

>>> SOURCE LICENSE >>>

  This library is free software; you can redistribute it and/or
  modify it under the terms of the GNU Lesser General Public
  License as published by the Free Software Foundation; either
  version 2.1 of the License, or (at your option) any later version.

  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public
  License along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

>>> END OF LICENSE >>>

================================================================================================== */

#ifndef _BLASWRAPPER_
#define _BLASWRAPPER_

/*================================================================================================
 Includes
================================================================================================== */

#include <math.h>
#include <iostream>
#include <cstdlib>

/*================================================================================================
  Prototypes for BLAS and LAPACK
================================================================================================== */

extern "C" {

	int idamax_(int * n, double *X, int *incx);
	int isamax_(int * n, float *X, int *incx);

	double dasum_(int * n, double *X, int *incx);
	float sasum_(int * n, float *X, int *incx);

	void daxpy_(int * n ,double *alpha , double * X, int *incx, double * Y,int *incy);
	void saxpy_(int * n ,float *alpha , float * X, int *incx, float * Y,int *incy);

	double dnrm2_(int * n, double * X, int *incx);
	float snrm2_(int * n, float * X, int *incx);

	void dscal_(int * n, double * alpha, double * X, int *incx);
	void sscal_(int * n, float * alpha, float * X, int *incx);

	void dcopy_(int * n, double * X, int *incx, double * Y, int *incy);
	void scopy_(int * n, float * X, int *incx, float * Y, int *incy);

	void dswap_(int * n, double * X, int *incx, double *Y, int *incy);
	void sswap_(int * n, float * X, int *incx, float *Y, int *incy);

	double ddot_(int * n, double * u, int * incu, double * v, int *incv);
	float sdot_(int * n, float * u, int * incu, float * v, int *incv);

	void dgemv_(char * trans, int * m, int * n, double * alpha, double *A,
		int *lda, double * X, int * incx, double *beta, double * Y, int * incy);
	void sgemv_(char * trans, int * m, int * n, float * alpha, float *A,
		int *lda, float * X, int * incx, float *beta, float * Y, int * incy);

	void dgemm_(char * transa, char * transb, int * m, int * n, int * k,
		double * alpha, double * A, int * lda,
		double * B, int * ldb, double * beta, double * C, int * ldc);
	void sgemm_(char * transa, char * transb, int * m, int * n, int * k,
		float * alpha, float * A, int * lda,
		float * B, int * ldb, float * beta, float * C, int * ldc);

	void dsyrk_(char * uplo, char * trans, int * n, int * k, double * alpha,
		double * A, int *lda, double * beta, double *C, int *ldc);
	void ssyrk_(char * uplo, char * trans, int * n, int * k, double * alpha,
		double * A, int *lda, double * beta, double *C, int *ldc);

	void dstev_(char *jobz, int *n,  double *d,  double *e, double *z, int *ldz,
		double *work, int *info);
	void sstev_(char *jobz, int *n,  float *d,  float *e, float *z, int *ldz,
		float *work, int *info);

        void dsyevd_(char * JOBZ, char * UPLO, int *N, double *A, int *LDA, double *W, 
		double *WORK, int *LWORK, int *IWORK, int *LIWORK, int *INFO);
        void ssyevd_(char * JOBZ, char * UPLO, int *N, float *A, int *LDA, float *W, 
		float *WORK, int *LWORK, int *IWORK, int *LIWORK, int *INFO);

        void dsyev_(char * JOBZ, char * UPLO, int *N, double *A, int *LDA, double *W, 
		double *WORK, int *LWORK, int *INFO);
        void ssyev_(char * JOBZ, char * UPLO, int *N, float *A, int *LDA, float *W, 
		float *WORK, int *LWORK, int *INFO);

	void dgeqrf_(int *M, int *N, double *A, int *LDA, double *TAU, double *WORK, 
		int *LWORK, int *INFO);
	void sgeqrf_(int *M, int *N, float *A, int *LDA, float *TAU, float *WORK, 
		int *LWORK, int *INFO);

	void dtrsm_(char *SIDE, char *UPLO, char *TRANSA, char *DIAG, int *m, int *n,
		double *alpha, double *A, int *lda, double *B, int *ldb);
	void strsm_(char *SIDE, char *UPLO, char *TRANSA, char *DIAG, int *m, int *n,
		float *alpha, float *A, int *lda, float *B, int *ldb);

	void dpotrf_(char *UPLO, int *N, double *A, int *LDA, int *INFO);
	void spotrf_(char *UPLO, int *N, float *A, int *LDA, int *INFO);

	void dgesvd_(char *JOBU, char *JOBV, int *m, int *n, double *a, int *lda, double *s,
		double *u, int *ldu, double *vt, int *ldv, double *work, int *lwork, int *info);

	void sgesvd_(char *JOBU, char *JOBV, int *m, int *n, float *a, int *lda, float *s,
		float *u, int *ldu, float *vt, int *ldv, float *work, int *lwork, int *info);


	int eig_idamax_(int * n ,T * X, int *incx);

	T eig_dasum_(int * n ,T * X, int *incx);
	T host_dasum_(int * n ,T * X, int *incx);

	void host_daxpy_(int * n ,T *alpha , T * X, int *incx, T * Y,int *incy);
	void eig_daxpy_(int * n ,T *alpha , T * X, int *incx, T * Y,int *incy);
	void eig_dscal_(int * n, T * alpha, T * X, int *incx);
	void host_dscal_(int * n, T * alpha, T * X, int *incx);
	void eig_dcopy_(int * n, T * X, int *incx, T * Y, int *incy);
	void eig_dswap_(int * n, T * X, int *incx, T * Y, int *incy);

	void host_dcopy_(int * n, T * X, int *incx, T * Y, int *incy);

	T eig_ddot_(int * n, T * u, int * incu, T * v, int *incv);
	T eig_dnrm2_(int * n, T * X, int *incx);

	void eig_dgemv_(char * trans, int * m, int * n, T * alpha, T *A,
		int *lda, T * X, int * incx, T *beta, T * Y, int * incy);
	void eig_dgemm_(char * transa, char * transb, int * m, int * n, int * k,
		T * alpha, T * A, int * lda,
		T * B, int * ldb, T * beta, T * C, int * ldc);
	void eig_dsyrk_(char * uplo, char * trans, int * n, int * k, T * alpha,
		T * A, int *lda, T * beta, T *C, int *ldc);

	void host_dstev_(char *jobz, int *n,  T *d,  T *e, T *z, int *ldz,
		T *work, int *info);
	void eig_dstev_(char *jobz, int *n,  T *d,  T *e, T *z, int *ldz,
		T *work, int *info);
        void eig_dsyevd_(char * JOBZ, char * UPLO, int *N, T *A, int *LDA, T *W, 
		T *WORK, int *LWORK, int *IWORK, int *LIWORK, int *INFO);
        void eig_dsyev_(char * JOBZ, char * UPLO, int *N, T *A, int *LDA, T *W, 
		T *WORK, int *LWORK, int *INFO);

	void eig_dgeqrf_(int *M, int *N, T *A, int *LDA, T *TAU, T *WORK, 
		int *LWORK, int *INFO);

	void eig_dtrsm_(char *SIDE, char *UPLO, char *TRANSA, char *DIAG, int *m, int *n,
		T *alpha, T *A, int *lda, T *B, int *ldb);
	void eig_dpotrf_(char *UPLO, int *N, T *A, int *LDA, T *WORK, int *INFO);

	void eig_dgesvd_(char *JOBU, char *JOBV, int *m, int *n, T *a, int *lda, T *s,
		T *u, int *ldu, T *vt, int *ldv, T *work, int *lwork, int *info);
}

/*================================================================================================
  Define interfaces to BLAS / LAPACK functions
================================================================================================== */

inline void eig_daxpy_(int * n ,T *alpha , T * X, int *incx, T * Y,int *incy) {

#if defined(GPU)
	#if defined(DOUBLE)
        	cublasDaxpy(handle, *n, alpha, X, *incx, Y, *incy);
	#else
        	cublasSaxpy(handle, *n, alpha, X, *incx, Y, *incy);
	#endif
#else
	#if defined(DOUBLE)
		daxpy_(n, alpha, X, incx, Y, incy);
	#else
		saxpy_(n, alpha, X, incx, Y, incy);
	#endif
#endif

}

inline void host_daxpy_(int * n ,T *alpha , T * X, int *incx, T * Y,int *incy) {

#if defined(DOUBLE)
	daxpy_(n, alpha, X, incx, Y, incy);
#else
	saxpy_(n, alpha, X, incx, Y, incy);
#endif

}

inline int eig_idamax_(int * n ,T * X, int *incx) {

#if defined(GPU)
	int result;
	#if defined(DOUBLE)
		cublasIdamax(handle, *n, X, *incx, &result);
	#else
		cublasIsamax(handle, *n, X, *incx, &result);
	#endif
	return result;
#else
	#if defined(DOUBLE)
		return idamax_(n, X, incx);
	#else
		return isamax_(n, X, incx);
	#endif
#endif

}

inline T eig_dasum_(int * n ,T * X, int *incx) {

#if defined(GPU)
	T result;
	#if defined(DOUBLE)
		cublasDasum(handle, *n, X, *incx, &result);
	#else
		cublasSasum(handle, *n, X, *incx, &result);
	#endif
	return result;
#else
	#if defined(DOUBLE)
		return dasum_(n, X, incx);
	#else
		return sasum_(n, X, incx);
	#endif
#endif

}

inline T host_dasum_(int * n ,T * X, int *incx) {

#if defined(DOUBLE)
	return dasum_(n, X, incx);
#else
	return sasum_(n, X, incx);
#endif

}

inline void eig_dscal_(int * n, T * alpha, T * X, int *incx) {

#if defined(GPU)
	#if defined(DOUBLE)
		cublasDscal(handle, *n, alpha, X, *incx);
	#else
		cublasSscal(handle, *n, alpha, X, *incx);
	#endif
#else
	#if defined(DOUBLE)
		dscal_(n, alpha, X, incx);
	#else
		sscal_(n, alpha, X, incx);
	#endif
#endif

}

inline void host_dscal_(int * n, T * alpha, T * X, int *incx) {

#if defined(DOUBLE)
	dscal_(n, alpha, X, incx);
#else
	sscal_(n, alpha, X, incx);
#endif

}

inline void eig_dcopy_(int * n, T * X, int *incx, T * Y, int *incy) {

#if defined(GPU)
	#if defined(DOUBLE)
		cublasDcopy(handle, *n, X, *incx, Y, *incy);
	#else
		cublasScopy(handle, *n, X, *incx, Y, *incy);
	#endif
#else
	#if defined(DOUBLE)
		dcopy_(n, X, incx, Y, incy);
	#else
		scopy_(n, X, incx, Y, incy);
	#endif
#endif

}

inline void host_dcopy_(int * n, T * X, int *incx, T * Y, int *incy) {

#if defined(DOUBLE)
	dcopy_(n, X, incx, Y, incy);
#else
	scopy_(n, X, incx, Y, incy);
#endif

}

inline void eig_dswap_(int * n, T * X, int *incx, T * Y, int *incy) {

#if defined(GPU)
	#if defined(DOUBLE)
		cublasDswap(handle, *n, X, *incx, Y, *incy);
	#else
		cublasSswap(handle, *n, X, *incx, Y, *incy);
	#endif
#else
	#if defined(DOUBLE)
		dswap_(n, X, incx, Y, incy);
	#else
		sswap_(n, X, incx, Y, incy);
	#endif
#endif

}

inline T eig_ddot_(int * n, T * u, int * incu, T * v, int *incv) {

#if defined(GPU)
	T result;
	#if defined(DOUBLE)
		cublasDdot(handle, *n, u, *incu, v, *incv, &result);
	#else
		cublasSdot(handle, *n, u, *incu, v, *incv, &result);
	#endif
	return result;
#else
	#if defined(DOUBLE)
		return ddot_(n, u, incu, v, incv);
	#else
		return sdot_(n, u, incu, v, incv);
	#endif
#endif

}

inline T eig_dnrm2_(int * n, T * X, int *incx) {

#if defined(GPU)
	T result;
	#if defined(DOUBLE)
		cublasDnrm2(handle, *n, X, *incx, &result);
	#else
		cublasSnrm2(handle, *n, X, *incx, &result);
	#endif
	return result;
#else
	#if defined(DOUBLE)
		return dnrm2_(n, X, incx);
	#else
		return snrm2_(n, X, incx);
	#endif
#endif

}

inline void eig_dgemv_(char * trans, int * m, int * n, T * alpha, T *A,
		int *lda, T * X, int * incx, T *beta, T * Y, int * incy) {

#if defined(GPU)
	cublasOperation_t trans_GPU = CUBLAS_OP_N;
	if(*trans=='T') trans_GPU = CUBLAS_OP_T;
	#if defined(DOUBLE)
		cublasDgemv(handle, trans_GPU, *m, *n, alpha, A, *lda, X, *incx, 
			beta, Y, *incy);
	#else
		cublasSgemv(handle, trans_GPU, *m, *n, alpha, A, *lda, X, *incx, 
			beta, Y, *incy);
	#endif
#else
	#if defined(DOUBLE)
		dgemv_(trans, m, n, alpha, A, lda, X, incx, beta, Y, incy);
	#else
		sgemv_(trans, m, n, alpha, A, lda, X, incx, beta, Y, incy);
	#endif
#endif

}

inline void eig_dgemm_(char * transa, char * transb, int * m, int * n, int * k,
		T * alpha, T * A, int * lda,
		T * B, int * ldb, T * beta, T * C, int * ldc) {

#if defined(GPU)
	cublasOperation_t transa_GPU = CUBLAS_OP_N;
	if(*transa=='T') transa_GPU = CUBLAS_OP_T;
	cublasOperation_t transb_GPU = CUBLAS_OP_N;
	if(*transb=='T') transb_GPU = CUBLAS_OP_T;
	#if defined(DOUBLE)
		cublasDgemm(handle, transa_GPU, transb_GPU, *m, *n, *k, alpha, 
			A, *lda, B, *ldb, beta, C, *ldc);
	#else
		cublasSgemm(handle, transa_GPU, transb_GPU, *m, *n, *k, alpha, 
			A, *lda, B, *ldb, beta, C, *ldc);
	#endif
#else
	#if defined(DOUBLE)
		dgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
	#else
		sgemm_(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
	#endif
#endif

}

inline void eig_dsyrk_(char * uplo, char * trans, int * n, int * k, T * alpha,
		T * A, int *lda, T * beta, T *C, int *ldc) {

#if defined(GPU)
	cublasOperation_t trans_GPU = CUBLAS_OP_N;
	if(*trans=='T') trans_GPU = CUBLAS_OP_T;
	cublasFillMode_t uplo_GPU = CUBLAS_FILL_MODE_LOWER;
	if(*uplo=='U') uplo_GPU = CUBLAS_FILL_MODE_UPPER;
	#if defined(DOUBLE)
		cublasDsyrk(handle, uplo_GPU, trans_GPU, *n, *k, alpha, 
			A, *lda, beta, C, *ldc);
	#else
		cublasSsyrk(handle, uplo_GPU, trans_GPU, *n, *k, alpha, 
			A, *lda, beta, C, *ldc);
	#endif
#else
	#if defined(DOUBLE)
		dsyrk_(uplo, trans, n, k, alpha, A, lda, beta, C, ldc);
	#else
		ssyrk_(uplo, trans, n, k, alpha, A, lda, beta, C, ldc);
	#endif
#endif

}

inline void eig_dstev_(char *jobz, int *n,  T *d,  T *e, T *z, int *ldz,
		T *work, int *info) {

#if defined(GPU)
	std::cout << "dstev is an undefined function on GPU" << std::endl;
	exit(1);
#else
	#if defined(DOUBLE)
		dstev_(jobz, n, d, e, z, ldz, work, info);
	#else
		sstev_(jobz, n, d, e, z, ldz, work, info);
	#endif
#endif

}


inline void host_dstev_(char *jobz, int *n,  T *d,  T *e, T *z, int *ldz,
		T *work, int *info) {

#if defined(DOUBLE)
	dstev_(jobz, n, d, e, z, ldz, work, info);
#else
	sstev_(jobz, n, d, e, z, ldz, work, info);
#endif

}

inline void eig_dsyevd_(char * JOBZ, char * UPLO, int *N, T *A, int *LDA, T *W, 
		T *WORK, int *LWORK, int *IWORK, int *LIWORK, int *INFO) {

#if defined(GPU)
	std::cout << "dsyevd is different on GPU" << std::endl;
	exit(1);
#else
	#if defined(DOUBLE)
		dsyevd_(JOBZ, UPLO, N, A, LDA, W, WORK, LWORK, IWORK, LIWORK, INFO);
	#else
		ssyevd_(JOBZ, UPLO, N, A, LDA, W, WORK, LWORK, IWORK, LIWORK, INFO);
	#endif
#endif

}

inline void eig_dsyev_(char * JOBZ, char * UPLO, int *N, T *A, int *LDA, T *W, 
		T *WORK, int *LWORK, int *INFO) {

#if defined(GPU)
	cusolverEigMode_t jobz_GPU    = CUSOLVER_EIG_MODE_VECTOR;
	if(*JOBZ=='N') jobz_GPU = CUSOLVER_EIG_MODE_NOVECTOR;
	cublasFillMode_t uplo_GPU = CUBLAS_FILL_MODE_LOWER;
	if(*UPLO=='U') uplo_GPU = CUBLAS_FILL_MODE_UPPER;
	int *devInfo;
	cudaMalloc((void **)&devInfo, sizeof(int));
	#if defined(DOUBLE)
        	cusolverDnDsyevd(cusolverH, jobz_GPU, uplo_GPU, *N, A, *LDA, W, 
			WORK, *LWORK, devInfo);
	#else
        	cusolverDnSsyevd(cusolverH, jobz_GPU, uplo_GPU, *N, A, *LDA, W, 
			WORK, *LWORK, devInfo);
	#endif
	cudaMemcpy(INFO, devInfo, sizeof(int), cudaMemcpyDeviceToHost);
#else
	#if defined(DOUBLE)
		dsyev_(JOBZ, UPLO, N, A, LDA, W, WORK, LWORK, INFO);
	#else
		ssyev_(JOBZ, UPLO, N, A, LDA, W, WORK, LWORK, INFO);
	#endif
#endif

}

inline void eig_dgeqrf_(int *M, int *N, T *A, int *LDA, T *TAU, T *WORK, 
		int *LWORK, int *INFO) {

#if defined(GPU)
	int *devInfo;
	cudaMalloc((void **)&devInfo, sizeof(int));
	#if defined(DOUBLE)
		cusolverDnDgeqrf(cusolverH, *M, *N, A, *LDA, TAU, 
			WORK, *LWORK, devInfo);
	#else
		cusolverDnSgeqrf(cusolverH, *M, *N, A, *LDA, TAU, 
			WORK, *LWORK, devInfo);
	#endif
	cudaMemcpy(INFO, devInfo, sizeof(int), cudaMemcpyDeviceToHost);
#else
	#if defined(DOUBLE)
		dgeqrf_(M, N, A, LDA, TAU, WORK, LWORK, INFO);
	#else
		sgeqrf_(M, N, A, LDA, TAU, WORK, LWORK, INFO);
	#endif
#endif

}

inline void eig_dtrsm_(char *SIDE, char *UPLO, char *TRANSA, char *DIAG, int *m, int *n,
		T *alpha, T *A, int *lda, T *B, int *ldb) {

#if defined(GPU)
	cublasOperation_t trans_GPU = CUBLAS_OP_N;
	if(*TRANSA=='T') trans_GPU = CUBLAS_OP_T;
	cublasFillMode_t uplo_GPU = CUBLAS_FILL_MODE_LOWER;
	if(*UPLO=='U') uplo_GPU = CUBLAS_FILL_MODE_UPPER;
	cublasSideMode_t side_GPU = CUBLAS_SIDE_LEFT;
	if(*SIDE=='R' || *SIDE=='r') side_GPU = CUBLAS_SIDE_RIGHT;
	cublasDiagType_t diag_GPU = CUBLAS_DIAG_NON_UNIT;
	if(*DIAG=='U' || *DIAG=='u') diag_GPU = CUBLAS_DIAG_UNIT;
	#if defined(DOUBLE)
		cublasDtrsm(handle, side_GPU, uplo_GPU, trans_GPU, diag_GPU, *m, *n, 
			alpha, A, *lda, B, *ldb);
	#else
		cublasStrsm(handle, side_GPU, uplo_GPU, trans_GPU, diag_GPU, *m, *n, 
			alpha, A, *lda, B, *ldb);
	#endif
#else
	#if defined(DOUBLE)
		dtrsm_(SIDE, UPLO, TRANSA, DIAG, m, n, alpha, A, lda, B, ldb);
	#else
		strsm_(SIDE, UPLO, TRANSA, DIAG, m, n, alpha, A, lda, B, ldb);
	#endif
#endif

}

inline void eig_dpotrf_(char *UPLO, int *N, T *A, int *LDA, T *WORK, int *INFO) {

#if defined(GPU)
	cublasFillMode_t uplo_GPU = CUBLAS_FILL_MODE_LOWER;
	if(*UPLO=='U') uplo_GPU = CUBLAS_FILL_MODE_UPPER;
	int *devInfo;
	cudaMalloc((void **)&devInfo, sizeof(int));
	int bufferSize;
	#if defined(DOUBLE)
		cusolverDnDpotrf_bufferSize(cusolverH, uplo_GPU, *N, A, *LDA, &bufferSize);
		cusolverDnDpotrf(cusolverH, uplo_GPU, *N, A, *LDA, WORK, bufferSize, devInfo);
	#else
		cusolverSnDpotrf_bufferSize(cusolverH, uplo_GPU, *N, A, *LDA, &bufferSize);
		cusolverSnDpotrf(cusolverH, uplo_GPU, *N, A, *LDA, WORK, bufferSize, devInfo);
	#endif
	cudaMemcpy(INFO, devInfo, sizeof(int), cudaMemcpyDeviceToHost);
#else
	#if defined(DOUBLE)
		dpotrf_(UPLO, N, A, LDA, INFO);
	#else
		spotrf_(UPLO, N, A, LDA, INFO);
	#endif
#endif

}

 inline void eig_dgesvd_(char *JOBU, char *JOBV, int *m, int *n, T *a, int *lda, T *s,
		T *u, int *ldu, T *vt, int *ldv, T *work, int *lwork, int *info)
 {
#if defined(DOUBLE)
	dgesvd_(JOBU, JOBV, m, n, a, lda, s, u, ldu, vt, ldv, work, lwork, info);
#else
	sgesvd_(JOBU, JOBV, m, n, a, lda, s, u, ldu, vt, ldv, work, lwork, info);
#endif

}

#endif
