/*================================================================================================
  Network.h
  Version 1: 9/29/2019

  Purpose: Computes number of points within a given cutoff, using a grid approach

Copyright (c) Patrice Koehl.

================================================================================================== */

#ifndef _NETWORK_H_
#define _NETWORK_H_


/*================================================================================================
 Includes
================================================================================================== */

#include "Edges.h"
#include "../../Delcx/include/delcx.h"
#include "Network_thread_tools.h"

DELCX delcx;

/*================================================================================================
 Network class
================================================================================================== */

template <typename T>
class Network {

	public:

		// defines elastic network
		void network(std::vector<Atoms<T> >& atoms, T cutoff, int mindiff, 
		std::vector<Edges<T> >& List, int nthreads);

		// elastic network based on Delaunay
		void delaunay(std::vector<Atoms<T> >& atoms, int mindiff, int potential,
		std::vector<Edges<T> >& List);

	private:

		// size of the box containing all atoms
		void BoxSize(std::vector<Atoms<T> >& atoms, T cutoff, T *LowerCorner, int *Nx, int *Ny, int *Nz);

		// organize cell linked lists
		void CellLists(std::vector<Atoms<T> >& atoms, T *LowerCorner, int Nx, int Ny, int Nz,
		T cutoff, int *ListCell, int *Head);

		// builds neighbour list (parallel)
		void GetNeighbours_thread(std::vector<Atoms<T> >& atoms, T cutoff, int mindiff, 
		int Nx, int Ny, int Nz, int *ListCell, int *Head, std::vector<Edges<T> >& Listcontact, 
		int nthreads);

		// builds neighbour list (serial)
		void GetNeighbours(std::vector<Atoms<T> >& atoms, T cutoff, int mindiff, int Nx, int Ny, int Nz, 
		int *ListCell, int *Head, std::vector<Edges<T> >& Listcontact);

		// builds neighbour list (serial)
		void GetNeighbours_bf(std::vector<Atoms<T> >& atoms, T cutoff, int mindiff,
		std::vector<Edges<T> >& Listcontact);

	};

/*================================================================================================
 Define size of the box that contains all the atoms
================================================================================================== */

template <typename T>
void Network<T>::BoxSize(std::vector<Atoms<T> >& atoms, T cutoff, T *LowerCorner, int *Nx, int *Ny, int *Nz)
{

	int natoms = atoms.size();

/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
	Output:
		LowerCorner: position of the lower, left, front corner of the grid considered
		Nx         : number of cells along X direction
		Ny         : number of cells along Y direction
		Nz         : number of cells along Z direction
	========================================================================================== */

/*	==========================================================================================
  	Maximum and minimum values for each coordinate set 
	========================================================================================== */
	
	T minx, miny, minz;
	T maxx, maxy, maxz;
	T x,y,z;
	
	minx = atoms[0].coord[0];
	maxx = minx;
	
	miny = atoms[0].coord[1];
	maxy = miny;

	minz = atoms[0].coord[2];
	maxz = minz;

	for(int i = 1; i < natoms; i++)
	{
		x = atoms[i].coord[0];
		y = atoms[i].coord[1];
		z = atoms[i].coord[2];
		if(minx > x) minx = x;
		if(maxx < x) maxx = x;
		if(miny > y) miny = y;
		if(maxy < y) maxy = y;
		if(minz > z) minz = z;
		if(maxz < z) maxz = z;
	}
		 					
/*	==========================================================================================
  	Define lower left front corner of the grid: at (minx-cutoff,miny-cutoff,minz-cutoff)
        (shifted by cutoff to ensure that all non-empty cells are "inside")
	========================================================================================== */

	LowerCorner[0] = minx - cutoff;
	LowerCorner[1] = miny - cutoff;
	LowerCorner[2] = minz - cutoff;

/*	==========================================================================================
  	Calculate the number of cubes based on cutoff 
	========================================================================================== */

	T cube_size = cutoff/2;


	*Nx = ceil((maxx + cutoff - LowerCorner[0])/cube_size);
	*Ny = ceil((maxy + cutoff - LowerCorner[1])/cube_size);
	*Nz = ceil((maxz + cutoff - LowerCorner[2])/cube_size);

}

