/* ====OT1_OPTM_H ==================================================================================
 *
 * Author: Patrice Koehl (in collaboration with Henri Orland), November 2018
 * Department of Computer Science
 * University of California, Davis
 *
 * This file implements different methods needed to solve the optimal transport problem with variable
 * mass using the minimization of a free energy
 *
 =============================================================================================== */

#ifndef _OT1_OPTM_H_
#define _OT1_OPTM_H_

  #include <cmath>
  #include <algorithm>
  #include <functional>
  #include "VectorOps.h"

/* ===============================================================================================
   prototypes for BLAS and LAPACK functions 
   =============================================================================================== */

  extern "C" {

	// BLAS1: copy one vector into another
	void dcopy_(int * n, double * X, int *incx, double * Y, int *incy);

	// BLAS1: scale a vector
	void dscal_(int * n, double *scale, double *Y, int *incy);

	// BLAS1: dot product of two vectors	
	double ddot_(int * n, double * u, int * incu, double * v, int *incv);

	// BLAS1: norm of a vector
	double dnrm2_(int * n, double * X, int *incx);

	// BLAS2: perform Y := alpha*op( A )* B  + beta*Y
	void dgemv_(char * trans, int * m, int * n, double * alpha, double *A,
		int *lda, double * X, int * incx, double *beta, double * Y, int * incy);

	// BLAS3: perform C := alpha*op( A )* op(B)  + beta*C
	void dgemm_(char * transa, char * transb, int * m, int * n, int * k,
		double * alpha, double * A, int * lda,
		double * B, int * ldb, double * beta, double * C, int * ldc);

	// LAPACK: solve a real system of linear equations A * X = B, where A is a symmetric matrix 
	void dsysv_(char *uplo, int *n, int *nrhs, double *A, int *lda, int *ipiv, double *b,
		int *ldb, double *work, int *lwork, int *info);

	// LAPACK: solve a real system of linear equations A * X = B, where A is a symmetric matrix 
	void dsysvx_(char *fact, char *uplo, int *n, int *nrhs, double *A, int *lda, double *AF,
		int *ldaf, int *ipiv, double *b, int *ldb, double *x, int *ldx, 
		double *rcond, double *ferr, double *berr,
		double *work, int *lwork, int *iwork, int *info);

  }

/* ===============================================================================================
   The OT1_OPTM class
   =============================================================================================== */

  class OT1_OPTM{

  public:

	// Solve for G and d at infinite beta (iterative over beta)
	double ot1_w(int n1, double *m1, int n2, double *m2, double *C,
	double *G, double *lambda, double *mu, double *x, double beta_init, double beta_final,
	double *Fopt, int iprint, int init, double a1, double a2, int nthreads);

	// Solve for G at a given beta value
	void solveG(int n1, double *m1, int n2, double *m2, double *C, double *G, double *lambda, 
	double *mu, double *x, double beta, double tol, int *niter, int nthreads);

	void initOT1(int n1, int n2, double a1, double a2);

  private:

	// compute Free Energy
	double computeF(int n1, double *m1, int n2, double *m2, double *C, double beta,
	double *lambda, double *mu, double x);

	// Check marginals: row sums and column sums of the transport plan G
	void checkMarginals(int n1, double *m1, int n2,  double *m2, double *G, double *err_row,
		double *err_col);

	// Compute the tranport plan G based on the auxiliary variables lambda and mu
	void computeG(int n1, int n2, double *C, double beta,
	double *lambda, double *mu, double x, double *G);

	// Check current Jacobian system
	void computeRC(int n1, double *m1, int n2, double *m2, double *C, double beta,
	double *lambda, double *mu, double x, double *err, int nthreads);

	// Solve Jacobian system for updates in Lambda and Mu, using direct solver
	void computedX_direct(int n1, int n2, double beta, double *C, double *lambda, double *mu, 
	double x, int nthreads);

	// Solve Jacobian system for updates in Lambda and Mu, using direct solver
	void computedX_direct2(int n1, int n2, double beta, double *C, double *lambda, double *mu, 
	double x, int nthreads);

	// Solve Jacobian system for updates in Lambda and Mu, using iterative solver (CG)
	void computedX_iter(int n1, int n2, double beta, double *C, double *lambda, double *mu, 
	double x, int nthreads);

	// internal variables

	double alpha1, alpha2;
	double *Work1, *Work2, *Work; 
	double *Jac, *Jac2, *B, *X;
	double *Al, *Am, *Ad;
	double *dl, *dm, *dd;
	double *D1, *D2, *v1, *v2;
	double *w, *T;
	double *vone;
	int *IPIV, *iwork;

  };

