/*================================================================================================
  MatVect.h
  Version 1: 12/1/2017

  Purpose: Sets of routine for defining matrix (Hessian) - Vect multiplication

Copyright (c) Patrice Koehl.

>>> SOURCE LICENSE >>>

  This library is free software; you can redistribute it and/or
  modify it under the terms of the GNU Lesser General Public
  License as published by the Free Software Foundation; either
  version 2.1 of the License, or (at your option) any later version.

  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public
  License along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

>>> END OF LICENSE >>>

================================================================================================== */

#ifndef _MATVECTCPU_
#define _MATVECTCPU_

/*================================================================================================
 Includes
================================================================================================== */

#include <math.h>
#include <cstdlib>
#include <vector>
#include <pthread.h>

#if defined(__ARM_NEON)
#include <arm_neon.h>
#endif

#if !defined(GPU)
#if defined(__x86_64__)
#include <immintrin.h>
#endif
#endif

/*================================================================================================
 spmv_thread: computes the product of the Hessian with a vector, using the form in which the Hessian
	      is written as a sum of tensor products (parallel code)
		The Hessian is stored in csr format; 
================================================================================================== */

template <typename T>
void* spmv_thread(void* data)
{
	int threadid = *((int *) data);
	threads_params<T> *tdata;
	tdata = (threads_params<T>*)&thparams[threadid];
	
	int start = tdata->start;
	int end   = tdata->end;

	T *X      = tdata->X;
	T *Y      = tdata->Y;
	int *ia   = tdata->ia;
	int *ja   = tdata->ja;
	T *val    = tdata->val;

	int col_start, col_end, j_col;
	T temp;

	T Xi[3], Xj[3], gi[3];
	T sum[3];

	for (int row = start; row < end; row++)
	{
		for(int k = 0; k < 3; k++) sum[k] = 0;
		for(int k = 0; k < 3; k++) Xi[k] = X[3*row+k];

		col_start = ia[row];
		col_end   = ia[row+1];

		for (int j = col_start; j < col_end; j++) {
			
			j_col = ja[j];
			for(int k = 0; k < 3; k++) {
				Xj[k] = X[3*j_col+k];
				gi[k] = val[3*j+k];
			}

			temp = gi[0]*(Xi[0]-Xj[0])+gi[1]*(Xi[1]-Xj[1])+gi[2]*(Xi[2]-Xj[2]);
			for(int k = 0; k < 3; k++) {
				sum[k] += temp*gi[k];
			}
		}

		for(int k = 0; k < 3; k++) {
			Y[3*row+k] += sum[k];
		}
	}

	return 0;
}

/*================================================================================================
 spmm_thread: computes the product of the Hessian with a matrix, using the form in which the Hessian
	      is written as a sum of tensor products (parallel code)
		The Hessian is stored in csr format; the matrix is stored in column major order
================================================================================================== */

template <typename T>
void* spmm_thread(void* data)
{
	int threadid = *((int *) data);
	threads_params<T> *tdata;
	tdata = (threads_params<T>*)&thparams[threadid];
	
	int start  = tdata->start;
	int end    = tdata->end;
	int Nvec   = tdata->Nvec;
	int Ncoord = tdata->Ncoord;

	T *X      = tdata->X;
	T *Y      = tdata->Y;
	int *ia   = tdata->ia;
	int *ja   = tdata->ja;
	T *val    = tdata->val;

	int col_start, col_end, j_col;
	T temp;

	T Xi[3], Xj[3], gi[3];
	T sum[3];

	int offset;

	for (int row = start; row < end; row++)
	{

		col_start = ia[row];
		col_end   = ia[row+1];

		offset = 0;
		for(int mcol = 0; mcol < Nvec; mcol++) { 

			for(int k = 0; k < 3; k++) sum[k] = 0;

			for(int k = 0; k < 3; k++) {
				Xi[k] = X[offset + 3*row+k];
			}

			for (int j = col_start; j < col_end; j++) {
			
				j_col = ja[j];
				for(int k = 0; k < 3; k++) {
					Xj[k] = X[offset+3*j_col+k];
					gi[k] = val[3*j+k];
				}

				temp = gi[0]*(Xi[0]-Xj[0])+gi[1]*(Xi[1]-Xj[1])+gi[2]*(Xi[2]-Xj[2]);
				for(int k = 0; k < 3; k++) {
					sum[k] += temp*gi[k];
				}
			}

			for(int k = 0; k < 3; k++) {
				Y[offset + 3*row+k] += sum[k];
			}

			offset += Ncoord;
		}
	
	}

	return 0;

}

