#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cmath>

#ifdef DQMC_CUDA
// #include <cublas_v2.h>
#include <cublas.h>
#include <cuda_runtime.h>
#include "multb.h"
#endif

#include "blaslapack.h"
#include "profile.h"

// Exported functions called from Fortran
extern "C" {
  void computeg_init_(int *n, int *L, int *fixw);
  void computeg_(int *n, double *B, int *L, int *nOrth, double *h, int *il, double *G, double *sgn, double *det);
  void computeg_free_(void);
  void swapg_(int *n, double *B, double *Bi, double *h, double *G);
}

// Matrix copy
void mat_copy(int m, int n, const double *A, int lda, double *B, int ldb)
{
  if ((m == lda) && (m == ldb)) memcpy(B, A, sizeof(double) * m * n);
  else {
    int i;
    for (i = 0; i < n; i++) 
      memcpy(B + ldb * i, A + lda * i, sizeof(double) * m); 
  }
}

// Computes log(abs(det(M))) and sign(det(M)) from a LU decomposition
void sgndet(int n, double *M, int ldm, int *ipiv, double &sgn, double &det)
{
  sgn = 1.0; det = 0.0;
  for (int i = 0; i < n; i++) {
    if (ipiv[i] != (i + 1)) sgn = -sgn;
    if (M[i * ldm + i] < 0.0) {
      sgn = -sgn;
      det += log(-M[i * ldm + i]);
    } else {
      det += log(M[i * ldm + i]); 
    }
  }  
}

// Dynamic memory storage (reused between calls)
static double *tau;
static int *ipiv;
static double *A;
static double *Q;
static double *T;
static double **M_cache;
static int L;
static double *D;
static double *Db;
static int lwork;
static double *work;

#ifdef DQMC_CUDA
// static cublasHandle_t handle;
static double *B_gpu;
static double *Bi_gpu;
static double *h_gpu;
static double *w1_gpu;
static double *w2_gpu;
#endif

void computeg_init_(int *pn, int *pL, int *fixw)
{
  try {
  L = *pL;
  int n = *pn, m = 2 * n;

#ifdef DQMC_CUDA
  // Start CUBLAS library
  CUDACHECK(cublasInit());
  // CUDACHECK(cublasCreate(&handle));
  // Allocate buffers in the graphics memory
  CUDACHECK(cudaMalloc((void**)&B_gpu, n * n * sizeof(double)));
  CUDACHECK(cudaMalloc((void**)&Bi_gpu, n * n * sizeof(double)));
  CUDACHECK(cudaMalloc((void**)&h_gpu, n * L * sizeof(double)));
  CUDACHECK(cudaMalloc((void**)&w1_gpu, n * n * sizeof(double)));
  CUDACHECK(cudaMalloc((void**)&w2_gpu, n * n * sizeof(double)));
#endif

  // Allocate memory 
  tau = new double[n];
  ipiv = new int[n];
  A = new double[n * n];
  Q = new double[n * n];
  T = new double[n * n];
  if (*fixw) {
    M_cache = new double*[L];
    for (int i = 0; i < L; i++) M_cache[i] = NULL;
  } else M_cache = NULL;
  D = new double[n];
  Db = new double[n];

  // Allocate workspace for LAPACK
  lwork = n * n;
  double temp;
  lapack_dgeqp3(n, n, Q, n, ipiv, tau, &temp, -1);
  if (temp > lwork) lwork = (int)temp;
  lapack_dgeqrf(n, n, Q, n, tau, &temp, -1);
  if (temp > lwork) lwork = (int)temp;
  lapack_dormqr("RN", n, n, n, Q, n, tau, A, n, &temp, -1);  
  if (temp > lwork) lwork = (int)temp;
  lapack_dorgqr(n, n, n, Q, n, tau, &temp, -1);
  if (temp > lwork) lwork = (int)temp;
  work = new double[lwork];

  } catch (std::exception& e) {
    std::cerr << "computeg_init: " << e.what() << std::endl;
    exit(1);
  }
}

#ifndef DQMC_CUDA