/* ===============================================================================================
   Common variables
   =============================================================================================== */

/* ===============================================================================================
   checkMarginals

   Notes:
   =====
   Check the marginals of a coupling matrix

   Input:
   =====
	n: 	number of points on space 1
	m1:	measure on space 1
	n2:	number of points on space 2
	m2:	measure on space 2
	G:	coupling matrix
  Output:
	err_row: error on row marginals
	err_col: error on col marginals
   =============================================================================================== */

  void OT1_OPTM::checkMarginals(int n1, double *m1, int n2,  double *m2, double *G, double *err_row,
		double *err_col)
  {

	int nmax = std::max(n1,n2);
	double *ones = new double[nmax];
	double *vect = new double[nmax];
	for(int i = 0; i <nmax; i++) { ones[i]=1;};

	double alpha, beta;
	alpha = 1.0; beta = 0.0;
	int inc = 1;
	char Trans   = 'T';
	char NoTrans = 'N';

	dgemv_(&NoTrans, &n1, &n2, &alpha, G, &n1, 
			ones, &inc, &beta, vect, &inc);
	for(int i = 0; i <n1; i++) { vect[i] = vect[i] - m1[i]; };
	double val = ddot_(&n1, vect, &inc, vect, &inc);
	*err_row = std::sqrt(val);

	dgemv_(&Trans, &n1, &n2, &alpha, G, &n1, 
			ones, &inc, &beta, vect, &inc);
	for(int i = 0; i <n2; i++) { vect[i] = vect[i] - m2[i];};
	val = ddot_(&n2, vect, &inc, vect, &inc);
	*err_col = std::sqrt(val);
  }
/* ===============================================================================================
   computeG

   Input:
   =====
	n1:	number of points for point set 1
	m1:		measure on points1
	n2:	number of points for point set 2
	m2:		measure on points2
	C:		cost matrix
	beta:		current beta
	lambda:		current values of Lagragians lambda
	mu:		current values of Lagragians mu
	nthreads:	number of threads for parallel computation
	
   Output:
   ======
	G:		coupling matrix
   =============================================================================================== */

  void OT1_OPTM::computeG(int n1, int n2, double *C, double beta,
	double *lambda, double *mu, double x, double *G)
  {

	double val;
	double tol = 1.e-8;

	for(int j = 0; j < n2; j++) {
		for(int i = 0; i < n1; i++) {
			val = C[i+j*n1] + lambda[i] + mu[j] + x;
			val = coupling(val, beta);
			if(val < tol) val = 0;
			if(val > 1.- tol) val = 1;
			G[i+n1*j] = val;
		}
	}

  }
/* ===============================================================================================
   computeF

   Input:
   =====
	n1:	number of points for point set 1
	m1:		measure on points1
	n2:	number of points for point set 2
	m2:		measure on points2
	C:		cost matrix
	beta:		current beta
	lambda:		current values of Lagragians lambda
	mu:		current values of Lagragians mu
	nthreads:	number of threads for parallel computation
	
   Output:
   ======
	F:		free energy
   =============================================================================================== */

  double OT1_OPTM::computeF(int n1, double *m1, int n2, double *m2, double *C, double beta,
	double *lambda, double *mu, double x)
  {

	double val, vexp, y, s1, s2, s3;
	double tolv = 1.e-5;

	s1 = 0;
	for(int i = 0; i < n1; i++) {
		s1 += lambda[i]*lambda[i];
	}
	s1 /=(4*alpha1);

	s2 = 0;
	for(int j = 0; j < n2; j++) {
		s2 += mu[j]*mu[j];
	}
	s2 /=(4*alpha2);

	s3 = 0;
	for(int j = 0; j < n2; j++) {
		for(int i = 0; i < n1; i++) {
			y = beta*(C[i+j*n1] + lambda[i] + mu[j]+x);
			if(std::abs(y) <tolv) {
				val = 1.0;
			} else {
				vexp = std::exp(-y);
				val = (1.0-vexp)/y;
			}
			s3 += std::log(val);
		}
	}
	double F = -x -s1 -s2 -s3/beta;
	return F;

  }



