/* ===============================================================================================
   SeqKernel.cpp			Version 1 6/8/2016		

   Authors: Saghi Nojoomi and Patrice Koehl

   This program computes the distance between two sets of sequences (in FASTA format)
   using the idea of a kernel on the space of (discrete) protein sequences

   =============================================================================================== */

#include <fstream>
#include <iostream>
#include <sstream>
#include <iomanip>
#include <string>
#include <algorithm>

#include <sys/types.h>
#include <sys/stat.h>
#include <math.h>
#include <dirent.h>
#include <errno.h>
#include <vector>

using namespace std;

/* ===============================================================================================
   Default
   =============================================================================================== */

const double BETA = 0.01;
const int KMAX = -1;
const int IW = 0;

/* ===============================================================================================
   Prototypes for all functions
   =============================================================================================== */

static void usage(char** argv);

void read_flags(int argc, char** argv, string *fmat, string *fseq1, string *fseq2, string *fresult, 
double *beta, int *kmax, int *iweight);

int corresp(char letter, int sizeK);

void read_kmat(string fmat, double *sm, int sizeK);

void scale_kmat(double *sm, double beta, int sizeK);

void readseq_fa(string file, vector<string> *seq, vector<string> *name, int *nseq);

void K3_hat(int kmax, int iweight, int *iseq1, int length1, int *iseq2, int length2, double *sm, 
int sizeK, double *mat, double k3_11, double k3_22, double *dotprod, double *dist);

double K3_1mer(int *iseq1, int length1, int *iseq2, int length2, double *sm, int sizeK, double *mat);

double K3_kmer(int k, int *iseq1, int length1, int *iseq2, int length2, double *sm, int sizeK, double *mat);

double K3(int kmax, int iweight, int *iseq1, int length1, int *iseq2, int length2, double *sm, int sizeK, double *mat);

int getmax(int *array, int size);

/* ===============================================================================================
   Main program
   =============================================================================================== */