/*================================================================================================
 spmv_go_thread: computes the product of the Hessian with a vector, where the Hessian is derived from
		the bond, angle, and dihedral angle part of the Go potential
================================================================================================== */

template <typename T>
void* spmv_go_thread(void* data)
{
	int threadid = *((int *) data);

	threads_params<T> *tdata;
	tdata = (threads_params<T>*)&thparams[threadid];
	
	int start = tdata->start;
	int end   = tdata->end;

	int Ncoord = tdata->Ncoord;
	int natoms = Ncoord/3;

	T *bonds   = tdata->bonds;
	T *angles  = tdata->angles;
	T *diheds  = tdata->diheds;

	T *X      = tdata->X;
	T *Y      = tdata->Y;

	T val;

	T Xi[12];
	T vi[12];

/*================================================================================================
 	Hessian-vect: bond contribution
================================================================================================== */

	int iatm2, iatm1, iat0, iat1, iat2;

	for (int atom = start; atom < end ; atom++)
	{
		iat0 = atom-1; iat1 = atom; iat2 = atom+1;
		if(atom > 0) {
			for(int k = 0; k < 3; k++) {
				vi[k]   = bonds[3*iat0+k];
				vi[k+3] = -vi[k];
				Xi[k]   = X[3*iat0+k];
				Xi[k+3] = X[3*iat1+k];
			}
			val = 0.0;
			for(int k = 0; k < 6; k++) val += Xi[k]*vi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k+3];
		}
		if(atom < natoms-1) {
			for(int k = 0; k < 3; k++) {
				vi[k]   = bonds[3*iat1+k];
				vi[k+3] = -vi[k];
				Xi[k]   = X[3*iat1+k];
				Xi[k+3] = X[3*iat2+k];
			}
			val = 0.0;
			for(int k = 0; k < 6; k++) val += Xi[k]*vi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k];
		}
	}

/*================================================================================================
 	Hessian-vect: angle contribution
================================================================================================== */

	for (int atom = start; atom < end ; atom++)
	{
		iatm1 = atom-2; iat0 = atom-1; iat1 = atom; iat2 = atom+1;
		if(atom > 1) {
			for(int k = 0; k < 9; k++) {
				vi[k]   = angles[9*iatm1+k];
				Xi[k]   = X[3*iatm1+k];
			}
			val = 0.0;
			for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k+6];
		}
		if(atom>0 && atom < natoms-1) {
			for(int k = 0; k < 9; k++) {
				vi[k]   = angles[9*iat0+k];
				Xi[k]   = X[3*iat0+k];
			}
			val = 0.0;
			for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k+3];
		}
		if(atom < natoms-2) {
			for(int k = 0; k < 9; k++) {
				vi[k]   = angles[9*iat1+k];
				Xi[k]   = X[3*iat1+k];
			}
			val = 0.0;
			for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k];
		}
	}

/*================================================================================================
 	Hessian-vect: dihedral angle contribution
================================================================================================== */

	for (int atom = start; atom < end ; atom++)
	{

		iatm2 = atom -3; iatm1 = atom - 2; iat0 = atom - 1; iat1 = atom;

		if(atom > 2) {
			for(int k = 0; k < 12; k++) {
				vi[k]   = diheds[12*iatm2+k];
				Xi[k]   = X[3*iatm2+k];
			}

			val = 0;
			for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k+9];
		}

		if(atom > 1 && atom < natoms-1) {
			for(int k = 0; k < 12; k++) {
				vi[k]   = diheds[12*iatm1+k];
				Xi[k]   = X[3*iatm1+k];
			}

			val = 0;
			for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k+6];
		}

		if(atom > 0 && atom < natoms - 2) {
			for(int k = 0; k < 12; k++) {
				vi[k]   = diheds[12*iat0+k];
				Xi[k]   = X[3*iat0+k];
			}

			val = 0;
			for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k+3];
		}

		if(atom < natoms - 3) {
			for(int k = 0; k < 12; k++) {
				vi[k]   = diheds[12*iat1+k];
				Xi[k]   = X[3*iat1+k];
			}

			val = 0;
			for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
			for(int k = 0; k < 3; k++) Y[3*iat1+k] += val*vi[k];
		}
	}

	return 0;
}