/* ===============================================================================================
   computeRC

   Input:
   =====
	n1:	number of points for point set 1
	m1:		measure on points1
	n2:	number of points for point set 2
	m2:		measure on points2
	C:		cost matrix
	beta:		current beta
	lambda:		current values of Lagragians lambda
	mu:		current values of Lagragians mu
	nthreads:	number of threads for parallel computation
	
   Output:
   ======
	row:		errors on row sums
	col:		errors on col sums
	err_r:		total error on row sums
	err_m:		total error on col sums
   =============================================================================================== */

  void OT1_OPTM::computeRC(int n1, double *m1, int n2, double *m2, double *C, double beta,
	double *lambda, double *mu, double x, double *err, int nthreads)
  {

	int n1n2 = n1*n2;
	char Trans   = 'T';
	char NoTrans = 'N';
	int inc = 1;
	int one = 1;
	double a, b;
	
	dcopy_(&n1n2, C, &inc, Work1, &inc);
	a = 1.0; b = 1;
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, lambda, &n1, vone, &n2, &b, 
		Work1, &n1);
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, vone, &n1, mu, &n2, &b, 
		Work1, &n1);
	for(int i = 0; i < n1n2; i++) Work1[i] += x;

	vect_coupling(n1n2, Work1, beta, nthreads);

	dcopy_(&n1, lambda, &inc, m1, &inc);
	a = 1./(2.0*alpha1);
	dscal_(&n1, &a, m1, &inc);

	dcopy_(&n1, m1, &inc, Al, &inc);
	a = 1; b = -1;
	dgemv_(&NoTrans, &n1, &n2, &a, Work1, &n1, vone, &inc, &b, Al, &inc);

	dcopy_(&n2, mu, &inc, m2, &inc);
	a = 1./(2.0*alpha2);
	dscal_(&n2, &a, m2, &inc);

	dcopy_(&n2, m2, &inc, Am, &inc);
	a = 1; b = -1;
	dgemv_(&Trans, &n1, &n2, &a, Work1, &n1, vone, &inc, &b, Am, &inc);

	Ad[0] = -1;
	for(int i = 0; i < n1n2; i++) Ad[0] += Work1[i];

	double err_l = 0;
	for(int i = 0; i < n1; i++) err_l += std::abs(Al[i]);

	double err_m = 0;
	for(int j = 0; j < n2; j++) err_m += std::abs(Am[j]);

	double err_d = std::abs(Ad[0]);

	*err = err_l+err_m+err_d;

  }

