/*================================================================================================
  MatVect.h
  Version 1: 12/1/2017

  Purpose: Sets of routine for defining matrix (Hessian) - Vect multiplication

Copyright (c) Patrice Koehl.

>>> SOURCE LICENSE >>>

  This library is free software; you can redistribute it and/or
  modify it under the terms of the GNU Lesser General Public
  License as published by the Free Software Foundation; either
  version 2.1 of the License, or (at your option) any later version.

  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public
  License along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

>>> END OF LICENSE >>>

================================================================================================== */

#ifndef _MATVECT_
#define _MATVECT_

/*================================================================================================
 Includes
================================================================================================== */

#if defined(GPU)
#include "MatVect_GPU.h"
#else
#include "MatVect_CPU.h"
#endif

/*================================================================================================
  spmv

	Purpose:
	========
	Computes y = (A * x)
	where A is a symmetric sparse matrix stored in CSR format and x is a vector

	This is the driver routine that chooses between CPU and GPU version of the procedure

	Arguments:
	==========

	N   	(input) integer
		On entry, number of rows / cols of the matrix A

	M	(input) integer
		On entry, just 1! (used for consistency)

	X	(input) array of floats or double, size N
		On entry, input vector x 

	ldx	(input) int
		On entry, leading dimension of x
		
	Y	(output) array of floats or double, size N*M
		On exit, vector y = a*A*x + b*y

	ldy	(input) int
		On entry, leading dimension of y
		
	mvparams (input) structure
		On entry, information about the matriX A


================================================================================================== */

template <typename T>
void spmv(int N, int M, T *X, int ldx, T *Y, int ldy, void *mvparams) {

#if defined(GPU)
	spmv_GPU(N, M, X, ldx, Y, ldy, mvparams);
#else
	spmv_CPU(N, M, X, ldx, Y, ldy, mvparams);
#endif

}

/*================================================================================================
  spmm

	Purpose:
	========
	Optimized version of sparse matrix-matrix multiplication

	This is the driver routine that chooses between CPU and GPU version of the procedure

        Arguments:
        =========

        N       (input) integer
                On entry, number of rows / cols of the matrix A

        Nvec    (input) integer
                On entry, number of columns of X 

        X       (input) float or double array of size N x Nvec
                On entry, input matrix X

        ldx     (input) integer
                leading dimension of X

        Y       (output) float or double array of size N x Nvec
                On exit, the result of Y = A X

        ldy     (input) integer
                leading dimension of Y

        mvparams (input) structure
                On entry, information about the matrix A

================================================================================================== */

template <typename T>
void spmm(int N, int Nvec, T *X, int ldx, T *Y, int ldy, void *mvparams)
{
#if defined(GPU)
	spmm_GPU(N, Nvec, X, ldx, Y, ldy, mvparams);
#else
	spmm_CPU(N, Nvec, X, ldx, Y, ldy, mvparams);
#endif

}

/*================================================================================================
  spmvED 

	Purpose:
	========
	Computes the product of a matrix B in CSR format with a vector X
	where B = A + U U^T
	(used for explicit deflation when computing eigenpairs of the matrix A)

	Chooses between CPU and GPU

	Arguments:
	=========

	N   	(input) integer
		On entry, number of rows / cols of the matrix A

	Nvec	(input) integer
		On entry, number of columns of X (always 1)

	X	(input) float or double array of size N
		On entry, input matrix X

	ldx	(input) integer
		leading dimension of X

	Y	(output) float or double array of size N
		On exit, the result of Y = (A + U U^T) X

	ldy	(input) integer
		leading dimension of Y

	mvparams (input) structure
		On entry, information about the matrices A, U, and Sigma

================================================================================================== */

template <typename T>
void spmvED(int N, int Nvec, T *X, int ldx, T *Y, int ldy, void *mvparams)
{
#if defined(GPU)
	spmvED_GPU(N, Nvec, X, ldx, Y, ldy, mvparams);
#else
	spmvED_CPU(N, Nvec, X, ldx, Y, ldy, mvparams);
#endif

}

/*================================================================================================
  spmmED

	Purpose:
	========
	Computes the product of a matrix B in CSR format with a matrix X
	where B = A + U U^T
	(used for explicit deflation when computing eigenpairs of the matrix A)

	Chooses between GPU and CPU

	Arguments:
	=========

	N   	(input) integer
		On entry, number of rows / cols of the matrix A

	Nvec	(input) integer
		On entry, number of columns of X 

	X	(input) float or double array of size N x Nvec
		On entry, input matrix X

	ldx	(input) integer
		leading dimension of X

	Y	(output) float or double array of size N x Nvec
		On exit, the result of Y = (A + U U^T) X

	ldy	(input) integer
		leading dimension of Y

	mvparams (input) structure
		On entry, information about the matrices A, U, and Sigma

================================================================================================== */

template <typename T>
void spmmED(int N, int Nvec, T *X, int ldx, T *Y, int ldy, void *mvparams)
{
#if defined(GPU)
	spmmED_GPU(N, Nvec, X, ldx, Y, ldy, mvparams);
#else
	spmmED_CPU(N, Nvec, X, ldx, Y, ldy, mvparams);
#endif

}

#endif
