#include <cmath>
#include <cstdlib>
#include <iostream>

using namespace std;

#include "blaslapack.h"
#include "profile.h"

// computeg.c++ functions
extern "C" {
  void cpp_gfun_init_(long *cpp_data, int *n, int *L, int *nOrth, int *nWrap, int *fixw);
  void cpp_gfun_computeg_(long *cpp_data, int *il, double *sgn, double *G, double *h,
                          double *B, int *nOrth, double *det);
  void cpp_gfun_free_(long *cpp_data);
  void cpp_gfun_swapg_(long *cpp_data, double *B, double *Bi, double *h, double *G);
  void cpp_gfun_invalid_cache_(long *cpp_data, int *j);
}

void mat_copy_scale_col(int m, int n, const double *A, int lda, const double *S,
                        double *B, int ldb)
{
  int i, j;
  for (j = 0; j < n; j++)
    for (i = 0; i < m; i++)
      B[j * ldb + i] = A[j * lda + i] * S[j];
  /* double s;
  for (j = n; j > 0; j--) {
    s = *S++;
    for (i = m; i > 0; i--)
      *B++ = *A++ * s;
    B += ldb - m;
    A += lda - n;
  } */
}

void compute_B(int m, double *B, double *Bi, double dtau)
{
  int i, j , k, n = m * m;
  double *K = new double[n * n];
  double *T = new double[n * n];
  double w[n];
  double temp;
  lapack_dsyev("VL", n, K, n, w, &temp, -1);
  int lwork = temp;
  double *work = new double[lwork];
  

  // Build the hopping matrix
  for (i = 0; i < n * n; i++) K[i] = 0;
  // I * K_1
  for (k = 0; k < m; k++) {
    int offset = k*m*n + k*m;
    for (i = 1; i < m; i++) {
      K[offset + i + (i-1)*n] = -dtau;
    }
    K[offset + m-1] = -dtau;
  }
  // K_1 * I
  for (j = 0; j < n-m; j++) 
    K[j * n + j + m] = -dtau;
  for (j = 0; j < m; j++)
    K[j * n + j + n - m] = -dtau;
  
  // compute the eigendecomposition
  lapack_dsyev("VL", n, K, n, w, work, lwork);
  
  // compute the exponential of the diagonal
  for (i = 0; i < n; i++)
    work[i] = exp(w[i]);
  mat_copy_scale_col(n, n, K, n, work, T, n);
        
  blas_dgemm("NT", n, n, n, 1.0, T, n, K, n, 0.0, B, n);
    
  // compute the exponential of the diagonal
  for (i = 0; i < n; i++)
    work[i] = 1.0 / exp(w[i]);
  mat_copy_scale_col(n, n, K, n, work, T, n);
        
  blas_dgemm("NT", n, n, n, 1.0, T, n, K, n, 0.0, Bi, n);

  delete []K;
  delete []T;
}

int main(int argc, const char *argv[])			  
{
  if (argc < 7) {
    cout << "Usage: flops m L dtau U nOrth fixw" << endl;
    return 1;
  }
  /* parameters */
  int m = atoi(argv[1]), L = atoi(argv[2]);
  double dtau = atof(argv[3]), U = atof(argv[4]);
  int nOrth = atoi(argv[5]), fixw = atoi(argv[6]);
  
  int i, n = m * m;
  double lambda = acosh(exp(U * dtau / 2.0)),
         lambda1 = exp(lambda), lambda2 = exp(-lambda);

  double *h = new double[L * n];
  double *B = new double[n * n];
  double *Bi = new double[n * n];
  double *G = new double[n * n];
 
  for (i = 0; i < L * n; i++) {
    if (rand() < RAND_MAX / 2) h[i] = lambda1;
    else h[i] = lambda2;
  }
  // printf("exp(+-lambda) = %e %e\n", lambda1, lambda2);
   
  // compute B
  compute_B(m, B, Bi, dtau);

  // initialize computeg.c++
  long cpp_data;
  cpp_gfun_init_(&cpp_data, &n, &L, &nOrth, &nOrth, &fixw);
  
  // compute G
  int il = 1;
  double sgn, det;
  for (i = 0; i < 10; i++) {
    cpp_gfun_computeg_(&cpp_data, &il, &sgn, G, h, B, &nOrth, &det);
    for (int j = 0; j < nOrth; j++) {
      cpp_gfun_swapg_(&cpp_data, B, Bi, h + j * n, G);
    }
  }

  // print results
  profile_print_();

  // free memory
  cpp_gfun_free_(&cpp_data);
  delete []h;
  delete []B;
  delete []Bi;
  delete []G;

  return 0;
}

