#!/usr/bin/env python

#
# This code requires an underlying module for likelihood testing.
#

# python libraries
import os, sys, optparse
import StringIO

# import random in order to seed since random is used in phylo
# import numpy.random since this is faster than the native random
import random
try:
    from numpy import random as nprnd
    NUMPY = True
except:
    import random as nprnd
    NUMPY = False

# string to binary representation
try:
    if bin(0): pass
except NameError, ne:
    def bin(x):
        """
        bin(number) -> string
        Stringifies an int or long in base 2.
        """
        if x < 0: return '-' + bin(-x)
        out = []
        if x == 0: out.append('0')
        while x > 0:
            out.append('01'[x & 1])
            x >>= 1
        try: return '0b' + ''.join(reversed(out))
        except NameError, ne2: out.reverse()
        return '0b' + ''.join(out)

# treefix library
import treefix
from treefix import common

# rasmus and compbio libraries
from rasmus import treelib, util, timer
from compbio import phylo

"""
All trees should be in newick format.

module should have at least one variable and four commands:
    -- rooted
       True if module uses rooted trees.
    -- init()
       Initializes the module.
    -- cleanup()
       Performs any cleanup of the module.
    -- optimize_model(treefile, seqfile, extra)
       Optimizes the underlying model in the module given the tree, seq (alignment),
       and extra parameter arguments.
    -- compute_lik_test(tree, test statistic)
       Computes the test statistic for tree likelihood equivalence.
       Returns the p-value and Dlnl (delta lnl = best lnl - current lnl).
"""

usage = "usage: %prog [options] <gene tree> ..."
parser = optparse.OptionParser(usage=usage)

grp_io = optparse.OptionGroup(parser, "Input/Output")
common.add_common_options(grp_io, 
                          infiles=True, reroot=True,
			  stree=True, smap=True,
			  alignext=True)
grp_io.add_option("-U", "--usertreeext", dest="usertreeext",
                  metavar="<user tree file extension>",
                  help="check if user tree is visited in search")
grp_io.add_option("-o", "--oldext", dest="oldext",
                  metavar="<old tree file extension>",
                  default=".tree",
                  help="old tree file extension (default: \".tree\")")
grp_io.add_option("-n", "--newext", dest="newext",
                  metavar="<new tree file extension>",
                  default=".treefix.tree",
                  help="new tree file extension (default: \".treefix.tree\")")
parser.add_option_group(grp_io)

grp_cost = optparse.OptionGroup(parser, "Cost Function")
common.move_option(parser, "--reroot", grp_cost)
grp_cost.add_option("-D", "--dupcost", dest="dupcost",
                    metavar="<dup cost>",
		    default=1.0, type="float",
		    help="duplication cost (default: 1.0)")
grp_cost.add_option("-L", "--losscost", dest="losscost",
                    metavar="<loss cost>",
		    default=1.0, type="float",
		    help="loss cost (default: 1.0)")
parser.add_option_group(grp_cost)

grp_model = optparse.OptionGroup(parser, "Likelihood Model")
grp_model.add_option("-m", "--module", dest="module",
                     metavar="<module for tree calculations>",
                     default="raxml",
                     help="module for tree calculations (default: \"raxml\")")
grp_model.add_option("-e", "--extra", dest="extra",
                     metavar="<extra arguments to module>",
                     default="-m GTRGAMMA -n test -e 2.0",
                     help="extra arguments to pass to program")
parser.add_option_group(grp_model)

grp_test = optparse.OptionGroup(parser, "Likelihood Test")
grp_test.add_option("-t", "--test", dest="test",
                    metavar="<test statistic>",
                    choices=["AU", "NP", "BP", "KH", "SH", "WKH", "WSH"],
                    default="SH",
                    help="test statistic for likelihood equivalence (default: \"SH\")")
grp_test.add_option("-p", "--pval", dest="pval",
                    metavar="<p-value>",
                    default=0.05, type="float",
                    help="p-value threshold (default: 0.05)")
parser.add_option_group(grp_test)

grp_search = optparse.OptionGroup(parser, "Search Options")
grp_search.add_option("--seed", dest="seed",
		      type="int", metavar="<seed>",
		      help="seed value for random generator")