// computes M=B_{il+1}...B_1*B_L...B_{il+k}
void compute_M(int n, double *B, int L, double *h, int &il, int k, double *M)
{
  int l, i, j;
  PROFILE_BEGIN();
  il++; if (il >= L) il = 0;
  PROFILE_BEGIN();
  #ifdef _OPENMP
  #pragma omp parallel for shared(n, il, M, B, h) private(i, j) schedule(static)
  #endif
  for (j = 0; j < n; j++)
    for (i = 0; i < n; i++)
      M[j * n + i] = B[j * n + i] * h[il * n + i]; 
  PROFILE_END(profile_scalerow, n * n);
  for (l = 1; l < k; l++) {
    il++; if (il >= L) il = 0;
    blas_dgemm("NN", n, n, n, 1.0, B, n, M, n, 0.0, work, n);  
    PROFILE_BEGIN();
    #ifdef _OPENMP
    #pragma omp parallel for shared(n, il, M, work, h) private(i, j) schedule(static)
    #endif
    for (j = 0; j < n; j++)
      for (i = 0; i < n; i++)
	M[j * n + i] = work[j * n + i] * h[il * n + i];
    PROFILE_END(profile_scalerow, n * n);
  }
  PROFILE_END(profile_computem, 0);
}

#else

void compute_M(int n, double *B, int L, double *h, int &il, int k, double *M)
{
  int l;
  PROFILE_BEGIN();
  il++; if (il >= L) il = 0;
  scalerow_gpu(n, h_gpu + il * n, B_gpu, w1_gpu);
  for (l = 1; l < k; l++) {
    il++; if (il >= L) il = 0;
    // cublas_dgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, 1.0, B_gpu, n, w1_gpu, n, 0.0, w2_gpu, n);
    cublas_dgemm('N', 'N', n, n, n, 1.0, B_gpu, n, w1_gpu, n, 0.0, w2_gpu, n);
    scalerow_gpu(n, h_gpu + il * n, w2_gpu, w1_gpu);
  }
  CUDACHECK(cublasGetMatrix(n, n, sizeof(double), w1_gpu, n, M, n));
  PROFILE_END(profile_computem, 0);
}

#endif

void compute_M_cache(int n, double *B, int L, double *h, int &il, int k, double *M, bool update)
{
  int p = il;
  if (!M_cache || !M_cache[p] || update || k == 1) {
    // compute M
    compute_M(n, B, L, h, il, k, M);
    if (M_cache && k != 1) {
      // update cache
      if (!M_cache[p]) M_cache[p] = new double[n * n];
      memcpy(M_cache[p], M, n * n * sizeof(double));
    }
  } else {
    // copy M from cache
    memcpy(M, M_cache[p], n * n * sizeof(double));
    // il should be updated
    for (int l = 0; l < k; l++) {
      il++; if (il >= L) il = 0;
    }
  }
}

