import gzip
import multiprocessing
import os
import json
from sys import argv
import pysam
import numpy as np
import matplotlib.pyplot as plt
from __builtin__ import range
import re
__author__ = 'basir'


class ParseFastQ(object):
    """Returns a read-by-read fastQ parser analogous to file.readline()"""

    def __init__(self, filePath, headerSymbols=['@', '+']):
        """Returns a read-by-read fastQ parser analogous to file.readline().
        Exmpl: parser.next()
        -OR-
        Its an iterator so you can do:
        for rec in parser:
            ... do something with rec ...

        rec is tuple: (seqHeader,seqStr,qualHeader,qualStr)
        """
        if filePath.endswith('.gz'):
            self._file = gzip.open(filePath)
        else:
            self._file = open(filePath, 'rU')
        self._currentLineNumber = 0
        self._hdSyms = headerSymbols

    def __iter__(self):
        return self

    def next(self):
        """Reads in next element, parses, and does minimal verification.
        Returns: tuple: (seqHeader,seqStr,qualHeader,qualStr)"""
        # ++++ Get Next Four Lines ++++
        elemList = []
        for i in range(4):
            line = self._file.readline()
            self._currentLineNumber += 1  ## increment file position
            if line:
                elemList.append(line.strip('\n'))
            else:
                elemList.append(None)

        # ++++ Check Lines For Expected Form ++++
        trues = [bool(x) for x in elemList].count(True)
        nones = elemList.count(None)
        # -- Check for acceptable end of file --
        if nones == 4:
            raise StopIteration
        # -- Make sure we got 4 full lines of data --
        assert trues == 4, \
            "** ERROR: It looks like I encountered a premature EOF or empty line.\n\
               Please check FastQ file near line number %s (plus or minus ~4 lines) and try again**" % (
                self._currentLineNumber)
        # -- Make sure we are in the correct "register" --
        assert elemList[0].startswith(self._hdSyms[0]), \
            "** ERROR: The 1st line in fastq element does not start with '%s'.\n\
               Please check FastQ file near line number %s (plus or minus ~4 lines) and try again**" % (
                self._hdSyms[0], self._currentLineNumber)
        assert elemList[2].startswith(self._hdSyms[1]), \
            "** ERROR: The 3rd line in fastq element does not start with '%s'.\n\
               Please check FastQ file near line number %s (plus or minus ~4 lines) and try again**" % (
                self._hdSyms[1], self._currentLineNumber)
        # -- Make sure the seq line and qual line have equal lengths --
        assert len(elemList[1]) == len(elemList[3]), "** ERROR: The length of Sequence data and Quality data of the last record aren't equal.\n\
               Please check FastQ file near line number %s (plus or minus ~4 lines) and try again**" % (
            self._currentLineNumber)

        # ++++ Return fatsQ data as tuple ++++
        return tuple(elemList)


def collect_id_position_table(bam_file):
	table = {}
	counter = 0
	print "Collecting the id-position talbe..."
	file_handle = pysam.Samfile(bam_file, "rb")
	for aligned_reads in file_handle.fetch():
		pos = -aligned_reads.pos if aligned_reads.is_reverse else aligned_reads.pos
		positions = [pos]
		orientation = "-" if aligned_reads.is_reverse else "+"
		if aligned_reads.is_unmapped:
			continue
		print pos
		tags = dict(aligned_reads.tags)
		if "XA" in tags:
			xa = tags["XA"]
			print "alternative alignments"
			for aa in xa.split(";")[:-1]:
				print aa
				positions.append(int(aa.split(",")[-3]))
	
		read_id = "@"+aligned_reads.qname+"/"+("2" if aligned_reads.is_reverse else "1")
			
		print read_id + " : " + str(positions)
		if not read_id in table: 
			table[read_id] = aligned_reads.pos
		else:
			if table[read_id] < aligned_reads.pos:
				table[read_id] = aligned_reads.pos	
		counter += 1
		print "*************************************" 
	file_handle.close()
	with open('id_position.json', 'w') as f:
		json.dump(table,f)
	return table