/*================================================================================================
 Define linked list cell algorithm:
 atoms are organized into linked lists, whose heads are assigned to the cell containing the
 atoms in the list; this saves space as the list per cells do not need to be stored
================================================================================================== */

template <typename T>
void Network<T>::CellLists(std::vector<Atoms<T> >& atoms, T *LowerCorner, int Nx, int Ny, int Nz,
		T cutoff, int *ListCell, int *Head)
{

	int natoms = atoms.size();

/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
		LowerCorner: position of the lower, left, front corner of the grid considered
		Nx         : number of cells along X direction
		Ny         : number of cells along Y direction
		Nz         : number of cells along Z direction
	Output:
		ListCell   : For each atom i, pointer to the next atom in the same cell as i
		Head       : For each cell, pointer to the first atom in the cell (-1 if empty)
	========================================================================================== */

	int EMPTY = -1;
	int Ncells = Nx*Ny*Nz;

/*	==========================================================================================
	Initialize each cell to be empty
	========================================================================================== */

	for(int i = 0; i < Ncells; i++) {
		Head[i] = EMPTY;
	}

/*	==========================================================================================
	Assign each atom to a cell by connecting it to its linked list
	========================================================================================== */

        int Ic, Jc, Kc, Icell;
	T grid_size = cutoff/2;

        for(int i = 0; i < natoms; i++)
        {
                Ic = floor((atoms[i].coord[0]-LowerCorner[0])/grid_size);
                Jc = floor((atoms[i].coord[1]-LowerCorner[1])/grid_size);
                Kc = floor((atoms[i].coord[2]-LowerCorner[2])/grid_size);

                Icell = (Kc*Nx*Ny) + (Jc*Nx) + Ic;
		ListCell[i] = Head[Icell];
		Head[Icell] = i;
        }

}

/*================================================================================================
 Find list of neighbours, i.e. list of atoms that are within "cutoff" from each other, using a
 multi-threaded algorithm
================================================================================================== */

template <typename T>
void Network<T>::GetNeighbours_thread(std::vector<Atoms<T> >& atoms, T cutoff, int mindiff, int Nx, int Ny, int Nz, 
	int *ListCell, int *Head, std::vector<Edges<T> >& Listcontact, int nthreads)
{
/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
		mindiff	   : minimum "distance" in residue number
		Nx         : number of cells along X direction
		Ny         : number of cells along Y direction
		Nz         : number of cells along Z direction
		ListCell   : For each atom i, pointer to the next atom in the same cell as i
		Head       : For each cell, pointer to the first atom in the cell (-1 if empty)
	Output:
		Listcontact: list of contacts
	========================================================================================== */

/*	==========================================================================================
	Find list of non-empty cells
	========================================================================================== */

	int EMPTY = -1;
	int i, Icell;
	int N1,N2;
	T cut2 = cutoff * cutoff;

	int *Cells = new int[Nx*Ny*Nz];

	int Ncells = 0;
	for(int Ic = 0; Ic < Nx; Ic++) {
		for(int Jc = 0; Jc < Ny; Jc++) {
			for(int Kc = 0; Kc < Nz; Kc++) {
                		Icell = (Kc*Nx*Ny) + (Jc*Nx) + Ic;
				i = Head[Icell];
				if (i != EMPTY) {
					Cells[Ncells]=Icell;
					Ncells++;
				}
			}
		}
	}

/*	==========================================================================================
	Break list to all threads and send jobs
	========================================================================================== */

	int nval = Ncells/nthreads;

	for (int i=0; i < nthreads; i++)
	{
		N1 = i*nval;
		N2 = N1 + nval;
		if(i == nthreads-1) N2 = Ncells;

		threadids[i] = i;

		grids[i].firstcell = N1;
		grids[i].lastcell  = N2;
		grids[i].Nx        = Nx;
		grids[i].Ny        = Ny;
		grids[i].Nz        = Nz;
		grids[i].cutoff2   = cut2;
		grids[i].mindiff   = mindiff;
		grids[i].Cells     = Cells;
		grids[i].ListCell  = ListCell;
		grids[i].Head      = Head;
		grids[i].atoms     = atoms;

		std::vector<Edges<T> > List;
		grids[i].Contacts = List;

		pthread_create(&threads[i], NULL, neighbours_thread<T>, (void*) &threadids[i]);
	}

/*	==========================================================================================
	Join all the threads (to make sure they are all finished)
	========================================================================================== */

	std::vector<Edges<T> > Temp;
	for (int i=0; i < nthreads; i++)
	{
		pthread_join(threads[i], NULL);
		Temp = Listcontact;
		Listcontact.clear();
		std::merge(Temp.begin(),Temp.end(),grids[i].Contacts.begin(),grids[i].Contacts.end(),
		std::back_inserter(Listcontact), sortEdges<T>);
	}
	Temp.clear();

	delete [] Cells;

}