/* ===============================================================================================
   computedX_direct

   Input:
   =====
	n1:	number of points for point set 1
	n2:	number of points for point set 2
	beta:		current beta
	C:		Cost matrix
	lambda:		current values of Lagragians lambda
	mu:		current values of Lagragians mu
	nthreads:	number of threads for parallel computation
	
	
   =============================================================================================== */

  void OT1_OPTM::computedX_direct(int n1, int n2, double beta, double *C, double *lambda, double *mu, 
		double x, int nthreads)
  {

	int n1n2 = n1*n2;
	int ntot = n1 + n2 + 1;
	char Trans   = 'T';
	char NoTrans = 'N';
	int inc = 1;
	int one = 1;
	double a, b;

	memset(Jac, 0, ntot*ntot*sizeof(double));
	
	dcopy_(&n1n2, C, &inc, Work1, &inc);
	a = 1.0; b = 1;
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, lambda, &n1, vone, &n2, &b, 
		Work1, &n1);
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, vone, &n1, mu, &n2, &b, 
		Work1, &n1);
	for(int i = 0; i < n1n2; i++) Work1[i] += x;

	vect_dcoupling(n1n2, Work1, beta, nthreads);

	a = 1; b = 0;
	dgemv_(&NoTrans, &n1, &n2, &a, Work1, &n1, vone, &inc, &b, v1, &inc);
	dgemv_(&Trans, &n1, &n2, &a, Work1, &n1, vone, &inc, &b, v2, &inc);

	dcopy_(&n1, vone, &inc, D1, &inc);
	a = -1./(2*alpha1);
	dscal_(&n1, &a, D1, &inc);
	for(int i = 0; i < n1; i++) D1[i] += v1[i];

	dcopy_(&n2, vone, &inc, D2, &inc);
	a = -1./(2*alpha2);
	dscal_(&n2, &a, D2, &inc);
	for(int i = 0; i < n2; i++) D2[i] += v2[i];

	double d = 0;
	for(int i = 0; i < n1n2; i++) d+= Work1[i];

	for(int j = 0; j < n2; j++) {
		int na = n1*ntot + ntot*j;
		int nb = n1*j;
		for(int i = 0; i < n1; i++) {
			Jac[na + i] = Work1[nb + i];
			Jac[ntot*i + n1 + j] = Work1[n1*j+i];
		}
	}
	for(int i = 0; i < n1; i++) {
		Jac[(n1+n2)*ntot + i] = v1[i];
		Jac[i*ntot + n1+n2] = v1[i];
		B[i] = -Al[i];
		Jac[i*ntot+i] = D1[i];
	}
	for(int j = 0; j < n2; j++) {
		Jac[(n1+n2)*ntot + n1+j] = v2[j];
		Jac[n1*ntot + j*ntot + n1+n2] = v2[j];
		B[n1+j] = -Am[j];
		Jac[ntot*n1 + j*ntot + n1+j] = D2[j];
	}
	Jac[ntot*ntot -1] = d;
	B[ntot-1] = -Ad[0];

	int n = ntot*ntot;
	dcopy_(&n, Jac, &inc, Jac2, &inc);
	
	int nrhs = 1; char U  = 'U'; int info;

	for(int i = 0; i < n1; i++) {
		dl[i] = B[i]/D1[i];
	}
	for(int j = 0; j < n2; j++) {
		dm[j] = B[n1+j]/D2[j];
	}
	dd[0] = B[ntot-1]/d;

	char F = 'N';
	int lwork = 128*ntot;
	double Rcond, Ferr, Berr;

	dsysvx_(&F, &U, &ntot, &nrhs, Jac, &ntot, Jac2, &ntot, IPIV, B, &ntot, X, &ntot,
	&Rcond, &Ferr, &Berr, Work, &lwork, iwork, &info);

	if(info ==0) {
		for(int i = 0; i < n1; i++) {
			dl[i] = X[i];
		}
		for(int j = 0; j < n2; j++) {
			dm[j] = X[n1+j];
		}
		dd[0] = X[ntot-1];
	} 


  }

