import gzip
import multiprocessing
import os
import json
from sys import argv
import pysam
import numpy as np
#import matplotlib.pyplot as plt
from __builtin__ import range
import re
import cPickle
__author__ = 'basir'


class ParseFastQ(object):
    """Returns a read-by-read fastQ parser analogous to file.readline()"""

    def __init__(self, filePath, headerSymbols=['@', '+']):
        """Returns a read-by-read fastQ parser analogous to file.readline().
        Exmpl: parser.next()
        -OR-
        Its an iterator so you can do:
        for rec in parser:
            ... do something with rec ...

        rec is tuple: (seqHeader,seqStr,qualHeader,qualStr)
        """
        if filePath.endswith('.gz'):
            self._file = gzip.open(filePath)
        else:
            self._file = open(filePath, 'rU')
        self._currentLineNumber = 0
        self._hdSyms = headerSymbols

    def __iter__(self):
        return self

    def next(self):
        """Reads in next element, parses, and does minimal verification.
        Returns: tuple: (seqHeader,seqStr,qualHeader,qualStr)"""
        # ++++ Get Next Four Lines ++++
        elemList = []
        for i in range(4):
            line = self._file.readline()
            self._currentLineNumber += 1  ## increment file position
            if line:
                elemList.append(line.strip('\n'))
            else:
                elemList.append(None)

        # ++++ Check Lines For Expected Form ++++
        trues = [bool(x) for x in elemList].count(True)
        nones = elemList.count(None)
        # -- Check for acceptable end of file --
        if nones == 4:
            raise StopIteration
        # -- Make sure we got 4 full lines of data --
        assert trues == 4, \
            "** ERROR: It looks like I encountered a premature EOF or empty line.\n\
               Please check FastQ file near line number %s (plus or minus ~4 lines) and try again**" % (
                self._currentLineNumber)
        # -- Make sure we are in the correct "register" --
        assert elemList[0].startswith(self._hdSyms[0]), \
            "** ERROR: The 1st line in fastq element does not start with '%s'.\n\
               Please check FastQ file near line number %s (plus or minus ~4 lines) and try again**" % (
                self._hdSyms[0], self._currentLineNumber)
        assert elemList[2].startswith(self._hdSyms[1]), \
            "** ERROR: The 3rd line in fastq element does not start with '%s'.\n\
               Please check FastQ file near line number %s (plus or minus ~4 lines) and try again**" % (
                self._hdSyms[1], self._currentLineNumber)
        # -- Make sure the seq line and qual line have equal lengths --
        assert len(elemList[1]) == len(elemList[3]), "** ERROR: The length of Sequence data and Quality data of the last record aren't equal.\n\
               Please check FastQ file near line number %s (plus or minus ~4 lines) and try again**" % (
            self._currentLineNumber)

        # ++++ Return fatsQ data as tuple ++++
        return tuple(elemList)

def collect_id_position_table(bam_file):
	table = {}
	counter = 0
	print "Collecting the id-position talbe..."
	file_handle = pysam.Samfile(bam_file, "rb")
	for aligned_reads in file_handle.fetch():
		if aligned_reads.is_unmapped: continue
		if counter % 1000000 == 999999 : print str(counter+1)+"\treads has been processed.."
		if aligned_reads.is_reverse:
			positions = np.array([-aligned_reads.pos]) 
		else:
			positions = np.array([aligned_reads.pos]) 
		tags = dict(aligned_reads.tags)
		if "XA" in tags:
			aas = tags["XA"].split(";")[:-1]
			for aa in aas:
				position = int(aa.split(",")[-3])
				np.append(positions,position)
			
		read_id = "@"+aligned_reads.qname+"/"+("2" if aligned_reads.is_read2 else "1")
		#mate_id = "@"+aligned_reads.qname+"/"+("1" if aligned_reads.is_read2 else "2")
		#if mate_id in table:
		#	table[mate_id].append(pos)
		if not read_id in table:
			table[read_id] = positions
		else:
			np.append(table[read_id],positions)

		counter+=1
	file_handle.close()
	with open('id_position.pickle', 'wb') as handle:
  		cPickle.dump(table, handle,-1)
		handle.close()
	for key in table:
		table[key] = np.unique(table[key])
	return table


def reverse_complement(string):
	complement = {'c':'g','t':'a','g':'c','a':'t','\n':'','n':'n'}
	rc = ''
	s = string.lower()
	for c in s:
		rc += complement[c]
	return rc[::-1]