/*================================================================================================
 Find list of neighbours, i.e. list of atoms that are within "cutoff" from each other
 (serial version, i.e. on one processor)
================================================================================================== */

template <typename T>
void Network<T>::GetNeighbours(std::vector<Atoms<T> >& atoms, T cutoff, int mindiff, int Nx, int Ny, int Nz, 
	int *ListCell, int *Head, std::vector<Edges<T> >& Listcontact)
{
/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
		mindiff    : minimum distance in residue number
		Nx         : number of cells along X direction
		Ny         : number of cells along Y direction
		Nz         : number of cells along Z direction
		ListCell   : For each atom i, pointer to the next atom in the same cell as i
		Head       : For each cell, pointer to the first atom in the cell (-1 if empty)
	Output:
		Listcontact: list of contacts
	========================================================================================== */

/*	==========================================================================================
	Scan all inner cells
	========================================================================================== */

	int EMPTY = -1;
	int i, j, Icell, Jcell;
	int i1,j1;
	int N;
	int List[100];
	T cut2 = cutoff * cutoff;
	T dist;

	int *Cells = new int[Nx*Ny*Nz];

	int Ncells = 0;
	for(int Ic = 0; Ic < Nx; Ic++) {
		for(int Jc = 0; Jc < Ny; Jc++) {
			for(int Kc = 0; Kc < Nz; Kc++) {
                		Icell = (Kc*Nx*Ny) + (Jc*Nx) + Ic;
				i = Head[Icell];
				if (i != EMPTY) {
					Cells[Ncells]=Icell;
					Ncells++;
				}
			}
		}
	}

	for (int c = 0; c < Ncells; c++) {
		Icell = Cells[c];
		i = Head[Icell];
		Neighbours(Icell, Nx, Ny, Nz, &N, List);
		while ( i != EMPTY) {
			j = Head[Icell];
			while ( j != EMPTY) {
				if (i < j && ( (atoms[i].chainid != atoms[j].chainid) || (std::abs(atoms[i].resid-atoms[j].resid) >= mindiff))) {
					dist = distancesq(atoms, i, j);
					if(dist < cut2) {
						Edges<T> l(i, j);
						Listcontact.push_back(l);
					}
				}
				j = ListCell[j];
			}
			for(int k = 0; k < N; k++) {
				Jcell = List[k];
				j = Head[Jcell];
				while ( j != EMPTY) {
				   if ((atoms[i].chainid != atoms[j].chainid) || (std::abs(atoms[i].resid-atoms[j].resid) >= mindiff)) {
					dist = distancesq(atoms, i, j);
					if(dist < cut2) {
						i1 = std::min(i,j);
						j1 = std::max(i,j);
						Edges<T> l(i1, j1);
						Listcontact.push_back(l);
					}
				    }
				    j = ListCell[j];
				}
			}
			i = ListCell[i];
		}
	}

	std::sort(Listcontact.begin(),Listcontact.end(), sortEdges<T>);

	delete [] Cells;

}
/*================================================================================================
 Find list of neighbours, i.e. list of atoms that are within "cutoff" from each other
 (brute force, serial version, i.e. on one processor)
================================================================================================== */

