#!/usr/bin/env python

READ_COUNTS_PER_ALLELE='read_counts_per_allele'
READ_COUNTS='read_counts'

import sys
if len(sys.argv) < 4:
	raise IOError, 'Usage: create_read_counts_file_for_all_replicates.py [ read_counts | read_counts_per_allele ] output_file statistics_per_locus_file_1 [ statistics_per_locus_file_2 . . . ]'

def main(option, output_filename, statistics_per_locus_filename_list):
	output_file = open(output_filename, 'w')
	
	locus_dict = {}
	filename_and_locus_to_read_count_dict = {}
	filename_and_locus_to_read_count_for_allele_dict = {}

	header_line = 'locus_name'
	
	for i in range(0, len(statistics_per_locus_filename_list)):
		statistics_per_locus_filename = statistics_per_locus_filename_list[i]
		locus_to_read_count_dict, locus_to_read_count_per_allele_dict = get_locus_to_read_count_dict(statistics_per_locus_filename)
		filename_and_locus_to_read_count_dict[statistics_per_locus_filename] = locus_to_read_count_dict
		filename_and_locus_to_read_count_for_allele_dict[statistics_per_locus_filename] = locus_to_read_count_per_allele_dict
		for locus in locus_to_read_count_dict:
			locus_dict[locus] = 0
		
		header_line += '\t'
		
		full_path_list = statistics_per_locus_filename.split('/')
		local_filename = full_path_list[len(full_path_list) - 1]
		shortened_filename = local_filename.replace('_statistics_per_locus', '')
		if option == READ_COUNTS:
			header_line += shortened_filename
		elif option == READ_COUNTS_PER_ALLELE:
			header_line += (shortened_filename + '_REF' + '\t' + shortened_filename + '_OTHER')
		else:
			raise IOError, 'Bad option: \'%s\', must be one of either \'%s\' or \'%s\'' % (option, READ_COUNTS, READ_COUNTS_PER_ALLELE)
	output_file.write(header_line + '\n')

	locus_list = locus_dict.keys()
	locus_list.sort()
	for locus in locus_list:
		output_line = locus
		for statistics_per_locus_filename in statistics_per_locus_filename_list:
			if ((not (locus in filename_and_locus_to_read_count_dict[statistics_per_locus_filename])) or \
			    (not (locus in filename_and_locus_to_read_count_for_allele_dict[statistics_per_locus_filename]))):
				raise IOError, 'Internal Error: Each locus should be present in all statistics per locus files'
			if option == READ_COUNTS:
				read_count = filename_and_locus_to_read_count_dict[statistics_per_locus_filename][locus]
				output_line += '\t%d' % read_count
			elif option == READ_COUNTS_PER_ALLELE:
				read_count_for_ref_allele = filename_and_locus_to_read_count_for_allele_dict[statistics_per_locus_filename][locus]['ref']
				read_count_for_other_allele = filename_and_locus_to_read_count_for_allele_dict[statistics_per_locus_filename][locus]['other']
				output_line += '\t%d\t%d' % (read_count_for_ref_allele, read_count_for_other_allele)
			else:
				raise IOError, 'Bad option: \'%s\', must be one of either \'%s\' or \'%s\'' % (option, READ_COUNTS, READ_COUNTS_PER_ALLELE)
		output_file.write(output_line + '\n')
	output_file.close()
	return

def get_locus_to_read_count_dict(statistics_per_locus_filename):
	locus_to_read_count_dict = {}
	locus_to_read_count_per_allele_dict = {}
	
	statistics_per_locus_file = open(statistics_per_locus_filename, 'r')
	for line in statistics_per_locus_file:
		line = line.rstrip('\n\r')
		field_list = line.split('\t')
		locus = field_list[0]
		if locus == 'locus_name':
			## header line
			continue
		num_reads = int(field_list[1])
		num_reads_for_ref_allele = int(field_list[3])
		num_reads_for_other_allele = int(field_list[4])
		
		locus_to_read_count_dict[locus] = num_reads
		locus_to_read_count_per_allele_dict[locus] = {}
		locus_to_read_count_per_allele_dict[locus]['ref'] = num_reads_for_ref_allele
		locus_to_read_count_per_allele_dict[locus]['other'] = num_reads_for_other_allele
	return locus_to_read_count_dict, locus_to_read_count_per_allele_dict

main(sys.argv[1], sys.argv[2], sys.argv[3:])