void computeg_(int *pn, double *B, int *pL, int *nOrth, double *h, int *pil, double *G, double *psgn, double *pdet)
{
  try {
  PROFILE_ENABLE();
  PROFILE_BEGIN();
  int n = *pn, L = *pL, il = *pil - 1, k = *nOrth, i, j, l;  

#ifdef DQMC_CUDA
  CUDACHECK(cublasSetMatrix(n, n, sizeof(double), B, n, B_gpu, n)); 
  CUDACHECK(cublasSetMatrix(n, L, sizeof(double), h, n, h_gpu, n));
#endif

  double tmp;
  // Compute G using the ASQRD algorithm
  double max, dot;
  int p, t;
  
  compute_M_cache(n, B, L, h, il, k, Q, false); // Q = B_1
  for (i = 0; i < n; i++) ipiv[i] = 0;
  lapack_dgeqp3(n, n, Q, n, ipiv, tau, work, lwork); // QRP

  for (i = 0; i < n; i++) D[i] = Q[i * n + i]; // D = diag(R)

  PROFILE_BEGIN();
  #ifdef _OPENMP
  #pragma omp parallel for shared(n, ipiv, T, Q, D) private(i, j, p) schedule(static)
  #endif
  for (j = 0; j < n; j++) {
    p = ipiv[j] - 1;
    for (i = 0; i <= j; i++)
      T[p * n + i] = Q[j * n + i] / D[i]; // T = D^-1*R*P
    for (; i < n; i++)
      T[p * n + i] = 0;
  }
  PROFILE_END(profile_scalerow, n * n / 2);

  for (l = 1; l < L / k; l++) {
    // Q = (B_l*Q)*D
    compute_M_cache(n, B, L, h, il, k, A, l >= L / k - 2); // A = B_l
    lapack_dormqr("RN", n, n, n, Q, n, tau, A, n, work, lwork); // A = A * Q
    
#if 0
    // Use QRP for stratification
    // A = A * D
    PROFILE_BEGIN();
    #ifdef _OPENMP
    #pragma omp parallel for shared(n, A, D) private(i, j) schedule(static)
    #endif
    for (j = 0; j < n; j++)
      for (i = 0; i < n; i++)
        Q[j * n + i] = A[j * n + i] * D[j];
    PROFILE_END(profile_scalerow, n * n);
    
    for (i = 0; i < n; i++) ipiv[i] = 0;
    lapack_dgeqp3(n, n, Q, n, ipiv, tau, work, lwork); // QRP
    for (i = 0; i < n; i++) ipiv[i]--;
    
#else
    // Prepivoting
    // A = A * D and compute the norm of each column
    PROFILE_BEGIN();
    #ifdef _OPENMP
    #pragma omp parallel for shared(n, A, D) private(i, j, tmp, dot) schedule(static)
    #endif
    for (j = 0; j < n; j++) {
      dot = 0.0;
      for (i = 0; i < n; i++) {
        tmp = A[j * n + i] * D[j];
	A[j * n + i] = tmp;
	dot += tmp * tmp;
      }
      Db[j] = dot;
    }
    PROFILE_END(profile_normcol, n * n * 3);
  
    // compute a permutation P that sorts the columns of A
    for (j = 0; j < n; j++) ipiv[j] = j;
    for (j = 0; j < n - 1; j++) {
      // find column with maximum norm
      max = Db[j];
      p = j;
      for (i = j + 1; i < n; i++)
        if (Db[i] > max) {
	  max = Db[i];
	  p = i;
	}
      // swap columns
      if (p != j) {
        t = ipiv[j]; ipiv[j] = ipiv[p]; ipiv[p] = t;
	tmp = Db[j]; Db[j] = Db[p]; Db[p] = tmp;
      }
    }
    
    // apply the permutation
    #ifdef _OPENMP
    #pragma omp parallel for shared(n, ipiv, Q, A) private(i, j, p) schedule(static)
    #endif
    for (j = 0; j < n; j++) {
      p = ipiv[j];
      for (i = 0; i < n; i++)
	Q[j * n + i] = A[p * n + i];
    }
    
    // standard QR
    lapack_dgeqrf(n, n, Q, n, tau, work, lwork);
#endif

    for (i = 0; i < n; i++) D[i] = Q[i * n + i]; // D = diag(R)
    
    // T = D^-1 * R * P * T
#if 0
    // DTRMM version
    PROFILE_BEGIN();
    #ifdef _OPENMP
    #pragma omp parallel for shared(n, Q, D) private(i, j) schedule(static, 1)
    #endif
    for (j = 0; j < n; j++)
      for (i = 0; i <= j; i++)
         Q[j * n + i] = Q[j * n + i] / D[i]; // R = D^-1 * R
    PROFILE_END(profile_scalerow, n * n / 2);
    #ifdef _OPENMP
    #pragma omp parallel for shared(n, work, ipiv, T) private(i, j) schedule(static)
    #endif
    for (j = 0; j < n; j++) {
      for (i = 0; i < n; i++)
        work[j * n + i] = T[j * n + i];
      for (i = 0; i < n; i++)
        T[j * n + i] = work[j * n + ipiv[i]]; // T = P * T
    }
    blas_dtrmm("LUNN", n, n, 1.0, Q, n, T, n); // T = R * T
#else
    // DGEMM version
    PROFILE_BEGIN();
    #ifdef _OPENMP
    #pragma omp parallel for shared(n, ipiv, A, Q, D) private(i, j, p) schedule(static)
    #endif
    for (j = 0; j < n; j++) {
      p = ipiv[j];
      for (i = 0; i <= j; i++)
	A[p * n + i] = Q[j * n + i] / D[i]; // A = D^-1*R*P
      for (; i < n; i++)
	A[p * n + i] = 0;
    }
    PROFILE_END(profile_scalerow, n * n / 2);
    blas_dgemm("NN", n, n, n, 1.0, A, n, T, n, 0.0, work, n); // T = A * T
    memcpy(T, work, n * n * sizeof(double));
#endif
  }

  lapack_dorgqr(n, n, n, Q, n, tau, work, lwork); // build Q

  // compute G = (Db * Q' + Ds * T)^-1 * Db * Q'
  
  // split D and compute det(Db)
  double sgn2 = 1.0, det2 = 0.0;
  for (i = 0; i < n; i++)
    if (fabs(D[i]) > 1) {
      Db[i] = D[i]; D[i] = 1.0;
      if (Db[i] < 0.0) {
        sgn2 = -sgn2;
        det2 += log(-Db[i]);
      } else {
        det2 += log(Db[i]); 
      }
    } else {
      Db[i] = 1.0;
    }
  
  // G = Db * Q' ; A = Db * Q' + Ds * T
  PROFILE_BEGIN();
  #ifdef _OPENMP
  #pragma omp parallel for shared(n, Q, Db, G, A, D, T) private(i, j, tmp) schedule(static)
  #endif
  for (j = 0; j < n; j++)
    for (i = 0; i < n; i++) {
      tmp = Q[i * n + j] / Db[i];
      G[j * n + i] = tmp;
      A[j * n + i] = tmp + D[i] * T[j * n + i];
    }
  PROFILE_END(profile_scalerow, n * n * 3);
  
  // G = A^-1 * G
  lapack_dgetrf(n, n, A, n, ipiv);
  double sgn1, det1; sgndet(n, A, n, ipiv, sgn1, det1);
  lapack_dgetrs("N", n, n, A, n, ipiv, G, n);

  // compute det(Q)
  double sgn3 = 1.0;
  for (i = 0; i < n - 1; i++) 
    if (tau[i] != 0) sgn3 = -sgn3;
  
  *psgn = sgn1 * sgn2 * sgn3; *pdet = - det1 - det2;
 
  PROFILE_END(profile_computeg, 0);
  PROFILE_DISABLE();

  } catch (std::exception& e) {
    std::cerr << "computeg: " << e.what() << std::endl;
    exit(1);
  }
}

