/* ===============================================================================================
  Comp3DShapes: Comparing two surfaces using OT   

   Author:  Patrice Koehl with Henri Orland
   Date:    7/17/2019
   Version: 1
   =============================================================================================== */

/* ===============================================================================================
   System includes
   =============================================================================================== */

#include <iostream>
#include <iomanip>
#include <sstream>
#include <string>
#include <stdlib.h>
#include <fstream>
#include <cmath>
#include <ctime>
#include <unistd.h>
#include <cstdlib>
#include <limits>
#include <sys/time.h>

/* ===============================================================================================
   Static variables and Local Includes
   =============================================================================================== */

#include "Comp3DShapes.h"


/* ===============================================================================================
   Main program
   =============================================================================================== */

int main(int argc, char **argv)
{

	int nthreads = sysconf( _SC_NPROCESSORS_ONLN );
	if(nthreads==0) nthreads = 1;

/*	==========================================================================================
	Show usage if needed
	========================================================================================== */

	if( argc < 2 )
	{
		usage(argv);
		return -1;
	}

	std::string input = argv[1];
	if( input == "-h" || input == "-help" )
	{
		usage(argv);
		return -1;
	}

/*	==========================================================================================
	Read in all inputs (some values may have been preset)
	========================================================================================== */

	std::string INfile1;
	std::string INfile2;
	std::string OUTfile=" ";
	int method = 1;
	int desc_type = 0;
	int flag_kpts = 0;
	double a1 = 1;
	double a2 = 1;
	int flag_refine = 0;
	int nkeep = 0;

        if (!parse_args(argc, argv, &INfile1, &INfile2, &method, &a1, &a2,
		&desc_type, &flag_kpts, &nkeep, &flag_refine, &OUTfile)) return 1;

	if(desc_type==1) desc_type = 4;

	double a1_d = a1; double a2_d = a1; double a1_r = a2; double a2_r = a2;

/*	==========================================================================================
	Read in mesh1 from input file
	========================================================================================== */

	int flag_scale = 1;
	Vector cm1;
	double scale1;

	Mesh model1;

	std::string error;
	std::size_t found = INfile1.find("obj");
	bool info;
	if(found !=std::string::npos) {
		info = MeshIO::readOBJ(INfile1, model1, flag_scale, cm1, &scale1, error);
		if(info==false) {
			std::cout << " " << std::endl;
			std::cout << "Problem reading mesh 1; check file" << std::endl;
			std::cout << " " << std::endl;
			exit(1);
		}
	} else {
		found = INfile1.find("off");
		if(found !=std::string::npos) {
			info = MeshIO::readOFF(INfile1, model1, flag_scale, cm1, &scale1, error);
			if(info==false) {
				std::cout << " " << std::endl;
				std::cout << "Problem reading mesh 1; check file" << std::endl;
				std::cout << " " << std::endl;
				exit(1);
			}
		} else {
			std::cout << " " << std::endl;
			std::cout << "Input file format not recognized; program can only read OFF and OBJ files" << std::endl;
			std::cout << " " << std::endl;
			exit(1);
		}
	}

/*	==========================================================================================
	Characteristics of the mesh 1
	========================================================================================== */

	int nvertices1 = model1.vertices.size();
	int nfaces1   = model1.faces.size();
	int nbound1   = model1.boundaries.size();
	int euler1    = model1.eulerCharacteristic();
	double genus1 = (2-euler1)/2;

	std::cout << " " << std::endl;
	std::cout << "Number of vertices in mesh1        : " << nvertices1 << std::endl;
	std::cout << "Number of faces in mesh1           : " << nfaces1 << std::endl;
	std::cout << "Number of boundaries in mesh1      : " << nbound1 << std::endl;
	std::cout << "Euler characteristics              : " << euler1 << std::endl;
	std::cout << "Genus                              : " << genus1 << std::endl;

/*	==========================================================================================
	Read in mesh2 from input file
	========================================================================================== */

	Vector cm2;
	double scale2;
	Mesh model2;

	found = INfile2.find("obj");
	if(found !=std::string::npos) {
		info = MeshIO::readOBJ(INfile2, model2, flag_scale, cm2, &scale2, error);
		if(info==false) {
			std::cout << " " << std::endl;
			std::cout << "Problem reading mesh 2; check file" << std::endl;
			std::cout << " " << std::endl;
			exit(1);
		}
	} else {
		found = INfile2.find("off");
		if(found !=std::string::npos) {
			info = MeshIO::readOFF(INfile2, model2, flag_scale, cm2, &scale2, error);
			if(info==false) {
				std::cout << " " << std::endl;
				std::cout << "Problem reading mesh 2; check file" << std::endl;
				std::cout << " " << std::endl;
				exit(1);
			}
		} else {
			std::cout << " " << std::endl;
			std::cout << "Input file format not recognized; program can only read OFF and OBJ files" << std::endl;
			std::cout << " " << std::endl;
			exit(1);
		}
	}

/*	==========================================================================================
	Characteristics of mesh 2
	========================================================================================== */

	int nvertices2 = model2.vertices.size();
	int nfaces2   = model2.faces.size();
	int nbound2   = model2.boundaries.size();
	int euler2    = model2.eulerCharacteristic();
	double genus2 = (2-euler2)/2;

	std::cout << " " << std::endl;
	std::cout << "Number of vertices in mesh2        : " << nvertices2 << std::endl;
	std::cout << "Number of faces in mesh2           : " << nfaces2 << std::endl;
	std::cout << "Number of boundaries in mesh2      : " << nbound2 << std::endl;
	std::cout << "Euler characteristics              : " << euler2 << std::endl;
	std::cout << "Genus                              : " << genus2 << std::endl;

/*	==========================================================================================
	Generate subset of mesh, if needed
	========================================================================================== */

	std::vector<int> keypoints1;
	std::vector<int> keypoints2;

	if(flag_kpts == 0 ) {
		int idx;
		for (VertexIter v = model1.vertices.begin(); v != model1.vertices.end(); v++) {
			idx = v->index;
			keypoints1.push_back(idx);
		}
		for (VertexIter v = model2.vertices.begin(); v != model2.vertices.end(); v++) {
			idx = v->index;
			keypoints2.push_back(idx);
		}
	} else if(flag_kpts==2) {
		meshdist.farthestPointSampling(model1, nkeep, keypoints1);
		meshdist.farthestPointSampling(model2, nkeep, keypoints2);
/*
		int *color = new int[nvertices1];
		memset(color, 0, nvertices1*sizeof(int));
		for(int i = 0; i < keypoints1.size(); i++) color[keypoints1[i]] = 1;
		std::string out = "mesh1_color.off";
		MeshIO::writeCOFF(out, model1, color);
		exit(1);
*/
	}


/*	==========================================================================================
	Compute descriptors of both meshes
	========================================================================================== */

	std::vector<std::vector<double> > descriptors1;
	std::vector<std::vector<double> > descriptors2;

	if(desc_type < 3) {

		descriptors.genDescriptors(model1, desc_type, flag_kpts, keypoints1, descriptors1);
		descriptors.genDescriptors(model2, desc_type, flag_kpts, keypoints2, descriptors2);

	} else {
		int neig;
		int laplace_type = 0;
		int neigen    = 500;
		int nfeat     = 200;
		int d_type = desc_type - 3;

		std::vector<double> steps;
		int flag_step = 0;

		if(flag_kpts==2) {

			std::vector<std::vector<double> > desc1;
			std::vector<std::vector<double> > desc2;
			neig = std::min(neigen, nvertices1);
			spectral.genDescriptors(model1, neig, laplace_type, d_type, nfeat, 
			flag_step, steps, desc1, nthreads);

			flag_step = 1;
			neig = std::min(neigen, nvertices2);
			spectral.genDescriptors(model2, neig, laplace_type, d_type, nfeat, 
			flag_step, steps, desc2, nthreads);

			for(int i = 0; i < keypoints1.size(); i++) {
				descriptors1.push_back(desc1[keypoints1[i]]);
			}
			for(int i = 0; i < keypoints2.size(); i++) {
				descriptors2.push_back(desc2[keypoints2[i]]);
			}

		} else {

			neig = std::min(neigen, nvertices1);
			spectral.genDescriptors(model1, neig, laplace_type, d_type, nfeat, 
			flag_step, steps, descriptors1, nthreads);

			flag_step = 1;
			neig = std::min(neigen, nvertices2);
			spectral.genDescriptors(model2, neig, laplace_type, d_type, nfeat, 
			flag_step, steps, descriptors2, nthreads);
		}
	}

	int nfeatures = descriptors1[0].size();
	int npoint1 = keypoints1.size();
	int npoint2 = keypoints2.size();

/*
	std::string file1 = "desc1.dat";
	std::ofstream outfile1;
	outfile1.open(file1);

	for(int i = 0; i < npoint1; i++) {
		for(int j = 0; j < nfeatures; j++) {
			outfile1 << descriptors1[i][j] << " ";
		}
		outfile1 << std::endl;
	}
	outfile1.close();

	std::string file2 = "desc2.dat";
	std::ofstream outfile2;
	outfile2.open(file2);

	for(int i = 0; i < npoint2; i++) {
		for(int j = 0; j < nfeatures; j++) {
			outfile2 << descriptors2[i][j] << " ";
		}
		outfile2 << std::endl;
	}
	outfile2.close();

*/

/*	==========================================================================================
	Compute distance between the two meshes using Optimal Transport
	========================================================================================== */

	double val, sum;
	double *G, *Cost;
	double *lambda, *mu;
	double *m1, *m2;
	
/*  ===============================================================================================
	Compute p-Wasserstein between the selected sets of keypoints on images 1 and 2
    =============================================================================================== */

	double d_OT=0.0, d_Tot=0;

        lambda = new double[npoint1];
        mu = new double[npoint2];
	m1   = new double[npoint1];
	m2   = new double[npoint2];
	G = new double[npoint1*npoint2];
	Cost = new double[npoint1*npoint2];

        memset(lambda, 0, npoint1*sizeof(double));
        memset(mu, 0, npoint2*sizeof(double));
	memset(Cost, 0, npoint1*npoint2*sizeof(double));
	memset(G, 0, npoint1*npoint2*sizeof(double));

	for(int i = 0; i < npoint1; i++) m1[i] = 1./npoint1;
	for(int i = 0; i < npoint2; i++) m2[i] = 1./npoint2;
	
	double Cmax = 0.0;
	for(int j = 0; j < npoint2; j++) {
		for(int i = 0; i < npoint1; i++) {
			sum = 0;
			for(int k = 0; k < nfeatures; k++) {
				val = descriptors1[i][k]-descriptors2[j][k];
				double val2 = descriptors1[i][k]+descriptors2[j][k];
				if(desc_type==4) {
					if(val2>0) {
						sum += std::abs(val/val2);
					}
				} else {
					sum += val*val;
				}
			}
			if(desc_type!=4) sum = std::sqrt(sum);
			Cost[j*npoint1+i] = sum;
			Cmax = std::max(Cmax, sum);
		}
	}
	Cmax = 0.1*Cmax;
	for(int i = 0; i < npoint1*npoint2; i++) Cost[i] /= Cmax;

	clock_t start_s = clock();
	timeval tim;
	double t1, t2, u1, u2, diff;
	gettimeofday(&tim,NULL);
	t1 = tim.tv_sec;
	u1 = tim.tv_usec;

	double beta1 = 1000.;
	double beta_inf = 1.e11;
	int iprint = 1;
	int init = 0;
	double x=0;
	double F=0;
	if(method==0) {
		d_OT = ot.ot1(npoint1, m1, npoint2, m2, Cost, G, lambda, mu,
			beta1, beta_inf, &F, iprint, init, nthreads);
			d_Tot = d_OT;
	} else {
//		a2_d = npoint2*npoint2*(a1_d/(npoint1*npoint1));
		d_OT = otw.ot1_w(npoint1, m1, npoint2, m2, Cost, G, lambda, mu,
		&x, beta1, beta_inf, &d_Tot, iprint, init, a1_d, a2_d, nthreads);
	}

	clock_t stop_s = clock();
	gettimeofday(&tim,NULL);
	t2 = tim.tv_sec;
	u2 = tim.tv_usec;
	diff = (t2-t1) + (u2-u1)*1.e-6;

	std::cout << " " << std::endl;
	std::cout << "OT running time: " << (stop_s-start_s)/double(CLOCKS_PER_SEC) << " seconds" << std::endl;
	std::cout << "OT clock   time: " << diff << " seconds" << std::endl;

	delete [] lambda; delete [] mu;

/*  ===============================================================================================
    Print data about the comparison of the two images 
    =============================================================================================== */

	std::cout << std::endl;
	std::cout << "Number of keypoints on mesh 1                        : " << npoint1 << std::endl;
	std::cout << "Number of keypoints on mesh 2                        : " << npoint2 << std::endl;
	std::cout << "OT distance between the two meshes                   : " << d_OT*Cmax << std::endl;
	std::cout << "U  distance between the two meshes                   : " << d_Tot*Cmax << std::endl;
        std::cout << " " << std::endl;

/*	==========================================================================================
	Write results
	========================================================================================== */

	if(OUTfile!=" ") {
		int flag_mass = 0;
		if(method==1) flag_mass = 1;
		write_res.write(OUTfile, flag_mass, npoint1, m1, npoint2, m2, Cost, G, d_OT);
	}

/*  ===============================================================================================
    Move mesh 2 based on transport plan and refine, if requested
    =============================================================================================== */

	align.alignMesh2(model1, model2, keypoints1, keypoints2, G, m1);

	std::string filename;
	if(OUTfile!=" ") {
		filename = OUTfile;
		filename.append("_init.off");
		if(flag_scale==1) {
			for(VertexIter v = model2.vertices.begin(); v != model2.vertices.end(); v++) {
				Vector p = v->position2;
				p = p/scale1 + cm1;
				v->position2 = p;
			}
		}
		MeshIO::writeOFF(filename, model2);
	}

	delete [] Cost; delete [] G;
	delete [] m1; delete [] m2;

	if(flag_refine==1) {
		int niter = 10;
//		a2_r = nvertices2*nvertices2 *(a1_r/(nvertices1*nvertices1));
		double dist = optim.refine(model1, model2, a1_r, a2_r, niter, method, nthreads);
		std::cout << "OT distance after refinement                             : " << dist << std::endl;
        	std::cout << " " << std::endl;
		if(OUTfile!=" ") {
			filename = OUTfile;
			filename.append("_ref.off");
			if(flag_scale==1) {
				for(VertexIter v = model2.vertices.begin(); v != model2.vertices.end(); v++) {
					Vector p = v->position2;
					p = p/scale1 + cm1;
					v->position2 = p;
				}
			}
			MeshIO::writeOFF(filename, model2);
		}
	}

	return 0;

}