int main(int argc, char** argv)
{

/* ===============================================================================================
   Show usage if needed
   =============================================================================================== */

    int narg = argc;

    if( argc < 2 )
    {
        usage(argv);
        return -1;
    }

    string input = argv[1];
    if( input == "-h" || input == "-help" )
    {
        usage(argv);
        return -1;
    }

    string fkmat;
    string fseq1;
    string fseq2;
    string fresult;

/* ===============================================================================================
   Initialize parameters: scaling of substitution matrix; max length of k-mer
   =============================================================================================== */

    double beta = BETA;
    int kmax = KMAX;
    int iweight = IW;

/* ===============================================================================================
   Read in parameters for run
   =============================================================================================== */

    read_flags(argc, argv, &fkmat, &fseq1, &fseq2, &fresult, &beta, &kmax, &iweight);

/* ===============================================================================================
   Read in substitution matrix and scale it
   =============================================================================================== */

    int sizeK = 20;
    double *sm = new double[sizeK*sizeK];

    read_kmat(fkmat,sm,sizeK); 
    scale_kmat(sm,beta,sizeK);

/* ===============================================================================================
   Read in the two sequences seq1 and seq2
   =============================================================================================== */

    vector<string> seq1;
    vector<string> name1;
    int nseq1;
    readseq_fa(fseq1,&seq1,&name1,&nseq1);
    int* length1 = new int[nseq1];
    int** iseq1 = new int*[nseq1];
    for(int i = 0; i < nseq1; ++i)
    {
	length1[i]=seq1[i].length();
	iseq1[i] = new int[length1[i]];
    	for(int j = 0; j < length1[i]; j++) {
		iseq1[i][j] = corresp(seq1[i][j],sizeK);
    	}
    }

    vector<string> seq2;
    vector<string> name2;
    int nseq2;
    readseq_fa(fseq2,&seq2,&name2,&nseq2);
    int* length2 = new int[nseq2];
    int** iseq2 = new int*[nseq2];
    for(int i = 0; i < nseq2; ++i)
    {
	length2[i]=seq2[i].length();
	iseq2[i] = new int[length2[i]];
    	for(int j = 0; j < length2[i]; j++) {
		iseq2[i][j] = corresp(seq2[i][j],sizeK);
    	}
    }

/* ===============================================================================================
   Compute kernel distance between the two sequences
   =============================================================================================== */

    int lmax1 = getmax(length1,nseq1);
    int lmax2 = getmax(length2,nseq2);
    int lmax = max(lmax1,lmax2);
    double *mat = new double[lmax*lmax];
    double** dotprod = new double*[nseq1];
    double** dist = new double*[nseq1];
    for(int i = 0; i < nseq1; ++i)
    {
	dotprod[i] = new double[nseq2];
	dist[i] = new double[nseq2];
    }
    double *k3_11 = new double[nseq1];
    double *k3_22 = new double[nseq2];
    double x,y;

    for(int i=0; i < nseq1 ; i++) {
	k3_11[i]=K3(kmax,iweight,iseq1[i],length1[i],iseq1[i],length1[i],sm,sizeK,mat);
    }
    for(int i=0; i < nseq2 ; i++) {
	k3_22[i]=K3(kmax,iweight,iseq2[i],length2[i],iseq2[i],length2[i],sm,sizeK,mat);
    }
	
    for(int i=0; i < nseq1 ; i++) {
	for (int j = 0; j < nseq2; j++) {
    		K3_hat(kmax, iweight, iseq1[i], length1[i], iseq2[j], length2[j], sm, sizeK, mat, k3_11[i], k3_22[j],&x, &y);
		dotprod[i][j]=x;
		dist[i][j]=y;
	}
    }

/* ===============================================================================================
   Write result on screen and in the output file
   =============================================================================================== */

    cout << "\n" << endl;
    cout << "     Seq 1        Seq 2        Length1        Length2        <Seq1,Seq2>       Dist(Seq1,Seq2)" << endl;
    for(int i=0; i < nseq1 ; i++) {
	for (int j = 0; j < nseq2; j++) {
    		cout << setw(10) << name1[i] << setw(13) << name2[j] << setw(15) << length1[i] << setw(15)<< length2[j] << setw(18) << dotprod[i][j] << setw(18) << dist[i][j] << endl;
	}
    }
    cout << "\n" << endl;

    std::ofstream outfile;
    outfile.open(fresult.c_str());
    outfile << "     Seq 1        Seq 2        Length1        Length2        <Seq1,Seq2>       Dist(Seq1,Seq2)" << endl;
    for(int i=0; i < nseq1 ; i++) {
	for (int j = 0; j < nseq2; j++) {
    		outfile << setw(10) << name1[i] << setw(13) << name2[j] << setw(15) << length1[i] << setw(15)<< length2[j] << setw(18) << dotprod[i][j] << setw(18) << dist[i][j] << endl;
	}
    }
    outfile.close();

    return 0;
}

/* ===============================================================================================
   Usage
   =============================================================================================== */

static void usage(char** argv)
{
    cout << "\n\n" <<endl;
    cout << "     " << "================================================================================================"<<endl;
    cout << "     " << "================================================================================================"<<endl;
    cout << "     " << "=                                                                                              ="<<endl;
    cout << "     " << "=                                        WSeqKernel                                            ="<<endl;
    cout << "     " << "=                                                                                              ="<<endl;
    cout << "     " << "=     This program reads in two sets of sequences and computes the distances between them      ="<<endl;
    cout << "     " << "=     using a kernel-based approach adapted from the string kernel originally proposed by:     ="<<endl;
    cout << "     " << "=     Shen et al, Found. Comput. Math., 2013,14:951-984                                        ="<<endl;
    cout << "     " << "=     SeqKernel is fully described in                                                          ="<<endl;
    cout << "     " << "=     Nojoomi and Koehl, BMC Bioinformatics, 2017, 18:137:1-137:15                             ="<<endl;
    cout << "     " << "=     Nojoomi and Koehl, BMC Bioinformatics, submitted                                         ="<<endl;
    cout << "     " << "=                                                                                              ="<<endl;
    cout << "     " << "=     Usage is:                                                                                ="<<endl;
    cout << "     " << "=                 WSeqKernel.exe                                                               ="<<endl;
    cout << "     " << "=                                 -i  <file containing first set of sequences>                 ="<<endl;
    cout << "     " << "=                                 -j  <file containing second set of sequences>                ="<<endl;
    cout << "     " << "=                                 -p  <path to substitution matrix SM>                         ="<<endl;
    cout << "     " << "=                                 -o  <path to output file (append mode)>                      ="<<endl;
    cout << "     " << "=                                 -b  <Beta value (scale for SM; default 0.11)>                ="<<endl;
    cout << "     " << "=                                 -k  <kmax (longest kmer considered; -1 (default)  means all) ="<<endl;
    cout << "     " << "=                                 -w  <flag for weights: 0 (no weight) 1 (mean ) 2 (WS)>       ="<<endl;
    cout << "     " << "=                                     where                                                    ="<<endl;
    cout << "     " << "=                                     0 means no weight, as in Shen et al.                     ="<<endl;
    cout << "     " << "=                                     1 means mean weight, described in Nojoomi and Koehl      ="<<endl;
    cout << "     " << "=                                     2 means weighted degree, as in Ratsch et al,             ="<<endl;
    cout << "     " << "=                                       Bioinformatics. 2005;21:i369–i377                      ="<<endl;
    cout << "     " << "=                                                                                              ="<<endl;
    cout << "     " << "================================================================================================"<<endl;
    cout << "     " << "================================================================================================"<<endl;
    cout << "\n\n" <<endl;
}

