/* ===============================================================================================
   Dijkstra.h

   Implements Dijkstra's algorithm to compute geodesic distance on a 3D triangular
   mesh, where geodesics have to follow edges

   Authors:  Patrice Koehl
   Date:    8/20/2018
   Version: 1

   =============================================================================================== */

#ifndef _MESHDIST_H_
#define _MESHDIST_H_

  #include <limits>
  #include <random>
  #include <cstdlib>
  #include <time.h>

  /* =============================================================================================
     Define class
   =============================================================================================== */

  class MeshDist {

	public:

		void eucliddist(Mesh &Mesh, int npoint, double *dist);
		void geodist(Mesh &Mesh, int npoint, double *dist);

		void farthestPointSampling(Mesh& Mesh, int nselect, std::vector<int>& keypoints);

	private:

		void EuclideanDist(Mesh& Mesh, VertexIter source, double *distances);
		void GeodesicDist(Mesh& Mesh, VertexIter source, double *distances);

  };


/* ===============================================================================================
   EuclideanDist.h

   Input:
	Mesh:	mesh structure
	source: source vertex
   Output:
	distances: Euclidean distances of all vertices to the source
   =============================================================================================== */

  void MeshDist::EuclideanDist(Mesh& Mesh, VertexIter source, double *distances)
  {

	int id; 

/*      ==========================================================================================
        ========================================================================================== */

	Vector pointA = source->position;

	/* =======================================================================================
	Compute all distances to source
   	========================================================================================== */

	for(VertexIter v_iter = Mesh.vertices.begin(); v_iter != Mesh.vertices.end(); v_iter++)
	{
		Vector pointB = v_iter->position;
		id = v_iter->index;
        
		distances[id] = (pointA - pointB).norm();
	}
  }

/* ===============================================================================================
   GeodesicDist.h

   Input:
	Mesh:	mesh structure, in OpenMesh format
	source: source vertex
   Output:
	distances: distances of all vertices to the source
   =============================================================================================== */

  void MeshDist::GeodesicDist(Mesh& Mesh, VertexIter source, double *distances)
  {

	double inf = std::numeric_limits<double>::infinity();

	/* =======================================================================================
	Mesh iterators
   	========================================================================================== */

	int id = source->index;
	Vector pointA, pointB;

	int nvertices = Mesh.vertices.size();
	for(int i = 0; i < nvertices; i++) distances[i] = inf;

	/* =======================================================================================
	Define a set of (dist, idx) pairs for all vertices; initialize it with the source 
   	========================================================================================== */

	std::set< std::pair<double, VertexIter> > setVertices;

	setVertices.insert(std::make_pair(0., source));
	distances[id] = 0;

	std::pair<double, VertexIter> tmp;

	/* =======================================================================================
	While set is not empty, update all distances
   	========================================================================================== */

	VertexIter vit;

	double distu, length;
	while( !setVertices.empty() )
	{

		tmp = *(setVertices.begin());
		setVertices.erase(setVertices.begin());

		vit    = tmp.second;
		pointA = vit->position;
		distu = tmp.first;

		HalfEdgeCIter h = vit->he;
		do {
			VertexIter v = h->flip->vertex;
			id = v->index;
			pointB = v->position;
			length = (pointA - pointB).norm();
			if(distances[id] > distu + length)
			{
				if(distances[id] != inf) {
					setVertices.erase(setVertices.find(std::make_pair(distances[id], v)));
				}

				distances[id] = distu + length;
				setVertices.insert(std::make_pair(distances[id],v));
			}

			h = h->flip->next;
		} while (h != vit->he);

	}
  }

/*================================================================================================
 eucliddist: compute the Euclidean distances between all vertices in the mesh
================================================================================================== */

void MeshDist::eucliddist(Mesh &Mesh, int npoint, double *dist)
{

/*      ==========================================================================================
	Compute on main thread
        ========================================================================================== */

	int idx;
	for(VertexIter v_iter = Mesh.vertices.begin(); v_iter != Mesh.vertices.end(); v_iter++) {
		idx = v_iter->index;
		EuclideanDist(Mesh, v_iter, &dist[idx*npoint]);
	}

}

/*================================================================================================
 geodist: compute the geodesic distances between all vertices in the mesh
================================================================================================== */

void MeshDist::geodist(Mesh &Mesh, int npoint, double *dist)
{	
/*      ==========================================================================================
	Compute on main thread
        ========================================================================================== */

	int idx;
	for(VertexIter v_iter = Mesh.vertices.begin(); v_iter != Mesh.vertices.end(); v_iter++) {
		idx = v_iter->index;
		GeodesicDist(Mesh, v_iter, &dist[idx*npoint]);
	}

}

/* ===============================================================================================
   farthestPointSampling

   Input:
	Mesh   : mesh structure, in OpenMesh format
	Vinfo  : flag for each vertex; only those with flags set to 0 are available
	DistMat: distance matrix between all vertices
	nselect: number of vertices to select
	sinit  : status of vertices to be considered
	sselect: status of vertices that are selected
   Output:
	Vinfo:	flag for each vertex; selected vertices have flags set to 1
   =============================================================================================== */

  void MeshDist::farthestPointSampling(Mesh& Mesh, int nselect, std::vector<int>& keypoints)
  {

	int nvertices = Mesh.vertices.size();
	double *distances = new double[nvertices];
	int *status = new int[nvertices];
	memset(status, 0, nvertices*sizeof(int));

	double inf = std::numeric_limits<double>::infinity();

	double *distS = new double[nvertices];
        for(int i = 0; i < nvertices; i++) distS[i] = inf;

	/* =======================================================================================
	Select first point randomly
   	========================================================================================== */

	int select;
	int npick = 0;
	VertexIter v_select;

	srand (time(NULL));
	double rnd = ((double) rand() / (RAND_MAX));
	select = (int) (rnd*nvertices);
	for(VertexIter v_iter = Mesh.vertices.begin(); v_iter != Mesh.vertices.end(); v_iter++) {
		if(v_iter->index == select) {
			v_select = v_iter;
			break;
		}
	}

	keypoints.push_back(select);
	npick++;

	double dmax, dist;
	int idx;
	VertexIter v_sel;

	while(npick < nselect)
	{
		EuclideanDist(Mesh, v_select, distances);
		dmax = 0;
		for(VertexIter v_iter = Mesh.vertices.begin(); v_iter != Mesh.vertices.end(); v_iter++) 
		{
			idx = v_iter->index;
			if(status[idx] == 0)
			{
				dist = distances[idx];
				distS[idx] = std::min(distS[idx], dist);
				if(distS[idx] > dmax) {
					dmax = distS[idx];
					v_sel = v_iter;
				}
			}
		}
		v_select = v_sel;
		status[v_select->index] = 1;
		keypoints.push_back(v_select->index);
		npick++;
	}	

	delete [] distances;
	delete [] distS;
  }

#endif