/* ===============================================================================================
   Usage
   =============================================================================================== */

static void usage(char** argv)
{
    std::cout << "\n\n" <<std::endl;
    std::cout << "     " << "================================================================================================"<<std::endl;
    std::cout << "     " << "================================================================================================"<<std::endl;
    std::cout << "     " << "=                                                                                              ="<<std::endl;
    std::cout << "     " << "=                                         Comp3DShapes                                         ="<<std::endl;
    std::cout << "     " << "=                                                                                              ="<<std::endl;
    std::cout << "     " << "=     This program compares two meshes by assigning features to each vertex of the meshes,     ="<<std::endl;
    std::cout << "     " << "=     computing a Cost matrix between those vertices based on their feature representation,    ="<<std::endl;
    std::cout << "     " << "=     and computing the distance between those meshes using this Cost matrix and Optimal       ="<<std::endl;
    std::cout << "     " << "=     transport                                                                                ="<<std::endl;
    std::cout << "     " << "=     Usage is:                                                                                ="<<std::endl;
    std::cout << "     " << "=          Comp3DShapes.exe -i1 FILE1 -i2 FILE2 -m method -a1 alpha1 -a2 alpha2                ="<<std::endl;
    std::cout << "     " << "=                           -d descriptors -r refine                                           ="<<std::endl;
    std::cout << "     " << "=     where:                                                                                   ="<<std::endl;
    std::cout << "     " << "=                 -i1 FILE1       --> Input Mesh file 1 (OBJ or OFF format)                    ="<<std::endl;
    std::cout << "     " << "=                 -i2 FILE2       --> Input Mesh file 2 (OBJ or OFF format)                    ="<<std::endl;
    std::cout << "     " << "=                 -m method       --> 0: fixed masses, 1: optimized masses                     ="<<std::endl;
    std::cout << "     " << "=                 -d desc_type    --> 0: Sift 1: WKS (default: 0)                              ="<<std::endl;
    std::cout << "     " << "=                 -a1 alpha1      --> if m=1, weight on sum of masses 1                        ="<<std::endl;
    std::cout << "     " << "=                 -a2 alpha2      --> if m=1, weight on sum of masses 2                        ="<<std::endl;
    std::cout << "     " << "=                 -r flag_refine  --> flag: 1, refine position of mesh 2, 0: do not refine     ="<<std::endl;
    std::cout << "     " << "================================================================================================"<<std::endl;
    std::cout << "     " << "================================================================================================"<<std::endl;
    std::cout << "\n\n" <<std::endl;
}