grp_search.add_option("--niter", dest="niter",
                      metavar="<# iterations>",
                      default=100, type="int",
                      help="number of iterations (default: 100)")
grp_search.add_option("--nquickiter", dest="nquickiter",
                      metavar="<# quick iterations>",
                      default=50, type="int",
                      help="number of subproposals (default: 50)")
grp_search.add_option("--freconroot", dest="freconroot",
                      metavar="<fraction reconroot>",
                      default=0.05, type="float",
                      help="fraction of search proposals to reconroot (default: 0.05)")
parser.add_option_group(grp_search)

grp_info = optparse.OptionGroup(parser, "Information")
grp_info.add_option("-V", "--verbose", dest="verbose",
                    metavar="<verbosity level>",
                    default="0", choices=["0","1","2","3"],
                    help="verbosity level (0=quiet, 1=low, 2=medium, 3=high)")
grp_info.add_option("-l", "--log", dest="log",
                    metavar="<log file>",
                    default="-",
                    help="log filename.  Use '-' to display on stdout.")
common.move_option(parser, "--help", grp_info)
parser.add_option_group(grp_info)

grp_debug = optparse.OptionGroup(parser, "Debug")
grp_debug.add_option("--debug", dest="debug",
                     metavar="<debug mode>",
                     default=0, type="int",
                     help="debug mode (octal: 0=normal, " +\
                          "1=skips likelihood test, " +\
                          "2=skips cost requirement, " +\
                          "4=computes likelihood for all trees in pool)")
parser.add_option_group(grp_debug)

options, args = parser.parse_args()

#=============================
# check arguments

# required options
common.check_req_options(parser, options, clade=False)
options.verbose = int(options.verbose)

# debug options
if options.debug < 0 or options.debug > 7:
    parser.error("--debug must be in {0,...,7}: %d" % options.debug)
debug = bin(options.debug)[2:].zfill(3)
DEBUG_SKIP_LIK = True if debug[-1] == "1" else False
DEBUG_SKIP_COST = True if debug[-2] == "1" else False
DEBUG_COMPUTE_ALL_LIK = True if debug[-3] == "1" else False
if DEBUG_SKIP_LIK and DEBUG_COMPUTE_ALL_LIK:
    parser.error("cannot set debug flag 4 and 1: %d" % options.debug)

# other options
if options.pval < 0 or options.pval > 1:
    parser.error("--pval must be in [0,1]: %.5g" % options.pval)

if options.niter < 1:
    parser.error("--iter must be >= 1: %d" % options.iter)

if options.nquickiter < 1:
    parser.error("--quickiter must be >= 1: %d" % options.quickiter)

if options.freconroot < 0 or options.freconroot > 1:
    parser.error("--freconroot must be in [0,1]: %d" % options.freconroot)

# determine gene tree files
treefiles = common.get_input_files(parser, options, args)
if len(treefiles) == 0:
    parser.error("must specify input file(s)")

# read species tree and species map
stree = treelib.read_tree(options.stree)
gene2species = phylo.read_gene2species(options.smap)

# read duplication and loss cost
dupcost = options.dupcost
losscost = options.losscost

#=============================
# utilities

def output_tree(gtree, out, single_tree):
    if single_tree:
        gtree.write(out)
    else:
        gtree.write(out, oneline=True); out.write('\n')

def log_tree(gtree, log, oneline=True, writeDists=False):
    treeout = StringIO.StringIO()
    if oneline:
        gtree.write(treeout, oneline=oneline)
        log.log("tree: %s\n" % treeout.getvalue())
    else:
        if writeDists:
            treelib.draw_tree(gtree, out=treeout)
	else:
	    treelib.draw_tree(gtree, out=treeout, minlen=5, maxlen=5)
        log.log("tree:\n %s\n" % treeout.getvalue())
    treeout.close()

def clear_caches():
    """clears caches"""
    recon_cache.clear()
    duploss_cache.clear()
    lik_cache.clear()

#=============================
# phylogeny functions

def unroot(gtree, newCopy=True):
    """returns unrooted gtree (with internal root always at the same place)"""
   
    if newCopy:
        gtree = gtree.copy()
    treelib.unroot(gtree, newCopy=False)
    treelib.reroot(gtree, gtree.nodes[sorted(gtree.leaf_names())[0]].parent.name,
                   onBranch=False, newCopy=False)
    return gtree

