/*================================================================================================
  Network.h
  Version 1: 9/29/2019

  Purpose: Computes number of points within a given cutoff, using a grid approach

Copyright (c) Patrice Koehl.

================================================================================================== */

#ifndef _NETWORK_H_
#define _NETWORK_H_


/*================================================================================================
 Includes
================================================================================================== */

#include "Links.h"
#include "Network_thread_tools.h"

/*================================================================================================
 Network class
================================================================================================== */

class Network {

	public:

		// defines elastic network
		void network(std::vector<Atoms>& atoms, double cutoff, int mindiff, 
		std::vector<Links>& List, int nthreads);

	private:

		// size of the box containing all atoms
		void BoxSixe(std::vector<Atoms>& atoms, double cutoff, double *LowerCorner, int *Nx, int *Ny, int *Nz);

		// organize cell linked lists
		void CellLists(std::vector<Atoms>& atoms, double *LowerCorner, int Nx, int Ny, int Nz,
		double cutoff, int *ListCell, int *Head);

		// builds neighbour list (parallel)
		void GetNeighbours_thread(std::vector<Atoms>& atoms, double cutoff, int mindiff, 
		int Nx, int Ny, int Nz, int *ListCell, int *Head, std::vector<Links>& Listcontact, 
		int nthreads);

		// builds neighbour list (serial)
		void GetNeighbours(std::vector<Atoms>& atoms, double cutoff, int mindiff, int Nx, int Ny, int Nz, 
		int *ListCell, int *Head, std::vector<Links>& Listcontact);

		// builds neighbour list (full)
		void FullNeighbours(std::vector<Atoms>& atoms, double cutoff, int mindiff, std::vector<Links>& Listcontact);
	};

/*================================================================================================
 Define size of the box that contains all the atoms
================================================================================================== */

void Network::BoxSixe(std::vector<Atoms>& atoms, double cutoff, double *LowerCorner, int *Nx, int *Ny, int *Nz)
{

	int natoms = atoms.size();

/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
	Output:
		LowerCorner: position of the lower, left, front corner of the grid considered
		Nx         : number of cells along X direction
		Ny         : number of cells along Y direction
		Nz         : number of cells along Z direction
	========================================================================================== */

/*	==========================================================================================
  	Maximum and minimum values for each coordinate set 
	========================================================================================== */
	
	double minx, miny, minz;
	double maxx, maxy, maxz;
	double x,y,z;
	
	minx = atoms[0].coord[0];
	maxx = minx;
	
	miny = atoms[0].coord[1];
	maxy = miny;

	minz = atoms[0].coord[2];
	maxz = minz;

	for(int i = 1; i < natoms; i++)
	{
		x = atoms[i].coord[0];
		y = atoms[i].coord[1];
		z = atoms[i].coord[2];
		if(minx > x) minx = x;
		if(maxx < x) maxx = x;
		if(miny > y) miny = y;
		if(maxy < y) maxy = y;
		if(minz > z) minz = z;
		if(maxz < z) maxz = z;
	}
		 					
/*	==========================================================================================
  	Define lower left front corner of the grid: at (minx-cutoff,miny-cutoff,minz-cutoff)
        (shifted by cutoff to ensure that all non-empty cells are "inside")
	========================================================================================== */

	LowerCorner[0] = minx - cutoff;
	LowerCorner[1] = miny - cutoff;
	LowerCorner[2] = minz - cutoff;

/*	==========================================================================================
  	Calculate the number of cubes based on cutoff 
	========================================================================================== */

	double cube_size = cutoff/2;


	*Nx = ceil((maxx + cutoff - LowerCorner[0])/cube_size);
	*Ny = ceil((maxy + cutoff - LowerCorner[1])/cube_size);
	*Nz = ceil((maxz + cutoff - LowerCorner[2])/cube_size);

}

/*================================================================================================
 Define linked list cell algorithm:
 atoms are organized into linked lists, whose heads are assigned to the cell containing the
 atoms in the list; this saves space as the list per cells do not need to be stored
================================================================================================== */

