#!/usr/bin/env python

## Currently supports sequence feature annotations gff3 files for Oryza sativa (MSU v6.1) and Arabidopsis thaliana (TAIR 9)

import sys, math
from chromosome_helper_functions__module import *

def get_exon_data(sequence_feature_annotations_filename):
	chromosome_dict = get_dictionary_of_all_chromosomes_from_sequence_feature_annotations_file(sequence_feature_annotations_filename)
	locus_to_models_dict, model_to_exons_dict, model_to_trimmed_exons_dict, \
			locus_to_exons_of_selected_model_dict, \
			locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict, general_location_to_loci_dict \
			= {}, {}, {}, {}, {}, {}
	for chromosome in chromosome_dict:
		#print 'Chromosome: %s' % chromosome
		locus_to_models_dict__current_chromosome, model_to_exons_dict__current_chromosome, \
				model_to_trimmed_exons_dict__current_chromosome, \
				locus_to_exons_of_selected_model_dict__current_chromosome, \
				locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict__current_chromosome \
				= get_exon_information_for_each_locus_in_chromosome(sequence_feature_annotations_filename, chromosome, general_location_to_loci_dict)
		for locus in locus_to_models_dict__current_chromosome:
			if locus in locus_to_models_dict:
				raise IOError, 'internal error: saw locus twice'
			locus_to_models_dict[locus] = locus_to_models_dict__current_chromosome[locus]
		for model in model_to_exons_dict__current_chromosome:
			if model in model_to_exons_dict:
				raise IOError, 'internal error: saw model twice'
			model_to_exons_dict[model] = model_to_exons_dict__current_chromosome[model]
		for model in model_to_trimmed_exons_dict__current_chromosome:
			if model in model_to_trimmed_exons_dict:
				raise IOError, 'internal error: saw model twice (message 2)'
			model_to_trimmed_exons_dict[model] = model_to_trimmed_exons_dict__current_chromosome[model]
		#for locus in locus_to_mean_combined_length_of_exons_across_models_dict__current_chromosome:
		#	if locus in locus_to_mean_combined_length_of_exons_across_models_dict:
		#		raise IOError, 'internal error: saw locus twice (message 2)'
		#	locus_to_mean_combined_length_of_exons_across_models_dict[locus] = locus_to_mean_combined_length_of_exons_across_models_dict__current_chromosome[locus]
		for locus in locus_to_exons_of_selected_model_dict__current_chromosome:
			if locus in locus_to_exons_of_selected_model_dict:
				raise IOError, 'internal error: saw locus twice (message 3)'
			locus_to_exons_of_selected_model_dict[locus] = locus_to_exons_of_selected_model_dict__current_chromosome[locus]
		for locus in locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict__current_chromosome:
			if locus in locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict:
				raise IOError, 'internal error: saw locus twice (message 4)'
			locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict[locus] = \
					locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict__current_chromosome[locus]
	return locus_to_models_dict, model_to_exons_dict, model_to_trimmed_exons_dict, locus_to_exons_of_selected_model_dict, \
			locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict, general_location_to_loci_dict

def get_dictionary_of_all_chromosomes_from_sequence_feature_annotations_file(sequence_feature_annotations_filename):
	chromosome_dict = {}

	sequence_feature_annotations_file = open(sequence_feature_annotations_filename, 'r')
	for line in sequence_feature_annotations_file:
		line = line.rstrip('\n\r')
		field_list = line.split('\t')
		chromosome = get_base_chromosome_name(field_list[0])
		chromosome_dict[chromosome] = 0
	sequence_feature_annotations_file.close()
	
	return chromosome_dict