/*================================================================================================
 spmm_go_thread: computes the product of the Hessian with a matrix, where the Hessian is derived from
		the bond, angle, and dihedral angle part of the Go potential
================================================================================================== */

template <typename T>
void* spmm_go_thread(void* data)
{
	int threadid = *((int *) data);

	threads_params<T> *tdata;
	tdata = (threads_params<T>*)&thparams[threadid];
	
	int start = tdata->start;
	int end   = tdata->end;

	int Nvec   = tdata->Nvec;
	int Ncoord = tdata->Ncoord;
	int natoms = Ncoord/3;

	T *bonds   = tdata->bonds;
	T *angles  = tdata->angles;
	T *diheds  = tdata->diheds;

	T *X      = tdata->X;
	T *Y      = tdata->Y;

	T val;

	T Xi[12];
	T vi[12];

	int offset = 0;

/*================================================================================================
 	Hessian-vect: bond contribution
================================================================================================== */

	int iatm2, iatm1, iat0, iat1, iat2;

	for(int c = 0; c < Nvec; c++) {

		for (int atom = start; atom < end ; atom++)
		{
			iat0 = atom-1; iat1 = atom; iat2 = atom+1;
			if(atom > 0) {
				for(int k = 0; k < 3; k++) {
					vi[k]   = bonds[3*iat0+k];
					vi[k+3] = -vi[k];
					Xi[k]   = X[offset+3*iat0+k];
					Xi[k+3] = X[offset+3*iat1+k];
				}
				val = 0.0;
				for(int k = 0; k < 6; k++) val += Xi[k]*vi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k+3];
			}
			if(atom < natoms-1) {
				for(int k = 0; k < 3; k++) {
					vi[k]   = bonds[3*iat1+k];
					vi[k+3] = -vi[k];
					Xi[k]   = X[offset+3*iat1+k];
					Xi[k+3] = X[offset+3*iat2+k];
				}
				val = 0.0;
				for(int k = 0; k < 6; k++) val += Xi[k]*vi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k];
			}
		}

/*================================================================================================
 	Hessian-vect: angle contribution
================================================================================================== */

		for (int atom = start; atom < end ; atom++)
		{
			iatm1 = atom-2; iat0 = atom-1; iat1 = atom; iat2 = atom+1;
			if(atom > 1) {
				for(int k = 0; k < 9; k++) {
					vi[k]   = angles[9*iatm1+k];
					Xi[k]   = X[offset+3*iatm1+k];
				}
				val = 0.0;
				for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k+6];
			}
			if(atom>0 && atom < natoms-1) {
				for(int k = 0; k < 9; k++) {
					vi[k]   = angles[9*iat0+k];
					Xi[k]   = X[offset+3*iat0+k];
				}
				val = 0.0;
				for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k+3];
			}
			if(atom < natoms-2) {
				for(int k = 0; k < 9; k++) {
					vi[k]   = angles[9*iat1+k];
					Xi[k]   = X[offset+3*iat1+k];
				}
				val = 0.0;
				for(int k = 0; k < 9; k++) val += Xi[k]*vi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k];
			}
		}