/* ===============================================================================================
   Procedure to read in flag values
   =============================================================================================== */

void read_flags(int argc, char** argv, string *fmat, string *fseq1, string *fseq2, string *fresult, double *beta, int *kmax, int *iweight)
{
    string input;
    for(int i = 1; i < argc; i++)
    {
        input = argv[i];
        if (input == "-i") {
                *fseq1 = argv[i + 1];
        }
        if (input == "-j") {
                *fseq2 = argv[i + 1];
        }
        if (input == "-o") {
                *fresult = argv[i + 1];
        }
        if (input == "-p") {
                *fmat = argv[i + 1];
        }
        if (input == "-b") {
                *beta = atof(argv[i + 1]);
        }
        if (input == "-k") {
                *kmax = atoi(argv[i + 1]);
        }
        if (input == "-w") {
                *iweight = atoi(argv[i + 1]);
        }
     }
}

/* ===============================================================================================
   Procedure that assigns a number to each amino acid (given in 1-letter code)
   =============================================================================================== */

int corresp(char letter, int sizeK)
{
	char names[20]={'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y'};
	int i;

	for(i=0; i< sizeK; i++) {
		if(letter==names[i]) {
			return i;
		}
	}

	return -1; // Returns -1 if letter was not recognized
}

/* ===============================================================================================
   Procedure to read substitution matrix SM
   =============================================================================================== */

void read_kmat(string fkmat, double *sm, int sizeK)
{
        ifstream inFile;
        inFile.open(fkmat.c_str());
	string record;

        int i, idx, iread;
	char letter;
	double val;
	int *order= new int[sizeK];

	iread = 0;
	while ( !inFile.eof () ) {    
		getline(inFile,record);
		if(!record.empty()) {
			istringstream iss(record);
			if(iread==0) {
				i = 0;
				while (iss >> letter)
				{
					idx = corresp(letter,sizeK);
					order[i] = idx;
					i = i + 1;
				}
			}
			else {
				iss >> letter;
				idx = corresp(letter,sizeK);
				for(int j = 0; j < sizeK; j++) {
					iss >> val;
					sm[idx*sizeK + order[j]]=val;
				}
			}
			iread++;
		}
	}

	inFile.close();
}

/* ===============================================================================================
   Procedure to scale all elements of SM by beta
   =============================================================================================== */

void scale_kmat(double *sm, double beta, int sizeK)
{
	for(int i = 0; i < sizeK*sizeK; i++) {
		sm[i] = pow(sm[i],beta);
	}
}
			
/* ===============================================================================================
   Procedure to read sequence from input file (assumed to be is FASTA format)
   =============================================================================================== */