def annotate_k(table, reads, landscape):
	defaultk = 55
	max_k = 77 
	dk_str = str(defaultk)
	parser = ParseFastQ(reads)
	annotated_file = open("data/"+file_wo_ext(reads)+"_annotated.fastq","w")
	counter = 0
	for record in parser:
		read = list(record)
		if counter % 1000000 == 999999 : print str(counter+1)+"\t reads has been processed.."
		read_id = (read[0])[:-2]
		read_length = len(read[1])
		read_seq = read[1]
		if read_id in table:# reads that do align to the genome
			pos = table[read_id]
			k = int(np.max(landscape[pos:pos+read_length])+1)
			if k > read_length-1:
				read[0] += "___"+dk_str
			else:
				read[0] += "___"+str(k)
		else:#reads that do not align to the reference genome
			if read_length-1 > defaultk:
				read[0] += "___"+dk_str
			else:
				read[0] += "___"+str(read_length-1)
		for j in range(len(read)):
			annotated_file.write(read[j]+'\n')
		counter += 1
	annotated_file.close()

def annotate_reads(table, coverage, landscape, read1, read2=None):
    lss = np.loadtxt(landscape)
    print "Annotating first set of reads ..."+read1 
    annotate_k(table, read1, lss)
    if reads2 == None:
	return
    print "Annotating second set of reads ..."+read2 
    annotate_k(table, read2, lss)

def file_wo_ext(filename):
    return os.path.splitext(os.path.basename(filename))[0]


def align_reads_to_reference(reference, reads1, reads2):
    sam_file = "data/"+file_wo_ext(reference)+"_"+file_wo_ext(reads1)+".sam" 
    bam_file = "data/"+file_wo_ext(reference)+"_"+file_wo_ext(reads1)
    cores = str(multiprocessing.cpu_count())
    # Indexing the reference genome
    print "Indexing the genome ..."
    os.system("bwa index " + reference)
    print "Aligning the reads...."
    os.system("bwa mem -t " + cores + " " + reference + " " + reads1 + " " + reads2 + " > " + sam_file)
    print "Converting to BAM..."
    os.system("samtools view -bS " + sam_file + " | samtools sort - " + bam_file)
    output_bam = bam_file + ".bam"
    os.system("samtools index " + output_bam)
    os.system("../pathogen_detection/bedtools-2.17.0/bin/bedtools genomecov -d -ibam "+output_bam+" | cut -f 3-  > "+bam_file+".cov")
    
    print "BAM file is ready: " + output_bam
    
    return output_bam, bam_file+".cov"

def compute_landscape(reference):
    landscape = "data/"+file_wo_ext(reference)+".lss"
    seqfile = reference
    print "Here is the file i am going to build the landscape for ... " + reference 
    os.system("bin/build_dawg "+seqfile+" "+seqfile+" "+landscape)
    print "Done from the python script! "
    return landscape

def draw_coverage_landscape(landscape,coverage):
    ls = np.loadtxt(landscape)
    cg = np.loadtxt(coverage)
    plt.plot(ls)
    plt.plot(cg)
    plt.show()

# def marker(landscape_file, read1,read2)
#reference = argv[1]
#reads1 = argv[2]
#reads2 = argv[3]

#maximal_landscape = compute_landscape(reference)
#bamFile,coverage = align_reads_to_reference(reference, reads1, reads2)
#maximal_landscape = "data/ecoli_rc.lss"
#bamFile = "data/ecoli_rc_mc.orig.1.fq.00.cor.bam"
bamFile = argv[1]#"ecoli.bam"
#coverage = "data/ecoli_rc_mc.orig.1.fq.00.cor.cov"
#draw_coverage_landscape(maximal_landscape,coverage)
table = collect_id_position_table(bamFile)
#annotate_reads(table, coverage, maximal_landscape, reads1, reads2)