/* ===============================================================================================
   computedX_direct2

   Input:
   =====
	n1:	number of points for point set 1
	n2:	number of points for point set 2
	beta:		current beta
	C:		Cost matrix
	lambda:		current values of Lagragians lambda
	mu:		current values of Lagragians mu
	nthreads:	number of threads for parallel computation
	
	
   =============================================================================================== */

  void OT1_OPTM::computedX_direct2(int n1, int n2, double beta, double *C, double *lambda, double *mu, 
		double x, int nthreads)
  {

	int n1n2 = n1*n2;
	int nmax = std::max(n1, n2);
	char Trans   = 'T';
	char NoTrans = 'N';
	int inc = 1;
	int one = 1;
	double a, b;

	memset(Jac, 0, nmax*nmax*sizeof(double));
	
	dcopy_(&n1n2, C, &inc, Work1, &inc);
	a = 1.0; b = 1;
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, lambda, &n1, vone, &n2, &b, 
		Work1, &n1);
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, vone, &n1, mu, &n2, &b, 
		Work1, &n1);
	for(int i = 0; i < n1n2; i++) Work1[i] += x;

	vect_dcoupling(n1n2, Work1, beta, nthreads);

	a = 1; b = 0;
	dgemv_(&NoTrans, &n1, &n2, &a, Work1, &n1, vone, &inc, &b, v1, &inc);
	dgemv_(&Trans, &n1, &n2, &a, Work1, &n1, vone, &inc, &b, v2, &inc);

	dcopy_(&n1, vone, &inc, D1, &inc);
	a = -1./(2*alpha1);
	dscal_(&n1, &a, D1, &inc);
	for(int i = 0; i < n1; i++) D1[i] += v1[i];

	dcopy_(&n2, vone, &inc, D2, &inc);
	a = -1./(2*alpha2);
	dscal_(&n2, &a, D2, &inc);
	for(int i = 0; i < n2; i++) D2[i] += v2[i];

	double d = 0;
	for(int i = 0; i < n1n2; i++) d+= Work1[i];

	// (a) Compute G'
	a = -1.0/d; b = 1;
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, v1, &n1, v2, &n2, &b, 
		Work1, &n1);

	// (b) Compute b'1 and b'2
	for(int i = 0; i < n1; i++) Al[i] = -Al[i] + Ad[0]*v1[i]/d;
	for(int i = 0; i < n2; i++) Am[i] = -Am[i] + Ad[0]*v2[i]/d;
	
	int nrhs = 1; char U  = 'L'; int info;
	char F = 'N';
	int lwork; 
	double Rcond, Ferr, Berr;

	if(n1 < n2+10) {

		// (c) Compute w2 and alp2
		for(int i = 0; i < n2; i++) {
			w[i] = v2[i]/D2[i];
		}
		double alp2 = ddot_(&n2, v2, &inc, w, &inc);
		alp2 = 1.0/(d - alp2);

		// (d) Compute b''1
		for(int i = 0; i < n2; i++) Work[i] = Am[i]/D2[i];
		dcopy_(&n1, Al, &inc, B, &inc);
		a = -1.0; b = 1;
		dgemv_(&NoTrans, &n1, &n2, &a, Work1, &n1, Work, &inc, &b, B, &inc);
		a = ddot_(&n2, w, &inc, Am, &inc);
		a = -alp2*a; b = 1;
		dgemv_(&NoTrans, &n1, &n2, &a, Work1, &n1, w, &inc, &b, B, &inc);

		// (e) Compute T1
		a = 1.0; b = 0;
		dgemv_(&NoTrans, &n1, &n2, &a, Work1, &n1, w, &inc, &b, T, &inc);

		// (f) Compute J
		for(int j = 0; j < n2; j++) {
			for(int i = 0; i < n1; i++) {
				Work[i+j*n1] = Work1[i+j*n1]/D2[j];
			}
		}

		a = -1.0; b = 0;
		dgemm_(&NoTrans, &Trans, &n1, &n1, &n2, &a, Work, &n1, 
			Work1, &n1, &b, Jac, &n1);
		for(int i = 0; i < n1; i++) Jac[i+n1*i] += D1[i];

		a = -1.0/d; b = 1;
		dgemm_(&NoTrans, &Trans, &n1, &n1, &one, &a, v1, &n1, v1, &n1, &b, 
			Jac, &n1);
		a = -alp2; b = 1;
		dgemm_(&NoTrans, &Trans, &n1, &n1, &one, &a, T, &n1, T, &n1, &b, 
			Jac, &n1);

		// (g) Solve Jac x1 = b''1
		lwork = 128*n1;

		dsysvx_(&F, &U, &n1, &nrhs, Jac, &n1, Jac2, &n1, IPIV, B, &n1, X, &n1,
		&Rcond, &Ferr, &Berr, Work, &lwork, iwork, &info);

		for(int i = 0; i < n1; i++) {
			dl[i] = X[i];
		}

		// (h) Compute x2
		dcopy_(&n2, Am, &inc, dm, &inc);
		a = -1.0; b = 1.0;
		dgemv_(&Trans, &n1, &n2, &a, Work1, &n1, X, &inc, &b, dm, &inc);
		double bet = ddot_(&n2, dm, &inc, w, &inc);
		for(int i = 0; i < n2; i++) dm[i] = dm[i]/D2[i];
		a = alp2*bet;
		daxpy_(&n2, &a, w, &inc, dm, &inc);

	} else {
	
		// (c) Compute w1 and alp1
		for(int i = 0; i < n1; i++) {
			w[i] = v1[i]/D1[i];
		}
		double alp1 = ddot_(&n1, v1, &inc, w, &inc);
		alp1 = 1.0/(d - alp1);

		// (d) Compute b''2
		for(int i = 0; i < n1; i++) Work[i] = Al[i]/D1[i];
		dcopy_(&n2, Am, &inc, B, &inc);
		a = -1.0; b = 1;
		dgemv_(&Trans, &n1, &n2, &a, Work1, &n1, Work, &inc, &b, B, &inc);
		a = ddot_(&n1, w, &inc, Al, &inc);
		a = -alp1*a; b = 1;
		dgemv_(&Trans, &n1, &n2, &a, Work1, &n1, w, &inc, &b, B, &inc);

		// (e) Compute T2
		a = 1.0; b = 0;
		dgemv_(&Trans, &n1, &n2, &a, Work1, &n1, w, &inc, &b, T, &inc);

		// (f) Compute J
		for(int j = 0; j < n2; j++) {
			for(int i = 0; i < n1; i++) {
				Work[i+j*n1] = Work1[i+j*n1]/D1[i];
			}
		}

		a = -1.0; b = 0;
		dgemm_(&Trans, &NoTrans, &n2, &n2, &n1, &a, Work1, &n1, 
			Work, &n1, &b, Jac, &n2);
		for(int i = 0; i < n2; i++) Jac[i+n2*i] += D2[i];
		a = -1.0/d; b = 1;
		dgemm_(&NoTrans, &Trans, &n2, &n2, &one, &a, v2, &n2, v2, &n2, &b, 
			Jac, &n2);
		a = -alp1; b = 1;
		dgemm_(&NoTrans, &Trans, &n2, &n2, &one, &a, T, &n2, T, &n2, &b, 
			Jac, &n2);

		// (g) Solve Jac x2 = b''2
		lwork = 128*n2;

		dsysvx_(&F, &U, &n2, &nrhs, Jac, &n2, Jac2, &n2, IPIV, B, &n2, X, &n2,
		&Rcond, &Ferr, &Berr, Work, &lwork, iwork, &info);

		for(int i = 0; i < n2; i++) {
			dm[i] = X[i];
		}

		// (h) Compute x1
		dcopy_(&n1, Al, &inc, dl, &inc);
		a = -1.0; b = 1.0;
		dgemv_(&NoTrans, &n1, &n2, &a, Work1, &n1, X, &inc, &b, dl, &inc);
		double bet = ddot_(&n1, dl, &inc, w, &inc);
		for(int i = 0; i < n1; i++) dl[i] = dl[i]/D1[i];
		a = alp1*bet;
		daxpy_(&n1, &a, w, &inc, dl, &inc);
	}

	//  Compute x3
	double alp2 = ddot_(&n1, v1, &inc, dl, &inc);
	double bet = ddot_(&n2, v2, &inc, dm, &inc);

	dd[0] = (-Ad[0] - alp2 - bet)/d;

  }

