/*================================================================================================
  filterNet.h
  Version 1: 7/28/2025

  Purpose: Filter the elastic network using a rigidity approach

Copyright (c) Patrice Koehl.

================================================================================================== */

#ifndef _FILTERNET_H_
#define _FILTERNET_H_

/*================================================================================================
 Includes
================================================================================================== */

#include <numeric>
#include "Edges.h"
#include "pebble_component.h"

/*================================================================================================
 Network class
================================================================================================== */

template <typename T>
class filterNet {

	public:

		bool minimGraph(int type, std::vector<Atoms<T>>& atoms, std::vector<Edges<T>>& listIN, 
			std::vector<Edges<T>>& listOUT, int nthreads);

	private:

		void edgeLength(std::vector<Atoms<T>>& atoms, std::vector<Edges<T>>& list, std::vector<T>& Length);

	};

/*================================================================================================
 build minimal network
================================================================================================== */

 template <typename T>
 bool filterNet<T>::minimGraph(int type, std::vector<Atoms<T>>& atoms, std::vector<Edges<T>>& listIN, 
		std::vector<Edges<T>>& listOUT, int nthreads)
  {
    
	int N = atoms.size();
	int K = 6;
	int L = 6;

	int nedges = listIN.size();

	/*==========================================================================================
	Sort edges in increasing order of lengths
	============================================================================================ */

	std::vector<T> Lengths;
	edgeLength(atoms, listIN, Lengths);

	std::vector<std::pair<Edges<T>, T>> edge_with_length;
	for (int i = 0; i < (int) listIN.size(); i++) {
		edge_with_length.emplace_back(listIN[i], Lengths[i]);
	}

	std::sort(edge_with_length.begin(), edge_with_length.end(),
		[](const std::pair<Edges<T>, T>& a, const std::pair<Edges<T>, T>& b) { return a.second < b.second; });

	listIN.clear();
	for(int i = 0; i < (int) edge_with_length.size(); i++) {
		listIN.push_back(edge_with_length[i].first);
	}

	edge_with_length.clear();
	Lengths.clear();

	/*==========================================================================================
	Create pebble game
	============================================================================================ */

	pebbleComponent game(N, K, L, nthreads);

	/*==========================================================================================
	If GO potential, process bonds, a pseudo bond for each angle, and a pseudo bond for each
	dihedral angle
	============================================================================================ */

        bool added;
        int first, second;
	double ti, tc;

	if(type==2) {

		// Process bonds
		int nb=0;
		std::cout << "Process bonds: " << std::endl;
		for(int i = 0; i < N-1; i++) {
			if(atoms[i].chainid == atoms[i+1].chainid) {
				first = i; second = i+1;
                		added = game.insert_edge(first, second);
				if(!added) nb++;
			}
		}
		std::cout << "nb not included: " << nb << std::endl;

		// Process angles
		int na=0;
		std::cout << "Process angles: " << std::endl;
		for(int i = 0; i < N-2; i++) {
			if(atoms[i].chainid == atoms[i+1].chainid && atoms[i].chainid == atoms[i+2].chainid) {
				first = i; second = i+2;
                		added = game.insert_edge(first, second);
				if(!added) na++;
			}
		}
		std::cout << "nb not included: " << na << std::endl;

		// Process diheds
		int nd=0;
		std::cout << "Process diheds: " << std::endl;
		for(int i = 0; i < N-3; i++) {
			if(atoms[i].chainid == atoms[i+1].chainid && atoms[i].chainid == atoms[i+2].chainid
			&& atoms[i].chainid == atoms[i+3].chainid) {
				first = i; second = i+3;
                		added = game.insert_edge(first, second);
				if(!added) nd++;
			}
		}
		std::cout << "nd not included: " << nd << std::endl;

		std::cout << std::endl;
		std::cout << "# of redundant bonds in Go network : " << nb << std::endl;
		std::cout << "# of redundant angles in Go network: " << na << std::endl;
		std::cout << "# of redundant diheds in Go network: " << nd << std::endl;
		std::cout << std::endl;

		ti = game.time_collect; tc = game.time_component;
		std::cout << "time_collect: " << ti << " time_component: " << tc << std::endl;
		std::cout << std::endl;

		game.time_collect = 0;
		game.time_component = 0;

	}

	/*==========================================================================================
	Process elastic network edges one by one, based on endge lengths
	============================================================================================ */

        bool is_sparse = true;

        for(int idx = 0; idx < nedges; idx++) {

                first  = listIN[idx].atm1;
                second = listIN[idx].atm2;

		if(game.component[first]!=0 && game.component[first] == game.component[second]) continue;

		if(idx % 50000 == 0) {

			std::cout << "Starting edge # " << idx << " out of " << nedges << " ( " << (100.*idx)/nedges << " % done)" << std::endl;
			ti = game.time_collect;
			tc = game.time_component;
			std::cout << "# of edges in DAG : " << listOUT.size() << std::endl;
			std::cout << "time_collect: " << ti << " time_component: " << tc << std::endl;
			game.time_collect = 0;
			game.time_component = 0;
		}

                added = game.insert_edge(first, second);

                if(added) {
			Edges<T> l(first, second);
			listOUT.push_back(l);
                } else {
                        is_sparse = false;
                }
        }

	std::sort(listOUT.begin(), listOUT.end(), sortEdges<T>);

        bool is_tight = is_sparse && ( nedges == K*N-L);

	return is_tight;

    }

/*================================================================================================
 Compute edge lengths
================================================================================================== */

 template <typename T>
 void filterNet<T>::edgeLength(std::vector<Atoms<T>>& atoms, std::vector<Edges<T>>& list, std::vector<T>& Length)
 {

	Length.clear();

	int first, second;
	T x, d;
	for(int idx = 0; idx < (int) list.size(); idx++) {
		first  = list[idx].atm1;
		second = list[idx].atm2;
		d = 0;
		for(int i = 0; i < 3; i++) {
			x = atoms[first].coord[i] - atoms[second].coord[i];
			d += x*x;
		}
		Length.push_back(d);
	}
 }

#endif
