#!/usr/bin/env python

# mpr
# tree-relations
# tree-orth
# tree-paralog

# python libraries
import sys
import optparse

# treefix libraries
import treefix

# rasmus, compbio imports
from rasmus import treelib, util
from compbio import phylo

#=============================================================================
# parser

o = optparse.OptionParser(usage="%prog [options] <gene tree file> ...")
o.add_option("-s", "--stree", dest="stree",
             metavar="<species tree>",
             help="species tree file in newick format")
o.add_option("-S", "--smap", dest="smap",
             metavar="<species map>",
             help="gene to species map")
o.add_option("-T", "--treeext", dest="treeext",
             metavar="<tree file extension>",
             default=".tree",
             help="tree file extension (default: \".tree\")")
o.add_option("-R", "--reconext", dest="reconext",
             metavar="<reconciliation file extension>",
             default=".mpr.recon",
             help="reconciliation file extension (default: \".mpr.recon\")")
o.add_option("--nhxext", dest="nhxext",
             metavar="<NHX tree file extension>",
             default=".nhx.tree",
             help="NHX tree file extension (default: \".nhx.tree\")")
o.add_option("--relext", dest="relext",
             metavar="<relations file extension>",
             default=".rel.txt",
             help="relations file extension (default: \".rel.txt\")")
o.add_option("--orthext", dest="orthext",
             metavar="<orthologs file extension>",
             default=".orth.txt",
             help="orthologs file extension (default: \".orth.txt\")")
o.add_option("--paralogext", dest="paralogext",
             metavar="<paralogs file extension>",
             default=".paralog.txt",
             help="paralogs file extension (default: \".paralog.txt\")")
conf, args = o.parse_args()

#=============================================================================
# utilities

def read_filenames(stream):
    for line in stream:
        yield line.rstrip()

def get_tree_relations(tree, recon, events, loss):

    # print gene 'events'
    for leaf in tree.leaves():
        yield ["gene", leaf.name]

    # print duplication and speciation events
    for node in tree:

        # skip gene events (already printed)
        if events[node] == "gene":
            continue

        rel = sorted((sorted(child.leaf_names()) for child in node.children),
                     key=relation_format)
        
        if events[node] == "dup":
            yield ["dup"] + rel + [recon[node].name]
        elif events[node] == "spec":
            yield ["spec"] + rel + [recon[node].name]

    # print loss events
    for gbranch, sbranch in loss:
        yield ["loss", sorted(gbranch.leaf_names()),
               str(sbranch.name)]

def relation_format(val):
    if isinstance(val, (list, tuple)):
        return ",".join(relation_format(v) for v in val)
    else:
        return str(val)           

#=============================================================================
# main

# read inputs
stree = treelib.read_tree(conf.stree)
gene2species = phylo.read_gene2species(conf.smap)

if len(args) == 0:
    filenames = read_filenames(sys.stdin)
else:
    filenames = args

# process tree files
for treefile in filenames:
    tree = treelib.read_tree(treefile)
    base = util.replace_ext(treefile, conf.treeext, "")
    
    # check tree
    assert treelib.is_rooted(tree)

    # perform MPR, then infer spec/dup/loss
    recon = phylo.reconcile(tree, stree, gene2species)
    events = phylo.label_events(tree, recon)
    loss = phylo.find_loss(tree, stree, recon)

    # make (and annotate) NHX tree
    nhxtree = tree.copy(copyData=False)
    for node in tree:
        n = nhxtree.nodes[node.name]
        del n.data["tree"]
        if "boot" in node.data:
            n.data["B"] = node.data["boot"]                     # bootstrap
        n.data["S"] = str(recon[node].name)                     # species
        if not node.is_leaf():
            n.data["D"] = "Y" if events[node] == "dup" else "N" # spec/dup
        n.data["L"] = ""                                        # loss
    for node, snode in loss:
        n = nhxtree.nodes[node.name]
        n.data["L"] += "-" + str(snode.name)
    for n in nhxtree:
        if n.data["L"] == "":
            del n.data["L"]
    nhxtree.write(base + conf.nhxext, writeData=treelib.write_nhx_data)

    # output recon and events
    phylo.write_recon_events(base + conf.reconext, recon, events)

    # output relations
    out = util.open_stream(base + conf.relext, 'w')
    for rel in get_tree_relations(tree, recon, events, loss):
        print >>out, "\t".join(relation_format(val) for val in rel)
    out.close()
        
    # output orthologs
    out = util.open_stream(base + conf.orthext, 'w')
    orths = phylo.find_orthologs(tree, stree, recon, events,
                                 species_branch=True)

    for orth in orths:
        gene1, gene2, spcnt1, spcnt2, snode = orth
            
        sp1 = gene2species(gene1)
        sp2 = gene2species(gene2)
        if sp1 > sp2:
            sp1, sp2 = sp2, sp1
            gene1, gene2 = gene2, gene1
            spcnt1, spcnt2 = spcnt2, spcnt1

        toks = (sp1, sp2, gene1, gene2, spcnt1, spcnt2, snode.name)
        print >>out, "\t".join(map(str, toks))
    out.close()
        
    # output paralogs
    out = util.open_stream(base + conf.paralogext, 'w')
    paralogs = phylo.find_paralogs(tree, stree, recon, events,
                                   species_branch=True,
                                   split=False)

    for paralog in paralogs:
        gene1, gene2, spcnt1, spcnt2, snode = paralog
            
        sp1 = gene2species(gene1)
        sp2 = gene2species(gene2)
        if sp1 > sp2:
            sp1, sp2 = sp2, sp1
            gene1, gene2 = gene2, gene1
            spcnt1, spcnt2 = spcnt2, spcnt1

        toks = (sp1, sp2, gene1, gene2, spcnt1, spcnt2, snode.name)
        print >>out, "\t".join(map(str, toks))
    out.close()