/* ===============================================================================================
   computedX_iter

   Input:
   =====
	n1:	number of points for point set 1
	n2:	number of points for point set 2
	beta:		current beta
	C:		Cost matrix
	lambda:		current values of Lagragians lambda
	mu:		current values of Lagragians mu
	nthreads:	number of threads for parallel computation
	
	
   =============================================================================================== */

  void OT1_OPTM::computedX_iter(int n1, int n2, double beta, double *C, double *lambda, double *mu, 
		double x, int nthreads)
  {

	int n1n2 = n1*n2;
	int ntot = n1 + n2;
	char Trans   = 'T';
	char NoTrans = 'N';
	int inc = 1;
	int one = 1;
	double a, b;

	dcopy_(&n1n2, C, &inc, Work1, &inc);
	a = 1.0; b = 1;
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, lambda, &n1, vone, &n2, &b, 
		Work1, &n1);
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, vone, &n1, mu, &n2, &b, 
		Work1, &n1);
	for(int i = 0; i < n1n2; i++) Work1[i] += x;

	vect_dcoupling(n1n2, Work1, beta, nthreads);

	a = 1; b = 0;
	dgemv_(&NoTrans, &n1, &n2, &a, Work1, &n1, vone, &inc, &b, v1, &inc);
	dgemv_(&Trans, &n1, &n2, &a, Work1, &n1, vone, &inc, &b, v2, &inc);

	dcopy_(&n1, vone, &inc, D1, &inc);
	a = -1./(2*alpha1);
	dscal_(&n1, &a, D1, &inc);
	for(int i = 0; i < n1; i++) D1[i] += v1[i];

	dcopy_(&n2, vone, &inc, D2, &inc);
	a = -1./(2*alpha2);
	dscal_(&n2, &a, D2, &inc);
	for(int i = 0; i < n2; i++) D2[i] += v2[i];

	double d = 0;
	for(int i = 0; i < n1n2; i++) d+= Work1[i];

	// (a) Compute G'
	a = -1.0/d; b = 1;
	dgemm_(&NoTrans, &Trans, &n1, &n2, &one, &a, v1, &n1, v2, &n2, &b, 
		Work1, &n1);

	// (b) Compute b'1 and b'2
	for(int i = 0; i < n1; i++) B[i] = -Al[i] + Ad[0]*v1[i]/d;
	for(int i = 0; i < n2; i++) B[i+n1] = -Am[i] + Ad[0]*v2[i]/d;

	double tol = 1.e-2;
	conjgrad_m.cgDriver(n1, n2, D1, D2, d, Work1, v1, v2, B, X, &Work2[0], &Work2[ntot],
	&Work2[2*ntot], &Work2[3*ntot], &Work2[4*ntot], &Work2[5*ntot], tol);

	for(int i = 0; i < n1; i++) {
		dl[i] = X[i];
	}
	for(int j = 0; j < n2; j++) {
		dm[j] = X[n1+j];
	}

	//  Compute x3
	double alp2 = ddot_(&n1, v1, &inc, dl, &inc);
	double bet = ddot_(&n2, v2, &inc, dm, &inc);

	dd[0] = (-Ad[0] - alp2 - bet)/d;

  }