/* ===============================================================================================
   Parse Argument from command line:

   =============================================================================================== */

bool parse_args(int argc, char **argv, std::string *INfile1, std::string *INfile2, int *method,
	double *a1, double *a2, int *desc_type, int *flag_kpt, int *nkeep, int *flag_refine, std::string *OUTFILE) 
{
//
// Make sure we have at least two parameters....
//
	std::string param;
	if (argc == 1)
	{
		return false;
	}
	else
	{
		for (int i = 1; i < argc - 1; i = i + 2)
		{
			param = argv[i];

			if (param == "-i1") {
				*INfile1 = argv[i + 1];
			}
			else if (param == "-i2") {
				*INfile2 = argv[i + 1];
			}
			else if(param == "-m") {
				*method = std::atoi(argv[i + 1]);
			}
			else if(param == "-d") {
				*desc_type = std::atoi(argv[i + 1]);
			}
			else if(param == "-k") {
				*flag_kpt = std::atoi(argv[i + 1]);
			}
			else if(param == "-n") {
				*nkeep = std::atoi(argv[i + 1]);
			}
			else if(param == "-a1") {
				*a1 = std::atof(argv[i + 1]);
			}
			else if(param == "-a2") {
				*a2 = std::atof(argv[i + 1]);
			}
			else if(param == "-r") {
				*flag_refine = std::atoi(argv[i + 1]);
			}
			else if (param == "-o") {
				*OUTFILE = argv[i + 1];
			}
		}
  	}
	return true;
}