void computeg_free_(void) {
  // Free allocated memory
  delete []tau;
  delete []ipiv;
  delete []A;
  delete []Q;
  delete []T;
  delete []D;
  delete []Db;
  if (M_cache) {
    for (int i = 0; i < L; i++) 
      if (M_cache[i]) delete []M_cache[i];
    delete []M_cache;
  }
  delete []work;

#ifdef DQMC_CUDA  
  CUDACHECK(cudaFree(B_gpu));
  CUDACHECK(cudaFree(Bi_gpu));
  CUDACHECK(cudaFree(h_gpu));
  CUDACHECK(cudaFree(w1_gpu));
  CUDACHECK(cudaFree(w2_gpu));
  // CUDACHECK(cublasDestroy(handle));
  CUDACHECK(cublasShutdown());
#endif
}

void swapg_(int *pn, double *B, double *Bi, double *h, double *G) {
  // Computes G = h * B * G * Bi / h
  PROFILE_ENABLE();
  PROFILE_BEGIN();
  int i, j, n = *pn;
#ifdef DQMC_CUDA
  CUDACHECK(cublasSetMatrix(n, n, sizeof(double), B, n, B_gpu, n)); 
  CUDACHECK(cublasSetMatrix(n, n, sizeof(double), Bi, n, Bi_gpu, n)); 
  CUDACHECK(cublasSetMatrix(n, n, sizeof(double), G, n, w1_gpu, n));
  CUDACHECK(cublasSetVector(n, sizeof(double), h, 1, h_gpu, 1));
  // W = B * G
  // cublas_dgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, 1.0, B_gpu, n, w1_gpu, n, 0.0, w2_gpu, n);
  cublas_dgemm('N', 'N', n, n, n, 1.0, B_gpu, n, w1_gpu, n, 0.0, w2_gpu, n);
  // G = W * Bi
  // cublas_dgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, 1.0, w2_gpu, n, Bi_gpu, n, 0.0, w1_gpu, n);
  cublas_dgemm('N', 'N', n, n, n, 1.0, w2_gpu, n, Bi_gpu, n, 0.0, w1_gpu, n);
  // G = h * G / h
  scalerowcol_gpu(n, h_gpu, w1_gpu);
  CUDACHECK(cublasGetMatrix(n, n, sizeof(double), w1_gpu, n, G, n));
#else
  // W = B * G
  blas_dgemm("NN", n, n, n, 1.0, B, n, G, n, 0.0, work, n);
  // G = W * Bi
  blas_dgemm("NN", n, n, n, 1.0, work, n, Bi, n, 0.0, G, n);
  // G = h * G / h
  PROFILE_BEGIN();
  #ifdef _OPENMP
  #pragma omp parallel for shared(n, G, h) private(i, j) schedule(static)
  #endif
  for (j = 0; j < n; j++) 
    for (i = 0; i < n; i++)
      G[j * n + i] = h[i] * G[j * n + i] / h[j];
  PROFILE_END(profile_scalerowcol, n * n * 2);
#endif
  PROFILE_END(profile_swapg, 0);
  PROFILE_DISABLE();
}
