#!/usr/bin/env python

# utility to compute the pval,Dlnl or cost function

# python libraries
import optparse, sys

# treefix libraries
import treefix
from treefix import common

# rasmus libraries
from rasmus import treelib, util
from compbio import phylo, alignlib

usage = "usage: %prog [options] <gene tree> ..."
parser = optparse.OptionParser(usage=usage)

# main inputs
grp_io = optparse.OptionGroup(parser, "Input/Output")
common.add_common_options(grp_io,
                          infiles=True, reroot=True,
                          stree=True, smap=True,
                          alignext=True)
grp_io.add_option("-U", "--usertreeext", dest="usertreeext",
                  metavar="<user tree file extension>",
                  help="user tree used for optimizing model (default: input tree)")
grp_io.add_option("-o", "--oldext", dest="oldext",
                  metavar="<old tree file extension>",
                  default=".tree",
                  help="old tree file extension (default: \".tree\")")
grp_io.add_option("-n", "--newext", dest="newext",
                  metavar="<output file extension>",
                  default=".output",
                  help="output file extension (default: \".output\")")
parser.add_option_group(grp_io)

# model
default_module = "treefix.models.raxmlmodel.RAxMLModel"
default_smodule = "treefix.models.duplossmodel.DupLossModel"
grp_model = optparse.OptionGroup(parser, "Model")
grp_model.add_option("--type", dest="type",
                     default=None,
                     choices=["likelihood","cost"],
                     metavar="<likelihood|cost>",
                     help="module type")
grp_model.add_option("-m", "--module", dest="module",
                     metavar="<underlying module>",
                     help="module for likelihood/cost computations")
grp_model.add_option("-e", "--extra", dest="extra",
                     metavar="<extra arguments>",
                     help="extra arguments to pass to program")
grp_model.add_option("--show-help", dest="show_help",
                     default=False, action="store_true")
parser.add_option_group(grp_model)

# likelihood test
grp_test = optparse.OptionGroup(parser, "Likelihood Test")
grp_test.add_option("-t", "--test", dest="test",
                    metavar="<test statistic>",
                    choices=["AU", "NP", "BP", "KH", "SH", "WKH", "WSH"],
                    default="SH",
                    help="test statistic for likelihood equivalence (default: \"SH\")")
parser.add_option_group(grp_test)

# cost evaluation
grp_test = optparse.OptionGroup(parser, "Cost Evaluation")
common.move_option(parser, "--reroot", grp_test)
parser.get_option("--reroot").help = "set to reroot trees"
parser.add_option_group(grp_test)

# parse
options, args = parser.parse_args()

#=============================
# helper functions
def get_module(mname):
    """get module"""
    folder = '.'.join(mname.split('.')[:-1])
    if folder != '':
        exec "import %s" % folder
    module = eval("%s(options.extra)" % mname)
    return module

#=============================
# check arguments

# required options
if not options.type:
    parser.error("--type required")
if not options.module:
    if options.type == "likelihood":
        options.module = default_module
    else:
        options.module = default_smodule

# help
if options.show_help:
    module = get_module(options.module)
    module.print_help()
    sys.exit(0)

# required options for cost
if options.type == "cost":
    if (not options.stree) or (not options.smap):
        parser.error("--stree and --smap are required")

# determine input files
treefiles = common.get_input_files(parser, options, args)

#=============================
# main

# import module
module = get_module(options.module)

# print version
print "Model version: %s" % module.VERSION

# get species tree and species map
if options.type == "cost":
    # read species tree and species map
    stree = treelib.read_tree(options.stree)
    gene2species = phylo.read_gene2species(options.smap)    

# iterate through files
for treefile in treefiles:
    # setup files
    if options.type == "likelihood":
        alnfile = util.replace_ext(treefile, options.oldext, options.alignext)
        aln = alignlib.fasta.read_fasta(alnfile)
    outfile = util.replace_ext(treefile, options.oldext, options.newext)
    out = util.open_stream(outfile, "w")

    # read trees
    gtrees = treelib.read_trees(treefile)

    # read user trees
    if options.usertreeext:
        usertreefile = util.replace_ext(treefile, options.oldext, options.usertreeext)
        usertree = treelib.read_tree(usertreefile)
    else:
        usertree = gtrees[0]

    if options.type == "likelihood":
        # optimize model
        module.optimize_model(usertree, aln)
    elif options.type == "cost":
        # optimize model
        module.optimize_model(usertree, stree, gene2species)

    # iterate through trees
    for gtree in gtrees:
        # remove bootstraps and dists if present
        for node in gtree:
            node.dist = 0
            if "boot" in node.data:
                del node.data["boot"]
        if "boot" in gtree.default_data:
            del gtree.default_data["boot"]
        
        # compute likelihood or cost
        if options.type == "likelihood":
            
            pval, Dlnl = module.compute_lik_test(gtree, options.test)
            print >>out, "%.6g\t%.6g" % (pval, Dlnl)
               
        elif options.type == "cost":

            if options.reroot:
                tree, cost = module.recon_root(gtree, newCopy=False, returnCost=True)
            else:
                cost = module.compute_cost(gtree)
            print >>out, "%.6g" % cost
                
# close files
out.close()