void Network::CellLists(std::vector<Atoms>& atoms, double *LowerCorner, int Nx, int Ny, int Nz,
		double cutoff, int *ListCell, int *Head)
{

	int natoms = atoms.size();

/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
		LowerCorner: position of the lower, left, front corner of the grid considered
		Nx         : number of cells along X direction
		Ny         : number of cells along Y direction
		Nz         : number of cells along Z direction
	Output:
		ListCell   : For each atom i, pointer to the next atom in the same cell as i
		Head       : For each cell, pointer to the first atom in the cell (-1 if empty)
	========================================================================================== */

	int EMPTY = -1;
	int Ncells = Nx*Ny*Nz;

/*	==========================================================================================
	Initialize each cell to be empty
	========================================================================================== */

	for(int i = 0; i < Ncells; i++) {
		Head[i] = EMPTY;
	}

/*	==========================================================================================
	Assign each atom to a cell by connecting it to its linked list
	========================================================================================== */

        int Ic, Jc, Kc, Icell;
	double grid_size = cutoff/2;

        for(int i = 0; i < natoms; i++)
        {
                Ic = floor((atoms[i].coord[0]-LowerCorner[0])/grid_size);
                Jc = floor((atoms[i].coord[1]-LowerCorner[1])/grid_size);
                Kc = floor((atoms[i].coord[2]-LowerCorner[2])/grid_size);

                Icell = (Kc*Nx*Ny) + (Jc*Nx) + Ic;
		ListCell[i] = Head[Icell];
		Head[Icell] = i;
        }

}

/*================================================================================================
 Find list of neighbours, i.e. list of atoms that are within "cutoff" from each other, using a
 multi-threaded algorithm
================================================================================================== */

void Network::GetNeighbours_thread(std::vector<Atoms>& atoms, double cutoff, int mindiff, int Nx, int Ny, int Nz, 
	int *ListCell, int *Head, std::vector<Links>& Listcontact, int nthreads)
{
/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
		mindiff	   : minimum "distance" in residue number
		Nx         : number of cells along X direction
		Ny         : number of cells along Y direction
		Nz         : number of cells along Z direction
		ListCell   : For each atom i, pointer to the next atom in the same cell as i
		Head       : For each cell, pointer to the first atom in the cell (-1 if empty)
	Output:
		Listcontact: list of contacts
	========================================================================================== */

/*	==========================================================================================
	Find list of non-empty cells
	========================================================================================== */

	int EMPTY = -1;
	int i, Icell;
	int N1,N2;
	double cut2 = cutoff * cutoff;

	int *Cells = new int[Nx*Ny*Nz];

	int Ncells = 0;
	for(int Ic = 0; Ic < Nx; Ic++) {
		for(int Jc = 0; Jc < Ny; Jc++) {
			for(int Kc = 0; Kc < Nz; Kc++) {
                		Icell = (Kc*Nx*Ny) + (Jc*Nx) + Ic;
				i = Head[Icell];
				if (i != EMPTY) {
					Cells[Ncells]=Icell;
					Ncells++;
				}
			}
		}
	}

/*	==========================================================================================
	Break list to all threads and send jobs
	========================================================================================== */

	int nval = Ncells/nthreads;

	for (int i=0; i < nthreads; i++)
	{
		N1 = i*nval;
		N2 = N1 + nval;
		if(i == nthreads-1) N2 = Ncells;

		threadids[i] = i;

		grids[i].firstcell = N1;
		grids[i].lastcell  = N2;
		grids[i].Nx        = Nx;
		grids[i].Ny        = Ny;
		grids[i].cutoff2   = cut2;
		grids[i].mindiff   = mindiff;
		grids[i].Cells     = Cells;
		grids[i].ListCell  = ListCell;
		grids[i].Head      = Head;
		grids[i].atoms     = atoms;

		std::vector<Links> List;
		grids[i].Contacts = List;

		pthread_create(&threads[i], NULL, neighbours_thread, (void*) &threadids[i]);
	}

/*	==========================================================================================
	Join all the threads (to make sure they are all finished)
	========================================================================================== */

	std::vector<Links> Temp;
	for (int i=0; i < nthreads; i++)
	{
		pthread_join(threads[i], NULL);
		Temp = Listcontact;
		Listcontact.clear();
		std::merge(Temp.begin(),Temp.end(),grids[i].Contacts.begin(),grids[i].Contacts.end(),
		std::back_inserter(Listcontact), sortLinks);
	}
	Temp.clear();

	delete [] Cells;

}

/*================================================================================================
 Find list of neighbours, i.e. list of atoms that are within "cutoff" from each other
 (serial version, i.e. on one processor)
================================================================================================== */