template <typename T>
void Network<T>::GetNeighbours_bf(std::vector<Atoms<T> >& atoms, T cutoff, int mindiff,
	std::vector<Edges<T> >& Listcontact)
{
/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
		mindiff    : minimum distance in residue number
	Output:
		Listcontact: list of contacts
	========================================================================================== */

/*	==========================================================================================
	Scan all inner cells
	========================================================================================== */

	int i1, j1;
	T cut2 = cutoff * cutoff;
	T dist;

	int natoms = atoms.size();

	for(int i = 0; i < natoms-1; i++) {
		for(int j = i+1; j < natoms; j++) {
			if ((atoms[i].chainid != atoms[j].chainid) || (std::abs(atoms[i].resid-atoms[j].resid) >= mindiff)) {
				dist = distancesq(atoms, i, j);
				if(dist < cut2) {
					i1 = std::min(i,j);
					j1 = std::max(i,j);
					Edges<T> l(i1, j1);
					Listcontact.push_back(l);
				}
			}
		}
	}
  }

/*================================================================================================
 Define elastic network of a molecule based on a cutoff
================================================================================================== */

template <typename T>
void Network<T>::network(std::vector<Atoms<T> >& atoms, T cutoff, int mindiff, std::vector<Edges<T> >& List, 
	int nthreads)
{

	int natoms = atoms.size();

/*	==========================================================================================
	Build grid that contains the molecule
	========================================================================================== */

	T LowerCorner[3];
	int Nx, Ny, Nz;

	BoxSize(atoms, cutoff, LowerCorner, &Nx, &Ny, &Nz);

/*	==========================================================================================
	Build linked lists for atoms into cells
	========================================================================================== */

	int Ncells = Nx*Ny*Nz;

	int *ListCell = new int[natoms];
	int *Head     = new int[Ncells];

	CellLists(atoms, LowerCorner, Nx, Ny, Nz, cutoff, ListCell, Head);

/*	==========================================================================================
	Now build list of contacts
	========================================================================================== */

//	GetNeighbours_thread(atoms, cutoff, mindiff, Nx, Ny, Nz, ListCell, Head, List, nthreads); 
	GetNeighbours(atoms, cutoff, mindiff, Nx, Ny, Nz, ListCell, Head, List); 
//	GetNeighbours_bf(atoms, cutoff, mindiff, List); 

/*	==========================================================================================
	Clean up and return in main program
	========================================================================================== */

	delete [] ListCell;
	delete [] Head;
}

/*================================================================================================
 Define elastic network of a molecule based on an Delaunay
================================================================================================== */

template <typename T>
void Network<T>::delaunay(std::vector<Atoms<T> >& atoms, int mindiff, int potential,
		std::vector<Edges<T> >& List)
{

	int natoms = atoms.size();

	double *coord = new double[3*natoms];
	double *radii = new double[natoms];

	for(int i=0; i < natoms; i++)
	{       
		coord[3*i]   = atoms[i].coord[0];
		coord[3*i+1] = atoms[i].coord[1];
		coord[3*i+2] = atoms[i].coord[2];
		radii[i]     = 1.0;
	}
	
	std::vector<Vertex> vertices;
	std::vector<Tetrahedron> tetra;
	
	delcx.setup(natoms, coord, radii, vertices, tetra);
	delcx.regular3D(vertices, tetra);
	
	std::vector<std::pair<int, int> > edges;
	delcx.delaunayEdges(tetra, edges);
	
	int npairs = edges.size();
	int i1, j1; 
	for(int i = 0; i < npairs; i++) {
		i1 = edges[i].first-4;
		j1 = edges[i].second-4;
		if(potential==2) {
			if((atoms[i1].chainid == atoms[j1].chainid) && 
			(std::abs(atoms[i1].resid-atoms[j1].resid)<mindiff)) continue;
		}
		Edges<T> l(i1, j1);
		List.push_back(l);
	}

	delete [] coord;
	delete [] radii;
}

#endif
