/*================================================================================================
  Simple Reverse Cuthill-McKee (RCM) Implementation
  
  Input: CSR matrix (n_rows, row_ptr, col_idx)
  Output: permutation vector perm where perm[i] = old_vertex_id
          (i.e., new position i should contain old vertex perm[i])
================================================================================================== */

#ifndef RCM_H_
#define RCM_H_

/*================================================================================================
 Includes
================================================================================================== */

#include <vector>
#include <queue>
#include <algorithm>
#include <iostream>

/*================================================================================================
 Class definition
================================================================================================== */

template<typename T>
class RCM {

	public:
		// Main routine: compute reordering based on row_pointer and col_pointer
		std::vector<int> computeRCM(int n_rows, int* row_ptr, int* col_idx);

  		// Build RCM permuted matrix in CSR format
		void buildRCM(int n_rows, int* row_ptr, int* col_idx, T* values, 
			std::vector<int>& perm,
			int * new_row_ptr, int *new_col_idx, T *new_values);

		//Test the bandwidth reduction
		int computeBandwidth(int n_rows, int* row_ptr, int* col_idx);

		// Compute two largest degrees of matrix
		void computeDegrees(int n_rows, int *row_ptr, int *d1, int *d2);

};

/*================================================================================================
 Main routine: compute reordering based on row_pointer and col_pointer
================================================================================================== */

template<typename T>
std::vector<int> RCM<T>::computeRCM(int n_rows, int* row_ptr, int* col_idx) {
	
	/*==========================================================================================
	Step 1: Build adjacency list (undirected graph from matrix structure)
	============================================================================================ */
	
	std::vector<std::vector<int>> adj(n_rows);
	
	// Add all edges (treat matrix as undirected graph)
	for (int i = 0; i < n_rows; i++) {
		for (int idx = row_ptr[i]; idx < row_ptr[i + 1]; idx++) {
			int j = col_idx[idx];
			if (i != j) {  // Skip diagonal
				adj[i].push_back(j);
			}
		}
	}
	
	/*==========================================================================================
	Step 2: Find starting vertex (minimum degree)
	============================================================================================ */
	
	int start_vertex = 0;
	int min_degree = adj[0].size();
	
	for (int i = 1; i < n_rows; i++) {
		int isize = adj[i].size();
		if (isize < min_degree) {
			min_degree = adj[i].size();
			start_vertex = i;
		}
	}
	
//	std::cout << "Starting vertex: " << start_vertex << " with degree " << min_degree << std::endl;
	
	/*==========================================================================================
	Step 3: Cuthill-McKee ordering using BFS
	============================================================================================ */
	
	std::vector<int> cm_order;
	std::vector<bool> visited(n_rows, false);
	std::queue<int> bfs_queue;
	
	// Start BFS from the minimum degree vertex
	bfs_queue.push(start_vertex);
	visited[start_vertex] = true;
	
	while (!bfs_queue.empty()) {
		int current = bfs_queue.front();
		bfs_queue.pop();
		
		cm_order.push_back(current);
		
		// Get neighbors and sort by degree (ascending)
		std::vector<int> neighbors = adj[current];
		std::sort(neighbors.begin(), neighbors.end(), 
				  [&adj](int a, int b) {
					  return adj[a].size() < adj[b].size();
				  });
		
		// Add unvisited neighbors to queue
		for (int neighbor : neighbors) {
			if (!visited[neighbor]) {
				visited[neighbor] = true;
				bfs_queue.push(neighbor);
			}
		}
	}
	
	/*==========================================================================================
	Step 4: Handle disconnected components (if any)
	============================================================================================ */
	
	// Add any remaining unvisited vertices
	for (int i = 0; i < n_rows; i++) {
		if (!visited[i]) {
			cm_order.push_back(i);
			std::cout << "Warning: disconnected vertex " << i << std::endl;
		}
	}
	
	/*==========================================================================================
	Step 5: Reverse for RCM
	============================================================================================ */
	
	std::reverse(cm_order.begin(), cm_order.end());
	
	/*==========================================================================================
	Step 6: Verify result
	============================================================================================ */
	
	int nsize = cm_order.size();
	if (nsize != n_rows) {
		std::cout << "ERROR: Wrong permutation size!" << std::endl;
		return {};
	}
	
	// Check it's a valid permutation
	std::vector<bool> used(n_rows, false);
	for (int i = 0; i < n_rows; i++) {
		if (cm_order[i] < 0 || cm_order[i] >= n_rows || used[cm_order[i]]) {
			std::cout << "ERROR: Invalid permutation!" << std::endl;
			return {};
		}
		used[cm_order[i]] = true;
	}
	
	std::cout << "RCM permutation generated successfully." << std::endl;
	std::cout << "First 10 entries: ";
	for (int i = 0; i < std::min(10, n_rows); i++) {
		std::cout << cm_order[i] << " ";
	}
	std::cout << std::endl;
	
	return cm_order;
}