recon_cache = {}
def recon_root(gtree, stree, gene2species, newCopy=True, returnCost=False,
               dupcost=dupcost, losscost=losscost):
    """cached version of phylo.recon_root"""
    
    # can hash with unrooted version since recon_root first unroots anyways
    treehash = phylo.hash_tree(unroot(gtree, newCopy=True))
    if treehash in recon_cache:
        tree, cost = recon_cache[treehash]
    else:
        tree, cost = phylo.recon_root(gtree, stree, gene2species, newCopy=True, keepName=True, returnCost=True,
                                      dupcost=dupcost, losscost=losscost)
        recon_cache[treehash] = tree, cost

    if newCopy:
        gtree = tree.copy()
    else:
        treelib.set_tree_topology(gtree, tree)

    if returnCost:
        return gtree, cost
    else:
        return gtree

duploss_cache = {}
def count_dup_loss_cost(gtree, stree, gene2species,
                        dupcost=dupcost, losscost=losscost):
    """cached version of phylo.count_dup_loss"""
    
    treehash = phylo.hash_tree(gtree)
    if treehash in duploss_cache:
        return duploss_cache[treehash]
    
    recon = phylo.reconcile(gtree, stree, gene2species)
    events = phylo.label_events(gtree, recon)
    cost = 0
    if dupcost != 0:
        cost += phylo.count_dup(gtree, events) * dupcost
    if losscost != 0:
        cost += phylo.count_loss(gtree, stree, recon) * losscost
    duploss_cache[treehash] = cost
    return cost

#=============================
# likelihood functions

lik_cache = {}
def compute_lik_test(gtree):
    """cached version of likelihood test"""
    if not rooted:
        gtree = unroot(gtree, newCopy=True)

    treehash = phylo.hash_tree(gtree)
    if treehash in lik_cache:
	return lik_cache[treehash]

    pval, Dlnl = eval("%s.compute_lik_test(gtree, options.test)" % options.module)
    lik_cache[treehash] = pval, Dlnl
    return pval, Dlnl

#=============================
# main

# log file
if options.verbose >= 1:
    if options.log == "-":
        log = timer.globalTimer()
    else:
        outlog = util.open_stream(options.log, "w")
        log = timer.Timer(outlog)

    log.log("TreeFix executed with the following arguments:")
    log.log("%s %s\n" % (os.path.basename(sys.argv[0]), ' '.join(sys.argv[1:])))

    if DEBUG_SKIP_LIK: log.log("debug: skip likelihood test")
    if DEBUG_SKIP_COST: log.log("debug: skip cost requirement")
    if DEBUG_COMPUTE_ALL_LIK: log.log("debug: compute likelihoods for all trees in pool")
    if any(map(lambda x: x == "1", debug)):
        log.log("\n")

# import module
if DEBUG_SKIP_LIK:
    rooted = False
else:
    exec "import %s" % options.module
    rooted = eval("%s.rooted" % options.module)

