/* ====WriteRes_H ===================================================================================
 *
 * Author: Patrice Koehl (in collaboration with Henri Orland), November 2018
 * Department of Computer Science
 * University of California, Davis
 *
 * class for writing results of OT computation in a file
 *
 =============================================================================================== */

#ifndef _WRITERES_H_
#define _WRITERES_H_

  #include <iostream>
  #include <iomanip>
  #include <string>

/* ===============================================================================================
   The class
   =============================================================================================== */

  class WriteRes{

  public:

	// outputs OT dist, plan G, and optionally masses (if they have been optimized
	void write(std::string output, int flag_mass, int npoint1, double *m1, 
	int npoint2, double *m2, double *C, double *G, double d_OT);

  };

/* ===============================================================================================
   Main function:
	outputs OT dist, plan G, and optionally masses (if they have been optimized
   =============================================================================================== */

  void WriteRes::write(std::string output, int flag_mass, int npoint1, double *m1, 
	int npoint2, double *m2, double *C, double *G, double d_OT)
  {
	if(output.back() != '.') output.append(".");
	std::string filename = output;
	filename.append("dist");
	std::ofstream outfile1;
	outfile1.open(filename);

	outfile1 << std::endl;
	outfile1 << "OT distance between the two point sets   : " << d_OT << std::endl;
       	outfile1 << " " << std::endl;

	outfile1.close();

	filename = output;
	filename.append("plan");
	std::ofstream outfile2;
	outfile2.open(filename);

	for(int i = 0; i < npoint1; i++) {
		for(int j = 0; j < npoint2; j++) {
			outfile2 << G[i+j*npoint1] << " ";
		}
		outfile2 << std::endl;
	}
	outfile2.close();

	if(flag_mass > 0 ) {
		filename = output;
		filename.append("mass1");
		std::ofstream outfile3;
		outfile3.open(filename);

		for(int i = 0; i < npoint1; i++) {
			outfile3 << m1[i] << std::endl;
		}
		outfile3.close();

		filename = output;
		filename.append("mass2");
		std::ofstream outfile4;
		outfile4.open(filename);

		for(int j = 0; j < npoint2; j++) {
			outfile4 << m2[j] << std::endl;
		}
		outfile4.close();
	}

	filename = output;
	filename.append("cost");
	std::ofstream outfile5;
	outfile5.open(filename);

	for(int i = 0; i < npoint1; i++) {
		for(int j = 0; j < npoint2; j++) {
			outfile5 << C[i+j*npoint1] << " ";
		}
		outfile5 << std::endl;
	}
	outfile5.close();

  }

#endif