/*================================================================================================
  Build RCM permuted matrix in CSR format
================================================================================================== */

template <typename T>
void RCM<T>::buildRCM(int n_rows, int* row_ptr, int* col_idx, T* values, 
	std::vector<int>& perm, int *new_row_ptr, int *new_col_idx, T *new_values)
{
	
	// Create inverse permutation: inv_perm[old_vertex] = new_position
	std::vector<int> inv_perm(n_rows);
	for (int i = 0; i < n_rows; i++) {
		inv_perm[perm[i]] = i;
	}

	// Count entries per new row
	std::vector<int> row_counts(n_rows, 0);
	for (int old_i = 0; old_i < n_rows; old_i++) {
		int new_i = inv_perm[old_i];
		row_counts[new_i] = row_ptr[old_i + 1] - row_ptr[old_i];
	}
	
	// Build new_row_ptr
	new_row_ptr[0] = 0;
	for (int i = 0; i < n_rows; i++) {
		new_row_ptr[i + 1] = new_row_ptr[i] + row_counts[i];
	}

	// Fill new matrix entries
	for (int new_i = 0; new_i < n_rows; new_i++) {
		int old_i = perm[new_i];  // old row corresponding to new row new_i
		
		// Collect entries for this row with their new column indices and 3D values
		std::vector<std::tuple<int, T, T, T>> entries;  // (col_idx, x, y, z)
		
		for (int idx = row_ptr[old_i]; idx < row_ptr[old_i + 1]; idx++) {
			int old_j = col_idx[idx];
			int new_j = inv_perm[old_j];
			
			// Extract the 3D vector from interleaved array
			T val_x = values[3 * idx];	  // x component
			T val_y = values[3 * idx + 1];  // y component  
			T val_z = values[3 * idx + 2];  // z component
			
			entries.push_back({new_j, val_x, val_y, val_z});
		}
		
		// Sort entries by new column index
		std::sort(entries.begin(), entries.end());
		
		// Store sorted entries in interleaved format
		int start_idx = new_row_ptr[new_i];
		for (int k = 0; k < (int) entries.size(); k++) {
			new_col_idx[start_idx + k] = std::get<0>(entries[k]);
			
			// Store the 3D vector components in interleaved format
			new_values[3 * (start_idx + k)]	 = std::get<1>(entries[k]);  // x
			new_values[3 * (start_idx + k) + 1] = std::get<2>(entries[k]);  // y
			new_values[3 * (start_idx + k) + 2] = std::get<3>(entries[k]);  // z
		}
	}

	std::cout << "RCM permuted matrix built successfully." << std::endl;
}

/*================================================================================================
  Test the bandwidth reduction
================================================================================================== */

template <typename T>
int RCM<T>::computeBandwidth(int n_rows, int* row_ptr, int* col_idx) 
{

	int bandwidth = 0;
	
	for (int i = 0; i < n_rows; i++) {
		for (int idx = row_ptr[i]; idx < row_ptr[i + 1]; idx++) {
			int j = col_idx[idx];
			bandwidth = std::max(bandwidth, std::abs(i - j));
		}
	}
	
	return bandwidth;
}


/*================================================================================================
 Compute two largest degrees of matrix
================================================================================================== */

template <typename T>
 void RCM<T>::computeDegrees(int n_rows, int *row_ptr, int *d1, int *d2)
 {
	std::vector<int> degrees;

	for(int i = 0; i < n_rows; i++) {
		degrees.push_back(row_ptr[i+1]-row_ptr[i]);
	}

	std::sort(degrees.begin(), degrees.end());

	*d1 = degrees[n_rows-1];
	*d2 = degrees[n_rows-2];
 }
#endif