# process genes trees
for treefile in treefiles:
    # seed random generator
    if options.seed:
        seed = options.seed
    else:
        seed = int(timer.time.time())
    nprnd.seed(seed); random.seed(seed)

    # start log
    if options.verbose >= 1:
        log.start("Working on file '%s'" % treefile)
        log.log("random seed: %s\n" % seed)
    
    # read user tree
    if options.usertreeext:
        usertreefile = util.replace_ext(treefile, options.oldext, options.usertreeext)
	usertree = treelib.read_tree(usertreefile)
        if options.verbose >= 1:
	    log.log("user: tree")
	    log_tree(usertree, log, writeDists=True)
	    log_tree(usertree, log, oneline=False, writeDists=True)
        if rooted:
            usertreehash = phylo.hash_tree(usertree)
        else:
            usertreehash = phylo.hash_tree(unroot(usertree, newCopy=True))
	usercost = count_dup_loss_cost(usertree, stree, gene2species, dupcost, losscost)
    else:
        usertreehash = None

    # setup files
    alnfile = util.replace_ext(treefile, options.oldext, options.alignext)
    outfile = util.replace_ext(treefile, options.oldext, options.newext)
    out = util.open_stream(outfile, "w")

    # read input trees 
    gtrees = treelib.read_trees(treefile)
    single_tree = len(gtrees) == 1

    for treendx, gtree in enumerate(gtrees):
        if options.verbose >= 1: log.start("Working on file '%s', tree %d" % (treefile, treendx))

        # special cases: no need to search
        special = False
	if len(gtree.leaves()) <= 2 or\
	   (len(gtree.leaves()) == 3 and not rooted):
	    output_tree(gtree, out, single_tree)
	    if options.verbose >= 1: log.log("tree size <= 2 or == 3 and unrooted -- search skipped")
	    if options.verbose >= 1: log.stop()
	    continue

        # remove bootstraps and dists if present
        for node in gtree:
            node.dist = 0
            if "boot" in node.data:
                del node.data["boot"]
	if "boot" in gtree.default_data:
	    del gtree.default_data["boot"]

        # log initial tree
        if options.reroot:
            gtree, cost0 = recon_root(gtree, stree, gene2species, newCopy=False, returnCost=True,
                                      dupcost=dupcost, losscost=losscost)
        else:
            cost0 = count_dup_loss_cost(gtree, stree, gene2species, dupcost, losscost)
	if options.verbose >= 1:
	    log.log("search: initial")
            log_tree(gtree, log)
	    log_tree(gtree, log, oneline=False)
	    log.log("search: cost\t= %d" % cost0)
	
        # store hash of initial tree and whether search has visited user tree
        tree0 = gtree.copy()
        if rooted:
            treehash0 = phylo.hash_tree(tree0)
        else:
            treehash0 = phylo.hash_tree(unroot(tree0, newCopy=True))
        if usertreehash:
            assert set(usertree.leaf_names()) == set(gtree.leaf_names())
        searched_user = treehash0 == usertreehash

        # initialize min values
	mintree, mincost, minpval, minDlnl = tree0, cost0, None, 0

        # initialize module
        if not DEBUG_SKIP_LIK:
            if options.verbose >= 1: log.log(''); log.start("Optimizing model")
            eval("%s.init()" % options.module)
            eval("%s.optimize_model(treefile, alnfile, options.extra)" % options.module)
            if options.verbose >= 1: log.stop(); log.log('')

        # runtime statistics
	runtime_prop = 0; timer_prop = timer.Timer()
	runtime_recon = 0; timer_recon = timer.Timer()
	runtime_stat = 0; timer_stat = timer.Timer()

        # user tree statistics
        if options.verbose >= 1:
	    if usertreehash:
                log.log("user: cost\t= %d" % usercost)
                if not DEBUG_SKIP_LIK:
                    timer_stat.start()
                    userpval, userDlnl = compute_lik_test(usertree)
                    runtime_stat += timer_stat.stop()
		    log.log("user: pval\t= %.6g" % userpval)
	            log.log("user: Dlnl\t= %.6g" % userDlnl)
		log.log("\n")
        
        # search functions
        search = phylo.TreeSearchMix(gtree)
        search.add_proposer(phylo.TreeSearchNni(gtree), 0.5)
        search.add_proposer(phylo.TreeSearchSpr(gtree), 0.5)
	uniques = set([treehash0])

        # do search
	nproposals = 0; nuniques = 0; npools = 0; nemptypools = 0
	ndiffrecon = 0; nrecon = 0
        for i in xrange(options.niter):            # outer search
	    # store initial search tree
	    if rooted:
	        treehash1 = phylo.hash_tree(search.tree)
            else:
                treehash1 = phylo.hash_tree(unroot(search.tree, newCopy=True))

	    # random values -- have to reseed in case module has a random number generator
	    nprnd.seed(seed + i*1024); random.seed(seed + i*1024)
	    if NUMPY:
	        randvec = nprnd.random(2*options.nquickiter)
            else:
	        randvec = [nprnd.random() for _ in xrange(2*options.nquickiter)]

            # search
	    ntrees = 0
	    pool = {}
            mincost_pool = mincost

            # note that reconroot is NOT propagated through the subproposals - doing so messes up the unique filter
	    for j in xrange(options.nquickiter):   # inner search
	        # propose tree
                timer_prop.start()
		tree = search.propose()
		runtime_prop += timer_prop.stop()

		# only allow unique proposals but if I have been rejecting too much allow some non-uniques through
		treehash = phylo.hash_tree(tree)
		if treehash in uniques and ntrees >= 0.1*j:
		    if options.verbose >= 3:
		        log.log("prescreen: iter %d" % j)
			log.log("prescreen: revert")
			log.log("")
                    search.revert()
		    continue

                # save tree
                nproposals += 1
		gtree = tree.copy()
                if treehash not in uniques:
                    uniques.add(treehash)
                    nuniques += 1
		ntrees += 1

		if options.verbose >= 3:
                    log.log("prescreen: iter %d" % j)

		# reconroot (some percentage of the time depending on options.reconroot)
		if randvec[j+options.nquickiter] < options.freconroot:
                    timer_recon.start()
                    gtree, cost = recon_root(gtree, stree, gene2species, newCopy=False, returnCost=True,
                                             dupcost=dupcost, losscost=losscost)
                    runtime_recon += timer_recon.stop()

                    # did reconroot change the tree?
                    nrecon += 1
                    if phylo.hash_tree(gtree) == treehash:
                        if options.verbose > 3:
                            log.log("prescreen: recon\t= unchanged")
                    else:
                        ndiffrecon += 1
                        if options.verbose > 3:
                            log.log("prescreen: recon\t= changed")
	        else:
		    cost = count_dup_loss_cost(gtree, stree, gene2species, dupcost, losscost)

                # log
		if options.verbose >= 3:
                    log.log("prescreen: cost\t= %d" % cost)
		    log_tree(gtree, log)
                    log.log("")
                
		# store to pool if unique
		if rooted:
                    treehash = phylo.hash_tree(gtree)
                else:
                    treehash = phylo.hash_tree(unroot(gtree, newCopy=True))
		if treehash != treehash1 and treehash not in pool:
		    pool[treehash] = (gtree, cost, j)
                if (usertreehash) and (not searched_user) and (treehash == usertreehash):
		    searched_user = True
                
		# update mincost of pool and decide how to continue proposals from here
		if cost < mincost_pool:
                    # make more proposals off this one
                    mincost_pool = cost
                elif cost == mincost_pool:
                    # flip a coin to decide whether to start from original or new proposal
                    if randvec[j] < 0.5:
                        search.revert()
                else:
                    # start from new proposal 10% of the time
		    if randvec[j] < 0.9:
		        search.revert()

            # remove trees with higher costs
	    if options.verbose >= 2:
	        log.log("pool: size\t= %d" % len(pool))
            pool = pool.values()
            if DEBUG_SKIP_COST:
                fpool = pool
            else:
                fpool = filter(lambda (gtree, cost, ndx): cost <= mincost and cost < cost0, pool)
                nfpool = len(fpool)
                if options.verbose >= 2:
                    log.log("pool: filtered size\t= %d" % nfpool)
            if options.verbose >= 2:
                log.log("")
            fpool.sort(key=lambda x: x[1])

            # propose a tree from the pool with minimum cost that passes threshold
            if DEBUG_SKIP_LIK:
                reject = False
                mintree, mincost, minpval, minDlnl = fpool[0][0], fpool[0][1], 1, 0
            else:
                reject = True
                for j, (gtree, cost, ndx) in enumerate(fpool):
                    timer_stat.start()
                    pval, Dlnl = compute_lik_test(gtree)
                    runtime_stat += timer_stat.stop()

                    if (pval < options.pval) or \
                       (cost == mincost and Dlnl > minDlnl):
                        # (1) significantly worse topology or (2) worse Dlnl (smaller Dlnl is better)
                        if options.verbose >= 2:
                            log.log("pool: iter %d (%d)" % (j, ndx))
                            log.log("pool: reject")
                            log.log("pool: cost\t= %d" % cost)
                            log.log("pool: pval\t= %.6g" % pval)
                            log.log("pool: Dlnl\t= %.6g" % Dlnl)
                            log_tree(gtree, log)
                            log.log("")
                    else:
                        reject = False
                        mintree, mincost, minpval, minDlnl = gtree, cost, pval, Dlnl
                        break

            # debug
	    if DEBUG_COMPUTE_ALL_LIK:
                dpool = (fpool[j+1:] if j+1 < nfpool else []) + [x for x in pool if x not in fpool]
                for (gtree, cost, ndx) in dpool:
                    timer_stat.start()
                    pval, Dlnl = compute_lik_test(gtree)
                    runtime_stat += timer_stat.stop()

                    if options.verbose >= 2:
                        log.log("pool: iter (%d)" % ndx)
                        log.log("pool: cost\t= %d" % cost)
                        log.log("pool: pval\t= %.6g" % pval)
                        log.log("pool: Dlnl\t= %.6g" % Dlnl)
                        log_tree(gtree, log)
                        log.log("")

            # reset search and log
	    search.reset()
	    search.set_tree(mintree.copy())
	    npools += 1
	    if nfpool == 0:
                nemptypools += 1
                if options.verbose >= 1:
		    log.log("search: iter %d" % i)
		    log.log("search: empty pool")
		    log.log("\n")
		continue
	    if reject:
		if options.verbose >= 1:
                    log.log("search: iter %d" % i)
                    log.log("search: reject")
                    log.log("\n")
		continue
            if options.verbose >= 1:
                log.log("search: iter %d" % i)
                log.log("search: accept")
                log.log("search: cost\t= %d" % mincost)
                log.log("search: pval\t= %.6g" % minpval)
                log.log("search: Dlnl\t= %.6g" % minDlnl)
                log_tree(mintree, log)
	        log_tree(mintree, log, oneline=False)
                log.log("\n")

	    # end early if tree with zero cost found
	    if mincost == 0:
	        if options.verbose >= 1:
	            log.log("search: break\n")
	        break

	# cleanup module for tree likelihood
        if not DEBUG_SKIP_LIK:
            eval("%s.cleanup()" % options.module)

        # has the tree changed?
        if rooted:
	    treehash = phylo.hash_tree(mintree)
        else:
	    # do a final reconroot
            timer_recon.start()
            mintree, mincost = recon_root(mintree, stree, gene2species, newCopy=False, returnCost=True,
                                          dupcost=dupcost, losscost=losscost)
            runtime_recon += timer_recon.stop()

            treehash = phylo.hash_tree(unroot(mintree, newCopy=True))
        same = treehash == treehash0

        # output final statistics and tree
        if options.verbose >= 1:
            log.log("search: final")
	    if usertreehash:
	        log.log("search: visited user tree\t= %s" % ("yes" if searched_user else "no"))
	        log.log("search: equal user tree\t= %s" % ("yes" if treehash == usertreehash else "no"))
            if same:
                log.log("search: changed\t= no")
            else:
                log.log("search: changed\t= yes")
		log.log("search: init cost\t= %d" % cost0)
                log.log("search: final cost\t= %d" % mincost)
                if minpval:
		    log.log("search: pval\t= %.6g" % minpval)
                    log.log("search: Dlnl\t= %.6g" % minDlnl)
                log_tree(mintree, log)

            log.log("")
	    log.log("num proposals:\t%d" % nproposals)
            log.log("num unique proposals:\t%d" % nuniques)
            if options.verbose >= 2:
	        log.log("")
                log.log("num pools:\t%d" % npools)
		log.log("num empty pools:\t%d" % nemptypools)
		if npools == 0:
		    log.log("empty pool rate:\tINF")
		else:
                    log.log("empty pool rate:\t%f" % (float(nemptypools)/npools))
            log.log("")
            log.log("num reconroot:\t%d" % nrecon)
            log.log("num diff reconroot:\t%d" % ndiffrecon)
	    if nrecon == 0:
	        log.log("diff reconroot rate:\tINF")
	    else:
                log.log("diff reconroot rate:\t%f" % (float(ndiffrecon)/nrecon))
            log.log("")
            log.log("proposal runtime:\t%f" % runtime_prop)
            log.log("reconroot runtime:\t%f" % runtime_recon)
	    log.log("statistic runtime:\t%f" % runtime_stat)
        output_tree(mintree, out, single_tree)

        # cleanup
        clear_caches()
        if options.verbose >= 1: log.stop()

    # close output stream
    out.close()

    if options.verbose >= 1: log.stop(); log.log("\n\n")

# close log
if options.verbose >= 1 and options.log != "-":
    outlog.close()