/* ===============================================================================================
   solveG

   Input:
   =====
	n1:	number of points for point set 1
	m1:		measure on points1
	n2:	number of points for point set 2
	m2:		measure on points2
	C:		cost matrix
	beta:		parameter beta
	tol:		tolerance criteria
	nthreads:	number of threads for parallel computation
	
   Output:
   ======
	G:		coupling matrix
   =============================================================================================== */

  void OT1_OPTM::solveG(int n1, double *m1, int n2, double *m2, double *C, double *G, 
	double *lambda, double *mu, double *x, double beta, double tol, int *niter, int nthreads)
  {

	double *l_try = new double[n1];
	double *m_try = new double[n2];
	double x_try = 0;
	memset(l_try, 0, n1*sizeof(double));
	memset(m_try, 0, n2*sizeof(double));

	double err, err_old;

	computeRC(n1, m1, n2, m2, C, beta, lambda, mu, *x, &err, nthreads);

	int iter = 0;
	double step;
	while (err > tol)
	{
		err_old = err;
		computedX_iter(n1, n2, beta, C, lambda, mu, *x, nthreads);
//		computedX_direct2(n1, n2, beta, C, lambda, mu, *x, nthreads);

		step = 1.0;
		int it = 0;
		for(int i = 0; i < 30; i++) {
			for(int i = 0; i < n1; i++) l_try[i] = lambda[i] + step*dl[i];
			for(int j = 0; j < n2; j++) m_try[j] = mu[j] + step*dm[j];
			x_try = *x + step*dd[0];
			computeRC(n1, m1, n2, m2, C, beta, l_try, m_try, x_try, &err, nthreads);
			if(err < err_old) break;
			step = step/2;
			it++;
		}
		if(iter==100 || it==30) break;

		for(int i = 0; i < n1; i++) lambda[i] = l_try[i];
		for(int j = 0; j < n2; j++) mu[j] = m_try[j];
		*x = x_try;

		iter++;
	}

	*niter = iter;

	computeG(n1, n2, C, beta, lambda, mu, *x, G);

	delete [] l_try; delete [] m_try;
  }
		