def annotate_k(table, reads, landscape):
	def_k = 55
	max_k = 77
	parser = ParseFastQ(reads)
	annotated_file = open("data/"+file_wo_ext(reads)+"_annotated.fastq","w")
	counter = 0
	everything = ""
	for record in parser:
		read = list(record)
		read_id = read[0]
		if counter % 100000 == 99999 : 
			print str(counter+1)+"\t reads has been processed.."
			annotated_file.write(everything)
			everything = ""
		read_length = len(read[1])
		read_seq = read[1]
		if read_id in table:
			# reads that align to the genome
			positions = table[read_id]
			m_k = 0 
			f_index = 1
			r_index = 1
			for pos in positions:
				pp = abs(pos)
			#	print type(positions)
				n_f = (positions >0).sum()
				n_r = (positions <= 0).sum()
		 
			#	print "number of forward alignments " + str(n_f)
			#	print "number of  alignments " + str(n_r)

				for_alns = np.zeros([n_f+1,read_length])
				rev_alns = np.zeros([n_r+1,read_length])
				for_alns [0,:] = max_k 
				rev_alns [0,:] = max_k
				if pos > 0:
					if pp+read_length < len(landscape):
						for_alns[f_index,:]=landscape[pp:pp+read_length]+1
					else:
						for_alns[f_index,:len(landscape)-pp]=landscape[pp:]+1
						for_alns[f_index,len(landscape)-pp:] = landscape[:read_length-(len(landscape)-pp)]+1
					f_index += 1
				else:
					if pp+read_length < len(landscape):
						rev_alns[r_index,:]=landscape[pp:pp+read_length]+1
					else:
						rev_alns[r_index,:len(landscape)-pp]=landscape[pp:]+1
						rev_alns[r_index,len(landscape)-pp:] = landscape[:read_length-(len(landscape)-pp)]+1
					r_index += 1
			if for_alns.shape[0] > 1:
				for_maxs = np.max(for_alns[1:,:],0)
				for_final = np.minimum(for_maxs,for_alns[0,:])			
				for_shreds = str(for_final).replace(".",",").replace("\n","").replace(" ","")[1:-2]
				rid = read[0] + "___"+for_shreds + "\n"
				everything += rid
				for k in range(1,len(read)):
					everything += read[k] + '\n'
			if rev_alns.shape[0]>1:
				rev_maxs = np.max(rev_alns[1:,:],0)
				rev_final = np.minimum(rev_maxs,rev_alns[0,:]) 
				rev_shreds = str(rev_final).replace(".",",").replace("\n","").replace(" ","")[1:-2]
				
				dual_id = read[0] + "_dual___" + rev_shreds + "\n"
				everything += dual_id
				everything += reverse_complement(read[1])+"\n"
				for k in range(2,len(read)):
					everything += read[k] + '\n'
		else:
			#reads that don't align to the reference genome
			default_shredding = (str(def_k)+",")*(read_length-1)+str(def_k)
			read[0] += "___"+default_shredding
			for j in range(len(read)):
				everything +=read[j]+'\n'
		counter += 1
	annotated_file.write(everything)
	annotated_file.close()


def annotate_reads(table, coverage, landscape, read1, read2=None):
    lss = np.loadtxt(landscape)
    print "Annotating first set of reads ..."+read1 
    annotate_k(table, read1, lss)
    if read2 == None:
		return
  #  print "Annotating second set of reads ..."+read2 
  #  annotate_k(table, read2, lss)

def file_wo_ext(filename):
    return os.path.splitext(os.path.basename(filename))[0]


def align_reads_to_reference(reference, reads1, reads2):
	sam_file = "data/"+file_wo_ext(reference)+"_"+file_wo_ext(reads1)+".sam" 
	sai1_file = "data/"+file_wo_ext(reference)+"_"+file_wo_ext(reads1)+"1.sai" 
	sai2_file = "data/"+file_wo_ext(reference)+"_"+file_wo_ext(reads2)+"2.sai" 
	bam_file = "data/"+file_wo_ext(reference)+"_"+file_wo_ext(reads1)
	if os.path.exists(bam_file):
		return
	cores = str(multiprocessing.cpu_count())
	print "Indexing reference genome"
	os.system("6.2/bwa index " + reference)
	print "Aligning the reads...."
	os.system("6.2/bwa aln -t " + cores + " -N " + reference + " " + reads1 + " > " + sai1_file)
	os.system("6.2/bwa aln -t " + cores + " -N " + reference + " " + reads2 + " > " + sai2_file)
	os.system("6.2/bwa sampe -n 10000 " + reference + " " + sai1_file  + " " + sai2_file  + " " + reads1  + " " + reads2  + " > " + sam_file)
	print "Converting to BAM..."
	os.system("samtools view -bS " + sam_file + " | samtools sort - " + bam_file)
	output_bam = bam_file + ".bam"
	os.system("samtools index " + output_bam)
	os.system("bed/bin/bedtools genomecov -d -ibam "+output_bam+" | cut -f 3-  > "+bam_file+".cov")
	print "BAM file is ready: " + output_bam
	return output_bam, bam_file+".cov"

def compute_landscape(reference):
	landscape = "data/"+file_wo_ext(reference)+".lss"
	if os.path.exists(landscape):
		return
	seqfile = reference
	print "Here is the file i am going to build the landscape for ... " + reference 
	os.system("bin/build_dawg "+seqfile+" "+seqfile+" "+landscape)
	print "Done from the python script! "
	return landscape

def draw_coverage_landscape(landscape,coverage):
    ls = np.loadtxt(landscape)
    cg = np.loadtxt(coverage)
    plt.plot(ls)
    plt.plot(cg)
    plt.show()


def main():
	reference = argv[1]
	reads1 = argv[2]
	reads2 = argv[3]
	maximal_landscape = compute_landscape(reference)
	bamFile,coverage = align_reads_to_reference(reference, reads1, reads2)
	#maximal_landscape = "data/ecoli.lss"
	#bamFile = "data/ec_ec.bam"
	#bamFile = "data/ecoli_rc_mc.orig.1.fq.00.cor.bam"
	#coverage = "data/ecoli_rc_mc.orig.1.fq.00.cor.cov"
	#draw_coverage_landscape(maximal_landscape,coverage)
	table = collect_id_position_table(bamFile)
	#with open('id_position.json') as f:
	#    table = json.load(f)
	#with open('id_position.pickle') as handle:
	#	table=cPickle.load(handle)   
	annotate_reads(table, coverage, maximal_landscape, reads1, reads2)

if __name__=="__main__":
	main()