def get_exon_information_for_each_locus_in_chromosome(sequence_feature_annotations_filename, chromosome, general_location_to_loci_dict):
	
	locus_to_models_dict, model_to_exons_dict, locus_to_selected_model_dict, locus_to_exons_of_selected_model_dict = \
			process_sequence_feature_annotations_filename(sequence_feature_annotations_filename, chromosome)
	

	## update each locus in location to loci dict
	locus_to_range_dict = get_locus_to_range_dict(locus_to_models_dict, model_to_exons_dict)
	update_location_to_loci_dict(locus_to_range_dict, chromosome, general_location_to_loci_dict)

	## Get combined length of trimmed exons for specified model of each locus
	start_pos_locus_tuple_list = []
	for locus in locus_to_range_dict:
		locus_start = locus_to_range_dict[locus][0]
		start_pos_locus_tuple_list.append((locus_start, locus))
	start_pos_locus_tuple_list.sort()

	locus_to_exons_from_union_of_models_dict = get_locus_to_exons_from_union_of_models_dict(locus_to_models_dict, model_to_exons_dict)

	model_to_trimmed_exons_dict, locus_to_combined_length_of_non_overlapping_regions_of_each_exon_model_dict \
			= get_model_to_trimmed_exons_dict(locus_to_models_dict, model_to_exons_dict, locus_to_range_dict, 
							  locus_to_exons_from_union_of_models_dict, start_pos_locus_tuple_list)

	locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict = {}
	for locus in locus_to_selected_model_dict:
		selected_model = locus_to_selected_model_dict[locus]
		locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict[locus] = \
				locus_to_combined_length_of_non_overlapping_regions_of_each_exon_model_dict[locus][selected_model]

	return locus_to_models_dict, model_to_exons_dict, model_to_trimmed_exons_dict, \
			locus_to_exons_of_selected_model_dict, locus_to_combined_length_of_non_overlapping_regions_of_selected_exon_model_dict

def update_location_to_loci_dict(locus_to_range_dict, chromosome, general_location_to_loci_dict):
	## for each locus
	for locus in locus_to_range_dict:
		first_bin_index = get_bin_index(locus_to_range_dict[locus][0])
		last_bin_index = get_bin_index(locus_to_range_dict[locus][1])

		## for each bin position associated with that locus
		bin_index = first_bin_index
		while bin_index <= last_bin_index:
			## assign the locus to that bin position
			if not (chromosome in general_location_to_loci_dict):
				general_location_to_loci_dict[chromosome] = {}
			if not (bin_index in general_location_to_loci_dict[chromosome]):
				general_location_to_loci_dict[chromosome][bin_index] = []
			general_location_to_loci_dict[chromosome][bin_index].append(locus)
			bin_index += 1
	return

def get_bin_index(position):
	## affects runtime of program, but does not affect results
	bin_size = get_bin_size()
	bin_index = int(math.floor(float(position - 1) / bin_size))
	return bin_index

def get_bin_size():
	bin_size = 10000.0
	return bin_size