/*================================================================================================
 	Hessian-vect: dihedral angle contribution
================================================================================================== */

		for (int atom = start; atom < end ; atom++)
		{

			iatm2 = atom -3; iatm1 = atom - 2; iat0 = atom - 1; iat1 = atom;

			if(atom > 2) {
				for(int k = 0; k < 12; k++) {
					vi[k]   = diheds[12*iatm2+k];
					Xi[k]   = X[offset+3*iatm2+k];
				}

				val = 0;
				for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k+9];
			}

			if(atom > 1 && atom < natoms-1) {
				for(int k = 0; k < 12; k++) {
					vi[k]   = diheds[12*iatm1+k];
					Xi[k]   = X[offset+3*iatm1+k];
				}

				val = 0;
				for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k+6];
			}

			if(atom > 0 && atom < natoms - 2) {
				for(int k = 0; k < 12; k++) {
					vi[k]   = diheds[12*iat0+k];
					Xi[k]   = X[offset+3*iat0+k];
				}

				val = 0;
				for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k+3];
			}

			if(atom < natoms - 3) {
				for(int k = 0; k < 12; k++) {
					vi[k]   = diheds[12*iat1+k];
					Xi[k]   = X[offset+3*iat1+k];
				}

				val = 0;
				for(int k = 0; k < 12; k++) val += vi[k]*Xi[k];
				for(int k = 0; k < 3; k++) Y[offset+3*iat1+k] += val*vi[k];
			}
		}

		offset += Ncoord;
	}

	return 0;
}

/*================================================================================================
 MatVect: computes the product of the Hessian with multiple vectors, using the form in which the Hessian
	      is written as a sum of tensor products
================================================================================================== */

template <typename T>
void spmv_CPU(int N, int Nvec, T *X, int ldx, T *Y, int ldy, void *mvparams)
{
	opparams<T> *params;
	params = (opparams<T> *) mvparams;
	int nthreads = params->nthreads;

	memset(Y, 0, N*sizeof(T));

	int Natoms = N/3;
	int nval  = Natoms / nthreads;
	int start, end;

	if(params->csrHessian->type==2) {
		for(int i = 0; i < nthreads; i++) 
		{
			threadids[i]=i;

			threads_params<T> *tdata = &thparams[i];

			tdata->bonds   = params->csrHessian->bonds;
			tdata->angles  = params->csrHessian->angles;
			tdata->diheds  = params->csrHessian->diheds;

			start = i*nval;
			end = start + nval;
			if(i == nthreads-1) end = Natoms;
			tdata->start  = start;
			tdata->end   = end;
			tdata->Ncoord = N;

			tdata->X  = X;
			tdata->Y  = Y;

			pthread_create(&threads[i], NULL, spmv_go_thread<T>, (void*) &threadids[i]);
		}
	
		for (int i=0; i < nthreads; i++)
		{
			pthread_join(threads[i], NULL);
		}
	}

	for(int i = 0; i < nthreads; i++) 
	{
		threadids[i]=i;

		threads_params<T> *tdata = &thparams[i];

		tdata->ia   = params->csrHessian->ia;
		tdata->ja   = params->csrHessian->ja;
		tdata->val  = params->csrHessian->val;

		start = i*nval;
		end = start + nval;
		if(i == nthreads-1) end = Natoms;
		tdata->start  = start;
		tdata->end   = end;

		tdata->X  = X;
		tdata->Y  = Y;

		pthread_create(&threads[i], NULL, spmv_thread<T>, (void*) &threadids[i]);
	}
	
	for (int i=0; i < nthreads; i++)
	{
		pthread_join(threads[i], NULL);
	}
}

/*================================================================================================
 MatVect: computes the product of the Hessian with multiple vectors, using the form in which the Hessian
	      is written as a sum of tensor products
================================================================================================== */