void Network::GetNeighbours(std::vector<Atoms>& atoms, double cutoff, int mindiff, int Nx, int Ny, int Nz, 
	int *ListCell, int *Head, std::vector<Links>& Listcontact)
{
/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
		mindiff    : minimum distance in residue number
		Nx         : number of cells along X direction
		Ny         : number of cells along Y direction
		Nz         : number of cells along Z direction
		ListCell   : For each atom i, pointer to the next atom in the same cell as i
		Head       : For each cell, pointer to the first atom in the cell (-1 if empty)
	Output:
		Listcontact: list of contacts
	========================================================================================== */

/*	==========================================================================================
	Scan all inner cells
	========================================================================================== */

	int EMPTY = -1;
	int i, j, Icell, Jcell;
	int i1,j1;
	int N;
	int List[1000];
	double cut2 = cutoff * cutoff;
	double dist;

	int *Cells = new int[Nx*Ny*Nz];

	int Ncells = 0;
	for(int Ic = 0; Ic < Nx; Ic++) {
		for(int Jc = 0; Jc < Ny; Jc++) {
			for(int Kc = 0; Kc < Nz; Kc++) {
                		Icell = (Kc*Nx*Ny) + (Jc*Nx) + Ic;
				i = Head[Icell];
				if (i != EMPTY) {
					Cells[Ncells]=Icell;
					Ncells++;
				}
			}
		}
	}

	double kval = 0.;
	double r;
	for (int c = 0; c < Ncells; c++) {
		Icell = Cells[c];
		i = Head[Icell];
		while ( i != EMPTY) {
			j = Head[Icell];
			while ( j != EMPTY) {
				if (i < j && ( (atoms[i].chainid != atoms[j].chainid) || (std::abs(atoms[i].resid-atoms[j].resid) >= mindiff))) {
					dist = distancesq(atoms, i, j);
					if(dist < cut2) {
						r = std::sqrt(dist);
						Links l(i, j, atoms[i].resid, atoms[j].resid, kval, r);
						Listcontact.push_back(l);
					}
				}
				j = ListCell[j];
			}
			Neighbours(Icell, Nx, Ny, &N, List);
			for(int k = 0; k < N; k++) {
				Jcell = List[k];
				j = Head[Jcell];
				while ( j != EMPTY) {
				   if ((atoms[i].chainid != atoms[j].chainid) || (std::abs(atoms[i].resid-atoms[j].resid) >= mindiff)) {
					dist = distancesq(atoms, i, j);
					if(dist < cut2) {
						i1 = std::min(i,j);
						j1 = std::max(i,j);
						r = std::sqrt(dist);
						Links l(i1, j1, atoms[i1].resid, atoms[j1].resid, kval, r);
						Listcontact.push_back(l);
					}
					j = ListCell[j];
				    }
				}
			}
			i = ListCell[i];
		}
	}

	std::sort(Listcontact.begin(),Listcontact.end(), sortLinks);

	delete [] Cells;

}
/*================================================================================================
 Find list of neighbours, i.e. list of atoms that are within "cutoff" from each other
 (serial version, i.e. on one processor)
================================================================================================== */

void Network::FullNeighbours(std::vector<Atoms>& atoms, double cutoff, int mindiff, std::vector<Links>& Listcontact)
{
/*	==========================================================================================
	Input:
		atoms      : class for atom information
		cutoff     : cutoff for defining neighbours
		mindiff    : minimum distance in residue number
	Output:
		Listcontact: list of contacts
	========================================================================================== */

	double cut2 = cutoff * cutoff;

	int natoms = atoms.size();

	double kval = 0.;
	double r;
	double dist;

	for(int i = 0; i < natoms; i++) {
		for(int j = i+1; j < natoms; j++) {

			if(std::abs(atoms[i].resid-atoms[j].resid) >= mindiff) {
				dist = distancesq(atoms, i, j);
				if(dist < cut2) {
					r = std::sqrt(dist);
					Links l(i, j, atoms[i].resid, atoms[j].resid, kval, r);
					Listcontact.push_back(l);
				}
			}
		}
	}

//	std::sort(Listcontact.begin(),Listcontact.end(), sortLinks);

}
/*================================================================================================
 Define elastic network of a molecule based on a cutoff
================================================================================================== */

void Network::network(std::vector<Atoms>& atoms, double cutoff, int mindiff, std::vector<Links>& List, 
	int nthreads)
{

	int natoms = atoms.size();

/*	==========================================================================================
	Build grid that contains the molecule
	========================================================================================== */

	double LowerCorner[3];
	int Nx, Ny, Nz;

	BoxSixe(atoms, cutoff, LowerCorner, &Nx, &Ny, &Nz);

/*	==========================================================================================
	Build linked lists for atoms into cells
	========================================================================================== */

	int Ncells = Nx*Ny*Nz;

	int *ListCell = new int[natoms];
	int *Head     = new int[Ncells];

	CellLists(atoms, LowerCorner, Nx, Ny, Nz, cutoff, ListCell, Head);

/*	==========================================================================================
	Now build list of contacts
	========================================================================================== */

//	GetNeighbours_thread(atoms, cutoff, mindiff, Nx, Ny, Nz, ListCell, Head, List, nthreads); 
//	GetNeighbours(atoms, cutoff, mindiff, Nx, Ny, Nz, ListCell, Head, List);
	FullNeighbours(atoms, cutoff, mindiff, List);

/*	==========================================================================================
	Clean up and return in main program
	========================================================================================== */

	delete [] ListCell;
	delete [] Head;

}

#endif