def process_sequence_feature_annotations_filename(sequence_feature_annotations_filename, specified_chromosome):
	
	locus_to_models_dict = {}
	locus_to_lowest_model_number_dict = {}
	locus_to_selected_model_dict = {}
	model_to_exons_dict = {}
	locus_to_exons_of_selected_model_dict = {}
	transposable_elements_also_listed_as_genes_dict = {}

	## this will handle each locus model that is explicitly represented by a line of the sequence annotation file,
	## for the supported organisms
	model_to_locus_dict = get_model_to_locus_dict(sequence_feature_annotations_filename)

	sequence_feature_annotations_file = open(sequence_feature_annotations_filename, 'r')
	for line in sequence_feature_annotations_file:
		line = line.rstrip('\n\r')
		## skip comment lines
		if line[0:2] == '##':
			continue
		field_list = line.split('\t')
		chromosome = get_base_chromosome_name(field_list[0])
		if chromosome != specified_chromosome:
			continue
		feature_type = field_list[2]
		start_position = int(field_list[3])
		stop_position = int(field_list[4])
		information_string = field_list[8]
		if feature_type in ['exon', 'pseudogenic_exon']:
			model = get_value_of_key_from_information_string('Parent', information_string)
			model_number_string = model.split('.')[1]
			# the following line is to handle the model name format in the gff3 for Oryza sativa
			model_number_string = model_number_string.replace('m', '')
			model_number = int(model_number_string)

			## update the list of all models and the selected model for the given locus
			locus = model_to_locus_dict[model]
			if not (locus in locus_to_models_dict):
				# if this is the first time we are seeing this locus
				locus_to_models_dict[locus] = [model]
				locus_to_lowest_model_number_dict[locus] = model_number
				locus_to_selected_model_dict[locus] = model
			else:
				if not (model in locus_to_models_dict[locus]):
					locus_to_models_dict[locus].append(model)
				if model_number < locus_to_lowest_model_number_dict[locus]:
					locus_to_lowest_model_number_dict[locus] = model_number
					locus_to_selected_model_dict[locus] = model
				if (model_number == locus_to_lowest_model_number_dict[locus]) and (model != locus_to_selected_model_dict[locus]):
					raise IOError, 'Unexpected: found two different models for locus \'%s\' that are tied for the lowest model number so far: %d' % (locus, model_number)
			
			## update model to exons dict
			exon = [start_position, stop_position]
			if not (model in model_to_exons_dict):
				# if this is the first time we are seeing this model
				model_to_exons_dict[model] = []
			model_to_exons_dict[model].append(exon)
		elif feature_type == 'transposon_fragment':
			TE = get_value_of_key_from_information_string('Parent', information_string)
			if '.' in TE:
				raise IOError, 'Did not expect \'.\' in name of TE'
			## update locus to models dict
			## and model to exons dict
			if not (TE in locus_to_models_dict):
				## this is the first time we are seeing this TE
				locus_to_models_dict[TE] = [TE]
				model_to_exons_dict[TE] = []
			fragment_range = [start_position, stop_position]
			model_to_exons_dict[TE].append(fragment_range)
		elif feature_type == 'transposable_element_gene':
			TE = get_value_of_key_from_information_string('Derives_from', information_string)
			if TE != -1:
				transposable_elements_also_listed_as_genes_dict[TE] = 0
	sequence_feature_annotations_file.close()

	## remove transposable elements that are also listed as transposable element genes
	for TE in transposable_elements_also_listed_as_genes_dict:
		## check for case where the information for the given TE is absent or has been removed
		## from the sequence feature annotations file
		if TE in locus_to_models_dict:
			del locus_to_models_dict[TE]
		if TE in model_to_exons_dict:
			del model_to_exons_dict[TE]
	
	## sort each exon list in 'model_to_exons_dict'
	for model in model_to_exons_dict:
		model_to_exons_dict[model].sort()
	
	## check the format in which the exon models are stored
	check_exon_models(locus_to_models_dict, model_to_exons_dict, locus_to_selected_model_dict)

	## get values for locus_to_exons_of_selected_model_dict
	for locus in locus_to_models_dict:
		if locus in locus_to_selected_model_dict:
			## non-TE locus
			model = locus_to_selected_model_dict[locus]
			locus_to_exons_of_selected_model_dict[locus] = [[pos for pos in exon] for exon in model_to_exons_dict[model]]
		else:
			## TE locus
			locus_to_exons_of_selected_model_dict[locus] = [[pos for pos in exon] for exon in model_to_exons_dict[locus]]
	
	return locus_to_models_dict, model_to_exons_dict, locus_to_selected_model_dict, locus_to_exons_of_selected_model_dict

def get_model_to_locus_dict(sequence_feature_annotations_filename):
	model_to_locus_dict = {}
	sequence_feature_annotations_file = open(sequence_feature_annotations_filename)
	for line in sequence_feature_annotations_file:
		line = line.rstrip('\n\r')
		## skip comment lines
		if line[0:2] == '##':
			continue
		field_list = line.split('\t')
		feature_type = field_list[2]
		information_string = field_list[8]
		if feature_type in ['snoRNA', 'ncRNA', 'miRNA', 'tRNA', 'rRNA', 'mRNA', 'snRNA', 'pseudogenic_transcript']:
			model = get_value_of_key_from_information_string('ID', information_string)
			locus = get_value_of_key_from_information_string('Parent', information_string)
			if model in model_to_locus_dict:
				raise IOError, 'Unexpected: saw two transcript-type features for same model: \'%s\'' % model
			model_to_locus_dict[model] = locus
	sequence_feature_annotations_file.close()
	return model_to_locus_dict

def get_value_of_key_from_information_string(specified_key, information_string):
	assignment_list = information_string.split(';')
	value_to_return = -1
	for assignment in assignment_list:
		assignment = assignment.strip()
		if assignment == '':
			continue
		assignment_field_list = assignment.split('=')
		if len(assignment_field_list) != 2:
			raise IOError, 'Expected assignment string with format \'key=value\': %s' % assignment
		[key, value] = assignment_field_list
		if key == specified_key:
			value = value.strip()
			if value_to_return != -1:
				raise IOError, 'Saw two entries for specified key \'%s\' in information string: %s' % (specified_key, information_string)
			value_to_return = value
	return value_to_return