void readseq_fa(string file, vector<string> *seq, vector<string> *seqname, int *nseq)
{
        ifstream inFile;
        inFile.open(file.c_str());
	string record;
	string name;
	string sequence;
	char letter;
	int sizeK;

/* 	==========================================================================================
	We assume there is a single sequence in the file
   	========================================================================================== */

	*nseq=0;
	name.clear();

	while ( !inFile.eof () ) {    
		getline(inFile,record);
		if(record[0] == '>' ){ // Found FASTA seq identifier
            		if( !name.empty() ){              // clear current sequence name
				transform(sequence.begin(), sequence.end(), sequence.begin(), (int (*)(int))std::toupper);
				(*seq).push_back(sequence);
				(*seqname).push_back(name);
				(*nseq)++;
                		name.clear();
            		}
            		if( record.find(' ') != std::string::npos ){
				name = record.substr(1,record.find_first_of(" ")-1);
			}
			else {
				name = record.substr(1);
			}
			sequence.clear();
        	} 
		else if( !name.empty() ){		  // we have read a sequence name from the line with ">"
                	sequence += record;
        	}
	}
        if( !name.empty() ){              // clear current sequence name
		transform(sequence.begin(), sequence.end(), sequence.begin(), (int (*)(int))std::toupper);
		(*seq).push_back(sequence);
		(*seqname).push_back(name);
		(*nseq)++;
	}
}

/* ===============================================================================================
   Procedure that initialize the computation of the dot product <Seq1,Seq2>:

   Let Seq1= a_0 a_1 ............a_N1
   and Seq2 =b_0 b_1 ..................b_N2

   We compute K3_1mer(Seq1,Seq2), i.e. the contibution of all k-mers of length 1 to the dot product.
   We note that 
		K3_1mer(Seq1,Seq2) = Sum_{i=1,N1}_{j=1,N2} K2_1(a_i,b_j)
   where:
		K2_1(a_i,b_j) = SM(itype(a_i),itype(b_j))
   since we only look at K-mer of size 1

   SM is given in the array *sm; the array MAT will contain all values of K2_1(a_i,a_j); 
   K3_1(Seq1,Seq2) is simply the sum of the elements of MAT

   =============================================================================================== */


double K3_1mer(int *iseq1, int length1, int *iseq2, int length2, double *sm, int sizeK, double *mat)
{
	int idx,jdx;
	int idx1,idx2;
	double val;
	double sum = 0;

	for(int i = 0; i < length1; i++) {
		idx1 = iseq1[i];
		for(int j = 0; j < length2; j++) {
			idx = i*length2+j;
			idx2 = iseq2[j];
			jdx = idx1*sizeK + idx2;
			val = sm[jdx];
			mat[idx]=val;
			sum = sum + val;
		}
	}

	return sum;

}

/* ===============================================================================================
   Procedure that computes K3_kmer(Seq1,Seq2)

   Let Seq1= a_0 a_1 ............a_N1
   and Seq2 =b_0 b_1 ..................b_N2

   We compute K3_kmer(Seq1,Seq2), i.e. the contibution of all k-mers of length k to the dot product.

   We note that 
		K3_kmer(Seq1,Seq2) = Sum_{i=1,N1-k}_{j=1,N2-k} K2_k(u_k(i),v_k(j)
   where:
		K2_k(u_k(i),v_k(j)) is the contribution of the k-mer u_k(i) starting at 
				position a_i and the k-mer v_k(j) starting at position b_j, both 
				of length k

   Interestingly,
		K2_k(u_k(i),v_k(j)) = K2_{k-1}( u_{k-1}(i), v_{k-1}(i) ) * SM(a_{i+k-1},b_{j+k-1})

   This means that K2_k can be computed as an update of K2_{k-1}, which is stored in MAT!

   =============================================================================================== */


double K3_kmer(int k, int *iseq1, int length1, int *iseq2, int length2, double *sm, int sizeK, double *mat)
{
	int idx,jdx;
	int idx1,idx2;
	double val;
	double sum = 0;

	for(int i = 0; i < length1 - k + 1; i++) {
		idx1 = iseq1[i+k-1];
		for(int j = 0; j < length2 - k + 1; j++) {
			idx = i*length2+j;
			idx2 = iseq2[j+k-1];
			jdx = idx1*sizeK + idx2;
			val = sm[jdx];
			mat[idx]=mat[idx]*val;
			sum = sum + mat[idx];
		}
	}

	return sum;

}