template <typename T>
void spmm_CPU(int N, int Nvec, T *X, int ldx, T *Y, int ldy, void *mvparams)
{
	opparams<T> *params;
	params = (opparams<T> *) mvparams;
	int nthreads = params->nthreads;

	memset(Y, 0, N*Nvec*sizeof(T));

	int Natoms = N/3;
	int nval  = Natoms / nthreads;
	int start, end;

	if(params->csrHessian->type==2) {
		for(int i = 0; i < nthreads; i++) 
		{
			threadids[i]=i;

			threads_params<T> *tdata = &thparams[i];

			tdata->bonds   = params->csrHessian->bonds;
			tdata->angles  = params->csrHessian->angles;
			tdata->diheds  = params->csrHessian->diheds;

			start = i*nval;
			end = start + nval;
			if(i == nthreads-1) end = Natoms;
			tdata->start  = start;
			tdata->end   = end;
			tdata->Nvec   = Nvec;
			tdata->Ncoord = N;

			tdata->X  = X;
			tdata->Y  = Y;

			pthread_create(&threads[i], NULL, spmm_go_thread<T>, (void*) &threadids[i]);
		}
	
		for (int i=0; i < nthreads; i++)
		{
			pthread_join(threads[i], NULL);
		}
	}

	for(int i = 0; i < nthreads; i++) 
	{
		threadids[i]=i;

		threads_params<T> *tdata = &thparams[i];

		tdata->ia   = params->csrHessian->ia;
		tdata->ja   = params->csrHessian->ja;
		tdata->val  = params->csrHessian->val;

		start = i*nval;
		end = start + nval;
		if(i == nthreads-1) end = Natoms;
		tdata->start  = start;
		tdata->end    = end;
		tdata->Nvec   = Nvec;
		tdata->Ncoord = N;

		tdata->X  = X;
		tdata->Y  = Y;

		pthread_create(&threads[i], NULL, spmm_thread<T>, (void*) &threadids[i]);
	}
	
/*      ==========================================================================================
	Join all the threads (to make sure they are all finished)
        ========================================================================================== */

	for (int i=0; i < nthreads; i++)
	{
		pthread_join(threads[i], NULL);
	}

}

/*================================================================================================
  spmvED_CPU 

	Purpose:
	========
	Computes the product of a matrix B in CSR format with a vector X
	where B = A + U U^T
	(used for explicit deflation when computing eigenpairs of the matrix A)
	It uses threads when the number of columns of X is > 1

	Arguments:
	=========

	N   	(input) integer
		On entry, number of rows / cols of the matrix A

	Nvec	(input) integer
		On entry, number of columns of X (always 1)

	X	(input) float or double array of size N
		On entry, input matrix X

	ldx	(input) integer
		leading dimension of X

	Y	(output) float or double array of size N
		On exit, the result of Y = (A + U U^T) X

	ldy	(input) integer
		leading dimension of Y

	mvparams (input) structure
		On entry, information about the matrices A, U, and Sigma

================================================================================================== */

template <typename T>
void spmvED_CPU(int N, int Nvec, T *X, int ldx, T *Y, int ldy, void *mvparams)
{
	opparams<T> *params;
	params = (opparams<T> *) mvparams;

	int n_eed = params->need;
	T *Space = params->space;
	T *Sigma = params->sigma;

	char Trans = 'T'; char noTrans = 'N';
	T alpha; T beta;
	int inc = 1;

/*      ==========================================================================================
	Compute A X
	========================================================================================== */

	spmv_CPU<T>(N, Nvec, X, ldx, Y, ldy, mvparams);

/*      ==========================================================================================
	Now compute U Sigma U^T X using dgemv
	========================================================================================== */

	alpha = 1; beta = 0.0;
	eig_dgemv_(&Trans, &N, &n_eed, &alpha, params->Ud, &N, X, &inc, 
		&beta, Space, &inc);

	for(int i = 0; i < n_eed; i++) {
		alpha = Sigma[i]; 
		Space[i] *= alpha;
	} 

	alpha = 1; beta = 1.0;
	eig_dgemv_(&noTrans, &N, &n_eed, &alpha, params->Ud, &N, 
		Space, &inc, &beta, Y, &inc);

}

/*================================================================================================
  spmmED_CPU

	Purpose:
	========
	Computes the product of a matrix B in CSR format with a matrix X
	where B = A + U U^T
	(used for explicit deflation when computing eigenpairs of the matrix A)
	It uses threads when the number of columns of X is > 1

	Arguments:
	=========

	N   	(input) integer
		On entry, number of rows / cols of the matrix A

	Nvec	(input) integer
		On entry, number of columns of X 

	X	(input) float or double array of size N x Nvec
		On entry, input matrix X

	ldx	(input) integer
		leading dimension of X

	Y	(output) float or double array of size N x Nvec
		On exit, the result of Y = (A + U U^T) X

	ldy	(input) integer
		leading dimension of Y

	mvparams (input) structure
		On entry, information about the matrices A, U, and Sigma

================================================================================================== */