def check_exon_models(locus_to_models_dict, model_to_exons_dict, locus_to_selected_model_dict):
	for locus in locus_to_models_dict:
		if locus in locus_to_selected_model_dict:
			## non-TE locus
			for model in locus_to_models_dict[locus]:
				exon_list=model_to_exons_dict[model]
				num_exons=len(exon_list)
				
				## check that consecutive exons are ordered correctly and have gaps between them
				for i in range(0, num_exons - 1):
					current_exon_stop = exon_list[i][1]
					next_exon_start = exon_list[i+1][0]
					if not ((current_exon_stop + 1) < next_exon_start):
						raise IOError, ('Found consecutively-listed exons for a non-TE locus that are either out of order, overlap, ' + \
								'or are directly adjacent to one another: [ . . .; (%d, %d) ; (%d, %d) ; . . .]') % \
							       (exon_list[i][0], exon_list[i][1], exon_list[i+1][0], exon_list[i+1][1])
		else:
			## TE locus
			model = locus
			exon_list=model_to_exons_dict[model]
			num_exons=len(exon_list)

			## check that consecutive exons are sorted by starting position
			for i in range(0, num_exons - 1):
				current_exon_start = exon_list[i][0]
				next_exon_start = exon_list[i+1][0]
				current_exon_stop = exon_list[i][1]
				next_exon_stop = exon_list[i+1][1]
				if not ((current_exon_start < next_exon_start) or ((current_exon_start == next_exon_start) and (current_exon_stop < next_exon_stop))):
					raise IOError, ('Found consecutively-listed exons for a TE locus that are out of order: ' + \
							'[ . . .; (%d, %d) ; (%d, %d) ; . . .]') % \
						       (exon_list[i][0], exon_list[i][1], exon_list[i+1][0], exon_list[i+1][1])

		## for both TE and non-TE loci
		for model in locus_to_models_dict[locus]:
			exon_list=model_to_exons_dict[model]
			num_exons=len(exon_list)

			## check that the start position of each exon is less than or equal to the stop position
			for i in range(0, num_exons):
				[start, stop] = exon_list[i]
				if start > stop:
					raise IOError, 'Start position of exon is greater than stop position of exon'
	return

def get_locus_to_range_dict(locus_to_models_dict, model_to_exons_dict):
	locus_min_dict = {}
	locus_max_dict = {}
	locus_to_range_dict = {}
	for locus in locus_to_models_dict:
		for model in locus_to_models_dict[locus]:
			for exon in model_to_exons_dict[model]:
				[start, stop] = exon
				if (not (locus in locus_min_dict)) or (start < locus_min_dict[locus]):
					locus_min_dict[locus] = start
				if (not (locus in locus_max_dict)) or (stop > locus_max_dict[locus]):
					locus_max_dict[locus] = stop
		locus_to_range_dict[locus] = [locus_min_dict[locus], locus_max_dict[locus]]
	return locus_to_range_dict

def get_locus_to_exons_from_union_of_models_dict(locus_to_gene_models_dict, gene_model_to_exons_dict):
	locus_to_exons_from_union_of_models_dict = {}
	for locus in locus_to_gene_models_dict:
		union_exon_list = []
		for gene_model in locus_to_gene_models_dict[locus]:
			for exon in gene_model_to_exons_dict[gene_model]:
				[start, stop] = exon
				union_exon_list = get_union_exon_list_after_adding_exon(exon, union_exon_list)
		locus_to_exons_from_union_of_models_dict[locus] = union_exon_list
	return locus_to_exons_from_union_of_models_dict

def get_union_exon_list_after_adding_exon(exon, union_exon_list):
	[exon_start, exon_stop] = exon

	union_exon_list.sort()

	idx1 = len(union_exon_list)
	for i in range(0, len(union_exon_list)):
		[union_exon_start, union_exon_stop] = union_exon_list[i]
		if (exon_start <= (union_exon_stop + 1)):
			idx1 = i
			break

	idx2 = -1
	for i in range(0, len(union_exon_list)):
		[union_exon_start, union_exon_stop] = union_exon_list[i]
		if (exon_stop >= (union_exon_start - 1)):
			idx2 = i
	
	if idx1 == len(union_exon_list):
		new_exon_start = exon_start
	else:
		new_exon_start = min(union_exon_list[idx1][0], exon_start)

	if idx2 == -1:
		new_exon_stop = exon_stop
	else:
		new_exon_stop = max(union_exon_list[idx2][1], exon_stop)
	new_exon = [new_exon_start, new_exon_stop]

	new_union_exon_list = union_exon_list[0:idx1] + [new_exon] + union_exon_list[(idx2+1):]
	return new_union_exon_list