/* ===============================================================================================
   earthMover

   Input:
   =====
	n1:	number of points for point set 1
	m1:		measure on points1
	n2:	number of points for point set 2
	m2:		measure on points2
	C:		cost matrix
	beta1:		starting beta
	nthreads:	number of threads for parallel computation
	
   Output:
   ======
	dist:		Optimal transport distance
	G:		coupling matrix
   =============================================================================================== */

   double OT1_OPTM::ot1_w(int n1, double *m1, int n2, double *m2, double *C,
	double *G, double *lambda, double *mu, double *x, double beta1, double betaf, double *U, 
	int iprint, int init, double a1, double a2, int nthreads)
  {

	// Define all variables needed to compute Earh Mover's distance

	double beta_val, tol;
	double dist, err_row, err_col;
	int n1n2 = n1*n2;
	int inc = 1;

	// Dimension and Initialize all arrays

	if(init==0) {
		initOT1(n1, n2, a1, a2);
	}

	// Initialize auxiliary variables lambda and mu
	double eps = 1.e-2;
	eps = 0.0;
	for(int i = 0; i < n1; i++) lambda[i] = eps*((double) std::rand())/RAND_MAX;
	for(int i = 0; i < n2; i++) mu[i] = eps*((double) std::rand())/RAND_MAX;
	*x = eps*((double) std::rand())/RAND_MAX;
	*x = 0;

	dist = 1.0;

	int niter;

	beta_val = beta1;

	if(iprint == 1) {
		std::cout << " " << std::endl;
		std::cout << "        " << "=================================================================================================================" << std::endl;
		std::cout << "        " << "       Beta           Iter              U                  Utot             Err_row         Err_col       Err_sum " << std::endl;
		std::cout << "        " << "=================================================================================================================" << std::endl;
	}

	double Utot=0;
	double xmass;
	double coef = std::sqrt(std::sqrt(10));
	double err_t;
	while(beta_val < betaf)
	{
		tol = std::max(1./beta_val, 1.e-5);
		tol = 1.e-5;
		solveG(n1, m1, n2, m2, C, G, lambda, mu, x, beta_val, tol, &niter, nthreads);

		checkMarginals(n1, m1, n2,  m2, G, &err_row, &err_col);
		err_t = 0.;
		for(int i = 0; i < n1*n2; i++) err_t += G[i];
		err_t = std::abs(err_t - 1.0);
		dist = ddot_(&n1n2, G, &inc, C, &inc); 
		xmass = 0;
		for(int i = 0; i < n1; i++) xmass += a1*m1[i]*m1[i];
		for(int i = 0; i < n2; i++) xmass += a2*m2[i]*m2[i];
		Utot = dist + xmass;

		if(iprint==1) {
			std::cout << "        " << "   " << std::setw(10)<< beta_val << "    ";
			std::cout << std::setw(10) << niter << "        " << std::setw(10) << dist << "        ";
			std::cout << std::setw(10) << Utot <<  "        ";
			std::cout << std::setw(10) << err_row <<  "        " << err_col << "     " << err_t << std::endl;
		}

		beta_val = beta_val*coef;
	}


	if(iprint==1) {
		std::cout << "        " << "=================================================================================================================" << std::endl;
		std::cout << " " << std::endl;
	}

	*U = Utot;
	return dist;

  }

/* ===============================================================================================
   Initialize arrays for EarchMover
   =============================================================================================== */

  void OT1_OPTM::initOT1(int n1, int n2, double a1, double a2)
  {

	alpha1 = a1;
	alpha2 = a2;
	int nmax = std::max(n1, n2);
	int n = std::max(nmax, 128);
	Work1 = new double[n1*n2];
	Work2 = new double[6*(n1+n2+1)];
	Work  = new double[n*nmax];
	Jac   = new double[nmax*nmax];
	Jac2  = new double[nmax*nmax];
	B     = new double[n1+n2+1];
	X     = new double[n1+n2+1];
	IPIV  = new int[nmax];
	iwork = new int[nmax];
	vone  = new double[nmax];
	Al    = new double[n1];
	Am    = new double [n2];
	Ad    = new double [1];
	D1    = new double[n1];
	D2    = new double [n2];
	v1    = new double [n1];
	v2    = new double [n2];
	w     = new double [nmax];
	T     = new double [nmax];
	dl    = new double [n1];
	dm    = new double [n2];
	dd    = new double [1];

	for(int i = 0; i < nmax; i++) vone[i] = 1.0;

  }

#endif
