#!/usr/bin/env python

import sys
from process_gene_models_helper_functions import *

if len(sys.argv) != 4:
	raise IOError, 'Usage: ./get_SNPs_per_locus.py SNP_filename sequence_feature_annotations_filename output_filename'

def main(SNP_filename, sequence_feature_annotations_filename, output_filename):
	locus_to_models_dict, model_to_exons_dict, model_to_trimmed_exons_dict, locus_to_exons_of_selected_model_dict, \
			locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict, general_location_to_loci_dict = get_exon_data(sequence_feature_annotations_filename)
	
	locus_to_num_SNPs_dict = get_locus_to_num_SNPs_dict(SNP_filename, locus_to_models_dict, model_to_exons_dict, locus_to_exons_of_selected_model_dict, general_location_to_loci_dict)
	print_distribution_of_num_SNPs_per_locus(locus_to_num_SNPs_dict, output_filename)
	return

def get_locus_to_num_SNPs_dict(SNP_filename, locus_to_models_dict, model_to_exons_dict, locus_to_exons_of_selected_model_dict, general_location_to_loci_dict):
	## initialize num SNPs per locus
	locus_to_num_SNPs_dict = {}
	for locus in locus_to_models_dict:
		locus_to_num_SNPs_dict[locus] = 0
	
	## process SNPs
	SNP_file = open(SNP_filename, 'r')
	for line in SNP_file:
		line = line.rstrip('\n\r')
		field_list = line.split('\t')
		chromosome = field_list[0]
		position_string = field_list[1]
		position_string = position_string.lower()
		if (position_string == 'pos') or (position_string == 'position'):
			# header line
			continue
		position = int(position_string)
		
		bin_index = get_bin_index(position)
		if bin_index in general_location_to_loci_dict[chromosome]:
			candidate_locus_list = general_location_to_loci_dict[chromosome][bin_index]
		else:
			candidate_locus_list = []

		## check if SNP maps to exactly one locus model
		containing_loci_dict = {}
		for locus in candidate_locus_list:
			for model in locus_to_models_dict[locus]:
				for exon in model_to_exons_dict[model]:
					[start, stop] = exon
					if (start <= position) and (position <= stop):
						## position is within exon of locus
						containing_loci_dict[locus] = 0
		## if there is exactly one locus for which the SNP overlaps with some gene model
		if len(containing_loci_dict.keys()) == 1:
			containing_locus = containing_loci_dict.keys()[0]
			
			## then count the SNP
			if not (containing_locus in locus_to_num_SNPs_dict):
				raise IOError, 'This should not happen'
			locus_to_num_SNPs_dict[containing_locus] += 1
	SNP_file.close()
	return locus_to_num_SNPs_dict

def print_distribution_of_num_SNPs_per_locus(locus_to_num_SNPs_dict, output_filename):
	output_file = open(output_filename, 'w')
	output_file.write('SNP_ct\tnumLoci\n')
	
	threshold_list = [0, 1, 2, 5, 10]
	for i in range(0, (len(threshold_list) + 1)):
		num_loci = 0

		## get descriptor for the current range
		if i == len(threshold_list):
			range_string = '%d+' % (threshold_list[i-1] + 1)
		elif i == 0:
			range_string = '%d' % threshold_list[i]
		else:
			range_string = '%d to %d' % (threshold_list[i-1] + 1, threshold_list[i])

		## determine how many loci have a number of SNPs that falls within the current range
		for locus in locus_to_num_SNPs_dict:
			num_SNPs = locus_to_num_SNPs_dict[locus]
			
			if i == len(threshold_list):
				if threshold_list[i-1] < num_SNPs:
					num_loci += 1
			elif i == 0:
				if threshold_list[i] == num_SNPs:
					num_loci += 1
			else:
				if (threshold_list[i-1] < num_SNPs) and (num_SNPs <= threshold_list[i]):
					num_loci += 1
		## print results
		output_file.write('%s\t%d\n' % (range_string, num_loci))
	output_file.close()
	return

main(sys.argv[1], sys.argv[2], sys.argv[3])