/* ===============================================================================================
   Procedure that computes K3(Seq1,Seq2)

   Let Seq1= a_0 a_1 ............a_N1
   and Seq2 =b_0 b_1 ..................b_N2

   We compute K3(Seq1,Seq2), i.e. the contibution of all k-mers to the dot product.

   We note that 
		K3(Seq1,Seq2) = Sum_{k=1}^{min(N1,N2)} K3_kmer(Seq1,Seq2)
   where
		K3_kmer(Seq1,Seq2) = Sum_{i=1,N1-k}_{j=1,N2-k} K2_k(u_k(i),v_k(j)
   and
		K2_k(u_k(i),v_k(j)) is the contribution of the k-mer u_k(i) starting at position a_i 
				and the k-mer v_k(j) starting at position b_j, both of length k

   K3_kmer(Seq1,Seq2) is computed with procedure K3_kmer, for 1 < k <= min(N1,N2), and with 
   procedure K3_1mer for k = 1 (which serves to initialize the matrix MAT)

   =============================================================================================== */

double K3(int kmax, int iweight, int *iseq1, int length1, int *iseq2, int length2, double *sm, int sizeK, double *mat)
{

	double K3=0;
	double sum;
	double weight;
	double diff;

	sum = K3_1mer(iseq1,length1,iseq2,length2,sm,sizeK,mat);
	if(iweight == 0) {
		weight = 1;
		diff = 0;
	}
	else if(iweight == 1) {
		weight = 1.0/ ( (double) (length1*length2));
		diff = 0;
	}
	else if(iweight == 2) {
		weight = 2.0/((double) (kmax+1));
		diff = 0;
	}
	else if(iweight == 3) {
		weight = 1;
		diff = 1.0/ ( (double) (length1*length2));
	}
	K3 = K3 + weight*sum-diff;

	int klast;
	if(kmax == -1 || kmax == 0) {
		klast = min(length1,length2);
	}
	else {
		klast = kmax;
	}

	for(int k = 2; k <= klast; k++) {
		sum = K3_kmer(k,iseq1,length1,iseq2,length2,sm,sizeK,mat);
		if(iweight == 0) {
			weight = 1;
			diff = 0;
		}
		else if(iweight == 1) {
			weight = 1/((double) (length1-k+1)*(length2-k+1));
			diff = 0;
		}
		else if(iweight == 2) {
			weight = 2.0*((double) kmax - k + 1)/((double) (kmax*(kmax+1)));
			diff = 0;
		}
		else if(iweight == 3) {
			weight = 1;
			diff = 1/((double) (length1-k+1)*(length2-k+1));
		}
		K3 = K3 + weight*sum - diff;
	}

	return K3;

}

/* ===============================================================================================
	Procedure that computes K3^(Seq1,Seq2)
	where:
		K3^(Seq1,Seq2) = K3(Seq1,Seq2)/ sqrt( K3(Seq1,Seq1) K3(Seq2,Seq2) )

   =============================================================================================== */

void K3_hat(int kmax, int iweight, int *iseq1, int length1, int *iseq2, int length2, double *sm, 
int sizeK, double *mat, double k3_11, double k3_22, double *dotprod, double *dist)
{

	double k3_12;
	
/* 	==========================================================================================
	Compute K3(Seq1,Seq2)
   	========================================================================================== */

	k3_12 = K3(kmax,iweight,iseq1,length1,iseq2,length2,sm,sizeK,mat);

/* 	==========================================================================================
	Compute K3_hat(Seq1,Seq2)
   	========================================================================================== */

	*dotprod = k3_12 / sqrt( k3_11 * k3_22);

/* 	==========================================================================================
	Compute Dist(Seq1,Seq2)
   	========================================================================================== */

	*dist = sqrt( 2 - 2*(*dotprod));

}
/* ===============================================================================================
	Procedure that find the maximum value in an array
   =============================================================================================== */

int getmax(int *array, int size)
{
	int vmax=array[0];
	for(int i = 1; i < size; i++) 
		if(array[i]>vmax) vmax = array[i];

	return vmax;
}