def get_model_to_trimmed_exons_dict(locus_to_models_dict, model_to_exons_dict, locus_to_range_dict, \
											 locus_to_exons_from_union_of_models_dict, start_pos_locus_tuple_list):
	## get exons from the selected model for each locus
	model_to_trimmed_exons_dict = {}
	for locus in locus_to_models_dict:
		for model in locus_to_models_dict[locus]:
			exon_list = model_to_exons_dict[model]
			## join overlapping exons within the exon list of the selected model,
			## so that we do not double-count any position when determining
			## the combined length of the trimmed exons
			## (transposable elements may have adjacent or overlapping fragments)
			new_exon_list = [[position for position in exon] for exon in exon_list]
			i = 0
			while i < (len(new_exon_list) - 1):
				current_exon_start = new_exon_list[i][0]
				current_exon_stop = new_exon_list[i][1]
				next_exon_start = new_exon_list[i+1][0]
				next_exon_stop = new_exon_list[i+1][1]
				#print 'DEBUG: for locus %s, comparing (%d-%d) to (%d-%d)' % (locus, current_exon_start, current_exon_stop, next_exon_start, next_exon_stop)
				if (current_exon_stop + 1) > next_exon_start:
					## join the overlapping exons
					new_exon = [current_exon_start, max(current_exon_stop, next_exon_stop)]
					new_exon_list = new_exon_list[0:i] + [new_exon] + new_exon_list[(i+2):]
				else:
					i += 1
			if model in model_to_trimmed_exons_dict:
				raise IOError, ('Unexpected: Saw model %s two times' % model)
			model_to_trimmed_exons_dict[model] = new_exon_list

	## for each locus
	num_loci = len(start_pos_locus_tuple_list)
	for locus_index in range(0, num_loci):
		locus = start_pos_locus_tuple_list[locus_index][1]
		locus_stop = locus_to_range_dict[locus][1]
		
		## for each other locus that overlaps with it
		other_locus_index = locus_index + 1
		while other_locus_index < num_loci:
			[other_locus_start, other_locus] = start_pos_locus_tuple_list[other_locus_index]
			if locus_stop < other_locus_start:
				## no overlap
				break
			
			## trim overlapping regions for each exon model
			for model in locus_to_models_dict[locus]:
				for union_exon in locus_to_exons_from_union_of_models_dict[other_locus]:
					model_to_trimmed_exons_dict[model] = \
						get_trimmed_exon_list(model_to_trimmed_exons_dict[model], union_exon)
			for model in locus_to_models_dict[other_locus]:
				for union_exon in locus_to_exons_from_union_of_models_dict[locus]:
					model_to_trimmed_exons_dict[model] = \
						get_trimmed_exon_list(model_to_trimmed_exons_dict[model], union_exon)
			other_locus_index += 1
	
	locus_to_combined_length_of_non_overlapping_regions_of_each_exon_model_dict = {}
	for locus in locus_to_models_dict:
		locus_to_combined_length_of_non_overlapping_regions_of_each_exon_model_dict[locus] = {}
		for model in locus_to_models_dict[locus]:
			sum_of_exon_lengths = 0
			for exon in model_to_trimmed_exons_dict[model]:
				[exon_start, exon_stop] = exon
				exon_length = exon_stop - exon_start + 1
				sum_of_exon_lengths += exon_length
			locus_to_combined_length_of_non_overlapping_regions_of_each_exon_model_dict[locus][model] = sum_of_exon_lengths
	return model_to_trimmed_exons_dict, locus_to_combined_length_of_non_overlapping_regions_of_each_exon_model_dict

def get_trimmed_exon_list(exon_list, range_to_trim):
	trimmed_exon_list = []
	for exon_index in range(0, len(exon_list)):
		exon = exon_list[exon_index]
		[exon_start, exon_stop] = exon
		[trim_start, trim_stop] = range_to_trim
		trim_region_overlaps_exon_start = (trim_start <= exon_start) and (exon_start <= trim_stop)
		trim_region_overlaps_exon_stop = (trim_start <= exon_stop) and (exon_stop <= trim_stop)
		if trim_region_overlaps_exon_start:
			exon_start = trim_stop + 1
		if trim_region_overlaps_exon_stop:
			exon_stop = trim_start - 1
		if exon_start <= exon_stop:
			trimmed_exon = [exon_start, exon_stop]
			trimmed_exon_list.append(trimmed_exon)
	return trimmed_exon_list

