#!/usr/bin/env python
# select_threshold_for_BA-THRESHER.py

import sys
from process_mutation_candidates_helper_functions import *



# mutation_calling_option:
# r: regular - each candidate mutation has a range (t1, t2), and when the threshold leaves the bottom of the range, the mutation is no longer predicted
# t: top - the mutation is still predicted even if you leave the bottom of the range

# output_option:
# q: prints out the threshold that we break at, followed by a warning message which is printed only if we return more predictions than specified (using selection_method == "number")
# v: prints more detailed information

if len(sys.argv) != 9:
	raise IOError, 'Usage: ./select_threshold_for_BA-THRESHER.py mutation_candidates_file selection_method selection_parameter is_best_threshold_high mutation_calling_option output_option num_dimensions_in_pooling_scheme output_mutation_candidates_file'

def main(mutation_candidates_filename, selection_method, selection_parameter_string, best_threshold_is_high_string, mutation_calling_option, output_option, num_dimensions_in_pooling_scheme_string, output_mutation_candidates_filename):
	num_dimensions_in_pooling_scheme = int(num_dimensions_in_pooling_scheme_string)
	if num_dimensions_in_pooling_scheme == 2:
		use_3D_pooling = False
	elif num_dimensions_in_pooling_scheme == 3:
		use_3D_pooling = True
	else:
		raise IOError, 'Unexpected pooling scheme dimension: %d' % num_dimensions_in_pooling_scheme

	

	if not (mutation_calling_option in ['r', 't', 'a']):
		raise IOError, 'Parameter \'mutation_calling_option\' must be set to one of: ' + \
				'a (keep all calls - expects all feasible well calls for each candidate to be in mutation candidates file),\n' + \
				'r (regular - discard redundant calls), or t (keep top scoring wells).  See notes in script for explanation.'

	if best_threshold_is_high_string.lower() == 'true':
		best_threshold_is_high = True
	elif best_threshold_is_high_string.lower() == 'false':
		best_threshold_is_high = False
	else:
		raise IOError, 'Parameter \'is_best_threshold_high\' must be set to one of \'true\' or \'false\''


	threshold_dict, mutation_to_line_dict, thresh_to_gained_candidate_dict, thresh_to_lost_candidate_dict = \
			process_mutation_candidates_file(mutation_candidates_filename, use_3D_pooling)

	## DEBUG
	#THRESH=-float('inf')
	#threshold_list = thresh_to_gained_candidate_dict.keys()
	#threshold_list.sort()
	#threshold_list.reverse()
	#candidate_count = 0
	#for i in range(0, len(threshold_list)):
	#	threshold = threshold_list[i]
	#	if threshold < THRESH:
	#		print 'breaking!'
	#		break
	#	num_candidates_gained_at_current_threshold = len(thresh_to_gained_candidate_dict[threshold])
	#	candidate_rank = candidate_count + (float(1 + num_candidates_gained_at_current_threshold) / float(2))
	#	for gained_mutation in thresh_to_gained_candidate_dict[threshold]:
	#		[organism, gene, base_change_in_TIL_string, row_string, column_string] = gained_mutation.split('\t')
	#		print '%s\t%s\t\t%s\t%f\t%s\t%s\t%d' % (organism, gene, base_change_in_TIL_string, threshold, row_string, column_string, candidate_rank)
	#	candidate_count += num_candidates_gained_at_current_threshold
	#raise IOError, 'stop!'
		

	## When trying to get a given number of predictions, there may be multiple thresholds that give this number of predictions
	## We move the threshold from most to least conservative, until we get at least the specified number of predictions
	current_candidate_dict = select_top_mutation_candidates(threshold_dict,
							thresh_to_gained_candidate_dict, thresh_to_lost_candidate_dict,
							selection_method, selection_parameter_string, 
							best_threshold_is_high, mutation_calling_option, output_option)

	print_current_mutations_to_output_file(current_candidate_dict, mutation_to_line_dict, output_mutation_candidates_filename)

	return