template <typename T>
void spmmED_CPU(int N, int Nvec, T *X, int ldx, T *Y, int ldy, void *mvparams)
{
	opparams<T> *params;
	params = (opparams<T> *) mvparams;

	int n_eed = params->need;
	T *Space = params->space;
	T *Sigma = params->sigma;

	char Trans = 'T'; char noTrans = 'N';
	T alpha; T beta;
	int inc;

/*      ==========================================================================================
	Compute A X
	========================================================================================== */

	spmm_CPU<T>(N, Nvec, X, ldx, Y, ldy, mvparams);

/*      ==========================================================================================
	Now compute U Sigma U^T X using dgemm
	========================================================================================== */

	alpha = 1; beta = 0.0;
	eig_dgemm_(&Trans, &noTrans, &n_eed, &Nvec, &N, &alpha, params->Ud, &N, X, &N, 
		&beta, Space, &n_eed);

	inc = n_eed;
	for(int i = 0; i < n_eed; i++) {
		alpha = Sigma[i]; 
		eig_dscal_(&Nvec, &alpha, &Space[i], &inc);
	} 

	alpha = 1; beta = 1.0;
	eig_dgemm_(&noTrans, &noTrans, &N, &Nvec, &n_eed, &alpha, params->Ud, &N, 
		Space, &n_eed, &beta, Y, &N);

}

/*================================================================================================
  work thread for inplace_dgemm
================================================================================================== */

template <typename T>
void* inplace_thread(void* data) {

	int threadid = *((int *) data);
	threads_mat<T> *tdata;
	tdata = (threads_mat<T>*)&thmat[threadid];
	
	int M     = tdata->M;
	int start = tdata->start;
	int end   = tdata->end;
	T *A      = tdata->A;
	T *B      = tdata->B;
	int lda   = tdata->lda;
	int ldb   = tdata->ldb;
	T *row    = tdata->row;

	T sum;
	for (int i = start; i < end; ++i) {
		// Load A[i, :] into temp row
		for (int j = 0; j < M; ++j)
			row[j] = A[i + j * lda];  // A(i,j)

		// Compute new A[i, j] = dot(row, B[:,j])
		for (int j = 0; j < M; ++j) {
			sum = 0;
			for (int k = 0; k < M; ++k)
				sum += row[k] * B[k + j * ldb];
			A[i + j * lda] = sum;
		}
	}

	return nullptr;
}

/*================================================================================================
  inplace_dgemm

        Purpose:
        ========
        Computes A = A * B (i.e. inplace matrix = matrix computation)

        A is N*M, while B = M*M

        Arguments:
        ==========

        N       (input) integer
                On entry, number of rows of the matrix A

        M       (input) integer
                On entry, number of cols of the matrix A

        A       (input) array of floats or double, size N*M
                On entry, input matrix A
                On exit, matrix A * B

        lda     (input) int
                On entry, leading dimension of A
                
        B       (output) array of floats or double, size M*M
                On entry, input matrix B

        ldb     (input) int
                On entry, leading dimension of B

	space	(input) array of floats or double, size M*M
		On entry, work array
                
================================================================================================== */

template <typename T>
void inplace_dgemm_CPU(int N, int M, T* A, int lda, T* B, int ldb, T *space, int nthreads) 
{

	int nval = N / nthreads;
	int start = 0;

	for(int i = 0; i < nthreads; i++) {
		threadids[i] = i;
		
		threads_mat<T> *tdata = &thmat[i];
		tdata->N     = N;
		tdata->M     = M;
		tdata->A     = A;
		tdata->lda   = lda;
		tdata->B     = B;
		tdata->ldb   = ldb;
		tdata->start = start;
		tdata->end   = start + nval;
		tdata->row   = &space[i*M];
		if(i == nthreads-1) tdata->end = N;

		pthread_create(&threads[i], NULL, inplace_thread<T>, (void*) &threadids[i]);
		start += nval;
    }

	for (int i = 0; i < nthreads; i++) {
		pthread_join(threads[i], NULL);
	}
}

#endif