def select_top_mutation_candidates(threshold_dict,
			thresh_to_gained_candidate_dict, thresh_to_lost_candidate_dict,
			selection_method, selection_parameter_string, 
			best_threshold_is_high, mutation_calling_option, output_option):

	if selection_method == 'number':
		num_candidates_to_select = int(selection_parameter_string)
		cutoff_threshold = 'NA'
	elif selection_method == 'threshold':
		num_candidates_to_select = 'NA'
		cutoff_threshold = float(selection_parameter_string)
	else:
		raise IOError, 'unknown selection method: ' + selection_method


	threshold_list = threshold_dict.keys()
	threshold_list.sort()
	if best_threshold_is_high:
		threshold_list.reverse()

	current_candidate_dict = {}
	current_num_candidates = 0

	output_line = ""

	if (selection_method == 'number') and (num_candidates_to_select == 0):
		raise IOError, 'Must specify a selection number of more than 0 candidates'

	for i in range(0, len(threshold_list)):
		current_threshold = threshold_list[i]

		if (selection_method == 'threshold'):
			if ((best_threshold_is_high and (current_threshold <= cutoff_threshold)) or \
			    ((not best_threshold_is_high) and (current_threshold >= cutoff_threshold))):
				break

		#if i == 0:
		#	print 'DEBUG64: current_threshold=%.20e' % current_threshold

		## consider the effect of setting the FrNn threshold to be just below 'current_threshold'

		if current_threshold in thresh_to_gained_candidate_dict:
			gained_candidate_dict = thresh_to_gained_candidate_dict[current_threshold]
		else:
			gained_candidate_dict = {}

		if current_threshold in thresh_to_lost_candidate_dict:
			lost_candidate_dict = thresh_to_lost_candidate_dict[current_threshold]
		else:
			lost_candidate_dict = {}

		#if i == 0:
		#	print 'DEBUG64: # gained = %d, # lost = %d' % (len(gained_candidate_dict.keys()), len(lost_candidate_dict.keys()))

		for gained_mutation in gained_candidate_dict:
			## check for seeing a given mutation twice
			if (mutation_calling_option != 'a') and (gained_mutation in current_candidate_dict):
				raise IOError, 'error: unexpected (for regular candidates)'
			current_candidate_dict[gained_mutation] = 0
			current_num_candidates += 1
			#print 'DEBUG128: computed_current_num_candidates = %d, real_current_num_candidates = %d' \
			#	% (current_num_candidates, len(current_candidate_dict))

		# Only consider losing mutations by lowering the threshold, under the 'regular' mutation_calling_option:
		if mutation_calling_option == 'r':
			for lost_mutation in lost_candidate_dict:
				if not (lost_mutation in current_candidate_dict):
					raise IOError, 'Error: found mutation \'%s\' lost at threshold %.20e that is not in the current candidate dict' % (lost_mutation, current_threshold)
				else:
					del current_candidate_dict[lost_mutation]
					current_num_candidates -= 1
					#print 'DEBUG128: computed_current_num_candidates = %d, real_current_num_candidates = %d' \
					#	% (current_num_candidates, len(current_candidate_dict))
		
		if (selection_method == 'number') and (current_num_candidates >= num_candidates_to_select):
			#print 'DEBUG128: computed_current_num_candidates = %d, num_candidates_to_select = %d, real_current_num_candidates = %d' \
			#	% (current_num_candidates, num_candidates_to_select, len(current_candidate_dict))
			#logfile_null = open('LOGFILE-NULL', 'a')
			#logfile_null.write('DEBUG128: computed_current_num_candidates = %d, num_candidates_to_select = %d, current_num_candidate_position_base_change_combos = %d\n' \
			#	% (current_num_candidates, num_candidates_to_select, len(current_candidate_dict)))
			#logfile_null.close()
			output_line += '%.3e' % current_threshold
			if current_num_candidates > num_candidates_to_select:
				output_line += ('\tERROR_overshot_number_of_candidates_to_select__predicted_%d_instead_of_%d_candidates' % \
						(current_num_candidates, num_candidates_to_select))
			print output_line
			break

	return current_candidate_dict

def print_current_mutations_to_output_file(current_candidate_dict, mutation_to_line_dict, output_mutation_candidates_filename):
	output_mutation_candidates_file = open(output_mutation_candidates_filename, 'w')
	for candidate in current_candidate_dict:
		output_mutation_candidates_file.write(mutation_to_line_dict[candidate] + '\n')
	output_mutation_candidates_file.close()
	return

main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6], sys.argv[7], sys.argv[8])
