# Time-stamp: <ClustalTools.py 2012-03-29 18:55:26 Mark Voorhies>
"""Tools for reading and manipulating trees and alignments from CLUSTAL.
NewHampshireGraph is a digraph representation of ph(b) trees supporting
re-rooting.
MultipleAlignment is a matrix representation of a multiple alignment
supporting column indexing and slicing.
"""

import re

#==================================================
#   Digraph implementation for New Hampshire Trees
#==================================================

class LabeledNode:
    def __init__(self, label):
        self.label = label

    def __str__(self):
        return str(self.label)

    def __repr__(self):
        return str(self)

class BootstrapEdge:
    def __init__(self, branchlen = None, bootstrap = None):
        self.branchlen = branchlen
        self.bootstrap = bootstrap

    def __str__(self):
        return "(%s,%s)" % (self.branchlen, self.bootstrap)

    def __repr__(self):
        return str(self)

    def phbFormat(self):
        retval = ""
        if((self.branchlen is not None) or (self.bootstrap is not None)):
            retval += ":"
        if(self.branchlen is not None):
            retval += "%6.5f" % self.branchlen
        if(self.bootstrap is not None):
            retval += "[%d]" % self.bootstrap
        return retval

class Digraph:
    def __init__(self):
        self.__nodes = set()
        self.__edges = set()
        self.__children = {}
        self.__parents = {}

    def addNode(self, node):
        self.__nodes.add(node)
        self.__parents[node] = {}
        self.__children[node] = {}

    def getEdge(self, node1, node2):
        try:
            return self.__children[node1][node2]
        except KeyError:
            return self.__children[node2][node1]

    def addEdge(self, parent, child, edge):
        """Add or update weight for edge from parent
        to child.  Note that multiple edges between
        two nodes are not allowed."""
        self.__edges.add(edge)
        self.__children[parent][child] = edge
        self.__parents[child][parent] = edge

    def delEdge(self, parent, child):
        edge = self.__children[parent][child]
        del self.__children[parent][child]
        del self.__parents[child][parent]
        self.__edges.remove(edge)

    def getParents(self, child):
        return self.__parents[child].keys()

    def getChildren(self, parent):
        return self.__children[parent].keys()

    def getNodes(self):
        return [i for i in self.__nodes]

    def getEdges(self):
        return [i for i in self.__edges]

    def orient(self, node):
        """Recursively invert the direction of any edge leading into node.
        This has the effect of rooting the tree from the node at the
        root of the recursive call.

        Assumes that the digraph is a tree.
        """
        parents = self.getParents(node)
        assert(len(parents) < 2)
        if(len(parents) == 0):
            # Root node, everything should be correct downstream
            #   of this node
            return
        oldParent = parents[0]
        self.orient(oldParent)
        edge = self.getEdge(oldParent, node)
        self.delEdge(oldParent, node)
        self.addEdge(node, oldParent, edge)
        
class NewHampshireGraph:
    """Digraph representation of (bootstrapped) phylogram"""
    def __init__(self):
        self.graph = Digraph()

        #  Internal variables used by LR parser

        # Root := First internal node encountered during parsing
        self.root = None
        # Number of internal nodes
        self.icount = 0
        # Node that should be parent to the next new node
        self.activeNode = None
        # Ancestors of the current active node
        self.stack = []
        # Edge that should receive the next branchlen and bootstrap values
        self.activeEdge = None


    def setRoot(self, node = None):
        """Root the tree at node, or use a heuristic based on njplot
        for finding the center of an unrooted tree."""
        if(node is None):
            # Heuristic for guessing center of unrooted tree
            #   Choose node with maximum average distance to leaf nodes.
            #   (We collect the depth data just because it's cheap and
            #    I'm curious about the depth vs. dist correlation
            #    -- in practice, probably not much correlation when
            #       there are lots of "star" topologies being fudged
            #       by very short branches).

            (depths, dists) = self.calcDepths()

            from MsvUtil import multirevdict
            depth2nodes = multirevdict(depths)
            dist2nodes = multirevdict(dists)

            maxdepth = max(depth2nodes.keys())
            #print "max depth: %d (%d nodes)" % (maxdepth,
            #                                    len(depth2nodes[maxdepth]))
            maxdist = max(dist2nodes.keys())
            #print "max dist: %f (%d nodes)" % (maxdist,
            #                                   len(dist2nodes[maxdist]))

            node = dist2nodes[maxdist][0]
            #print "root: ", node
            #print "  dist  = %f" % dists[node]
            #print "  depth = %d" % depths[node]

        assert(node is not None)
        
        self.graph.orient(node)
        self.root = node

    def calcDepths(self):
        """Return a tuple of dicts (depths, dists) where
        depths in is the depth of each node (leaf nodes = 0)
        and dists is the average distance of each node from the
        tree leaves (again, leaf nodes = 0.0)."""

        # N.B.: This algorithm is something like O(k*N) where
        #       k is the maximum depth and N is the number of
        #       nodes -- actually, more like log(N) given that
        #       we deal with most of the nodes in the first pass --
        #       a balanced tree may actually be the worst case
        #       scenario.  Given that this is a simple case of
        #       solving a dependency tree, there's probably a
        #       better implementation (e.g. the dependency tool
        #       in "Large Scale C++ Programming").

        # Return values
        depths = {}
        dists = {}

        # Start from leaf nodes and work inward
        depth = 0
        
        nodes = self.graph.getNodes()
        
        # iteration counter to avoid infinite loops
        n = 0
        while((len(depths) < len(nodes)) and
              n < 1000):
            
            n += 1

            # Add values for nodes at current depth.
            # We recognize such nodes by the fact that only a single
            #  edge (the parent) is unvisited.

            # We cache the new values so that each node sees the pre-update
            #  tree.

            depth_cache = {}
            dist_cache = {}
            
            for i in nodes:
                if(depths.has_key(i)):
                    continue

                neighbors = self.graph.getParents(i)+self.graph.getChildren(i)
                children = [j for j in neighbors if(depths.has_key(j))]
                if(len(neighbors) - len(children) > 1):
                    continue
                try:
                    # The root node has neighbors == children
                    #   I need to decide if there is a reasonable test
                    #   for root...
                    #assert(len(neighbors) - len(children) == 1)
                    pass
                except AssertionError:
                    import sys
                    print "--------------------i--------------------"
                    print i
                    print "--------------------neighbors--------------------"
                    print neighbors
                    print "--------------------children--------------------"
                    print children
                    raise
                
                depth_cache[i] = depth

                if(len(children) == 0):
                    dist_cache[i] = 0.0
                else:
                    dist_cache[i] = float(
                        sum(dists[j]+self.graph.getEdge(i,j).branchlen
                            for j in children)
                        )/float(len(children))

            for (key, val) in depth_cache.items():
                depths[key] = val
                dists[key] = dist_cache[key]

            depth += 1

        return (depths, dists)

    def makeBtree(self, node = None):
        if(node is None):
            self.makeBtree(self.root)

        else:
            children = self.graph.getChildren(node)
            if(len(children) < 1):
                return
            for i in children:
                self.makeBtree(i)

            while(len(children) > 2):
                newkids = []
                for i in xrange(len(children)//2):
                    self.icount += 1
                    dummy = LabeledNode(label = "DUMMY%04d" % self.icount)
                    self.graph.addNode(dummy)
                    left = children[i*2]
                    right = children[i*2+1]
                    edgeL = self.graph.getEdge(node, left)
                    edgeR = self.graph.getEdge(node, right)
                    self.graph.delEdge(node, left)
                    self.graph.delEdge(node, right)
                    self.graph.addEdge(node, dummy,
                                       BootstrapEdge(branchlen = 0.0))
                    self.graph.addEdge(dummy, left, edgeL)
                    self.graph.addEdge(dummy, right, edgeR)
                    newkids.append(dummy)
                if(len(children) % 2 == 1):
                    newkids.append(children[-1])

                children = newkids

            assert(len(children) == 2)

    def phbFormat(self, node):
        children = self.graph.getChildren(node)
        parents = self.graph.getParents(node)
        assert(len(parents) < 2)
        if(len(parents) == 1):
            edge = self.graph.getEdge(parents[0], node)
        else:
            edge = BootstrapEdge()
        if(len(children) == 0):
            # Leaf node
            return "%s%s" % (str(node), edge.phbFormat())

        else:
            return "(\n%s\n)%s" % (
                ",\n".join(self.phbFormat(i) for i in children),
                edge.phbFormat())

    def writePhb(self, fout):
        """Output the tree in New Hampshire phb format."""
        fout.write(self.phbFormat(self.root)+";")

    def sifFormat(self, node):
        retval = ""
        for i in self.graph.getChildren(node):
            retval += "%s pp %s\n" % (node, i)
            retval += self.sifFormat(i)
        return retval

    def writeSif(self, fout):
        """Output the tree in sif format for Cytoscape."""
        fout.write(self.sifFormat(self.root))

    def gtrFormat(self, node, branchlen = 0.0):
        retval = ""
        children = self.graph.getChildren(node)
        # No rows generated for leaf nodes
        if(len(children) == 0):
            return ""
        # Only binary trees can be written to GTR
        assert(len(children) == 2)
        branchlens = [branchlen, branchlen]
        for i in xrange(2):
            edge = self.graph.getEdge(node, children[i])
            if(edge.branchlen != None):
                branchlens[i] += edge.branchlen
            retval += self.gtrFormat(children[i], branchlens[i])
        retval += "\t".join((str(node),
                             str(children[0]),
                             str(children[1]),
                             "%6.5f" % branchlen))+"\n"

        return retval

    def writeGtr(self, fout):
        """Output tree as GTR file for JavaTreeView.
        Note that only internal nodes are written."""

        fout.write("\t".join(("NODEID","LEFT","RIGHT","TIME"))+"\n")
        fout.write(self.gtrFormat(self.root))

    def cdtFormat(self, node, alignment, branchlen = 0.0):
        children = self.graph.getChildren(node)
        if(len(children) == 0):
            aln = alignment[node.label]
            return "\t".join((node.label,            # GID
                              node.label,            # UID
                              "%6.5f" % branchlen,   # LEAF
                              aln,                   # ALN
                              node.label,            # NAME
                              "1",                   # GWEIGHT
                              "")                    # DUMMY
                             )+"\n"

        retval = ""
        for i in children:
            edge = self.graph.getEdge(node, i)
            if(edge.branchlen is not None):
                b = branchlen + edge.branchlen
            else:
                b = branchlen
            retval += self.cdtFormat(i, alignment, b)
            
        return retval
        
    def writeCdt(self, fout, alignment):
        """Write alignment as CDT for JavaTreeView, using the same
        sequence ordering and GIDs as writeGtr."""
        fout.write("\t".join(("GID","UID","LEAF","ALN","NAME",
                              "GWEIGHT","DUMMY"))+"\n")
        fout.write("\t".join(["EWEIGHT"]+([""]*5)+["1"])+"\n")
        fout.write(self.cdtFormat(self.root, alignment))

    def finalize(self):
        pass

    def initInternalNode(self):
        self.icount += 1
        node = LabeledNode(label = "NODE%05d" % self.icount)
        self.graph.addNode(node)

        if(self.activeNode is None):
            assert(self.root is None)
            self.root = node
        else:
            self.graph.addEdge(parent = self.activeNode,
                               child = node,
                               edge = BootstrapEdge())
            self.stack.append(self.activeNode)
        self.activeNode = node

    def closeInternalNode(self):
        parents = self.graph.getParents(self.activeNode)
        if(len(parents) == 0):
            assert(self.root == self.activeNode)
            assert(len(self.stack) == 0)
            self.activeNode = None
        else:
            assert(len(parents) == 1)
            self.activeEdge = self.graph.getEdge(parents[0], self.activeNode)
            assert(len(self.stack) > 0)
            self.activeNode = self.stack[-1]
            self.stack = self.stack[:-1]

    def setBranchLen(self, n):
        self.activeEdge.branchlen = float(n)

    def setBootstrap(self, n):
        self.activeEdge.bootstrap = int(n)

    def closeSisterElement(self):
        self.activeEdge = None

    def initLabel(self, label):
        node = LabeledNode(label = label)
        self.graph.addNode(node)
        edge = BootstrapEdge()
        self.graph.addEdge(parent = self.activeNode,
                           child = node,
                           edge = edge)
        self.activeEdge = edge

    def lrPhbParse(self, s, mapper):
        """left->right parser for New Hampshire format dendrograms."""
        p = 0
        while(p < len(s)):
            if(s[p].isspace()):
                p += 1
                continue
            if(s[p] == ";"):
                self.finalize()
                break
            if(s[p] == "("):
                self.initInternalNode()
                p += 1
                continue
            if(s[p] == ")"):
                self.closeInternalNode()
                p += 1
                continue
            if(s[p] == ":"):
                p += 1
                l = p
                while(s[p] in "0123456789-."):
                    p += 1
                self.setBranchLen(float(s[l:p]))
                continue
            if(s[p] == "["):
                p += 1
                l = p
                while(s[p] in "0123456789"):
                    p += 1
                assert(s[p] == "]")
                self.setBootstrap(int(s[l:p]))
                p += 1
                continue
            if(s[p] == ","):
                self.closeSisterElement()
                p += 1
                continue
            l = p
            while((s[p] not in ",();:")
                  and (not s[p].isspace())):
                p += 1
            self.initLabel(mapper(s[l:p]))

    @classmethod
    def fromPhb(cls, fin, mapper = lambda x: x):
        """Construct a NewHampshireGraph from a phb file.

        If mapper is given, it is a function mapping names
        in the phb file to desired names in the constructed
        graph."""
        if(isinstance(fin, str)):
            s = fin
        else:
            s = fin.read()
        tree = cls()
        tree.lrPhbParse(s, mapper)
        return tree

#==================================================
#    MultipleAlignment
#==================================================

class MultipleAlignment:
    """Matrix representation of a multiple alignment supporting column
    indexing and slicing."""
    def __init__(self, seqmatrix, seqnames = None, colAnnotations = None):
        self.seqmatrix = seqmatrix
        if(seqnames is not None):
            self.seqnames = seqnames
            assert(len(self.seqnames) == len(self.seqmatrix))
        else:
            self.seqnames = ["Seq%05d" % i for i in xrange(len(self.seqmatrix))]

        if(colAnnotations is not None):
            self.colAnnotations = colAnnotations
        else:
            self.colAnnotations = {}

        self.name2seq = dict((i,j) for (i,j) in zip(seqnames, seqmatrix))
        assert(len(self.name2seq) == len(self.seqnames))

    def remap_names(self, mapper):
        """Update sequence names and name2seq index with the given transform.

        This is a utility for, e.g., wrapping calls to CLUSTALX with a
        reversible name shortening."""

        self.seqnames = [mapper(i) for i in self.seqnames]
        self.name2seq = dict((i,j) for (i,j) in zip(self.seqnames, self.seqmatrix))
        assert(len(self.name2seq) == len(self.seqnames))

    def __getitem__(self, i):
        """Return the sequence named "i" if i is a string
        or column vector [i] (counting from 0) if i is an integer)
        """
        if(isinstance(i, int)):
            return [row[i] for row in self.seqmatrix]
        if(isinstance(i, str)):
            return self.name2seq[i]
        raise KeyError

    def __getslice__(self, i, j):
        """Return a MultipleAlignment corresponding to columns
        [i:j] (indexing from 0, j is one past the last column).
        """
        return MultipleAlignment(
            seqmatrix = [seq[i:j] for seq in self.seqmatrix],
            seqnames = self.seqnames,
            colAnnotations = dict((key, seq[i:j])
                                  for (key,seq) in self.colAnnotations.items()))

    def writeClustal(self, fout):
        """Write alignment in (inferred) CLUSTALX .aln format."""
        seqnames = [i[max(0,len(i)-35):] for i in self.seqnames]
        for i in seqnames:
            assert(len(i) < 36)
        assert(len(seqnames) == len(set(seqnames)))

        fout.write('CLUSTAL X (1.83) multiple sequence alignment\n')
        fout.write('\n')
        step = 50
        offset = 0
        n = len(self.seqmatrix[0])
        while(offset < n):
            # spacer
            fout.write('\n')
            for (name, seq) in zip(seqnames, self.seqmatrix):
                fout.write('%-35s %s\n' % (name, seq[offset:offset+step]))
            offset += step
            # conservation line
            fout.write(' '*86+'\n')

    @classmethod
    def fromClustal(cls, fin, headercheck = True):
        """Initialize from CLUSTAL format aln file."""
        seqnames = []
        seqs = {}

        # Read past header
        line = fin.next()
        if(headercheck):
            assert(line[:7] == "CLUSTAL")
        while((line[:7] == "CLUSTAL") or (len(line.strip()) == 0)):
            line = fin.next()

        def isConserved(s):
            if(len(s) < 1):
                return False
            if(not s[0].isspace()):
                return False
            for i in s:
                if(i.isspace()):
                    continue
                if(i not in ".:*"):
                    return False
            return True
                
        # Read first block
        while((len(line.strip()) > 0) and (not isConserved(line))):
            w = line.split()
            assert(len(w) == 2)
            assert(not seqs.has_key(w[0]))
            for i in w[1]:
                try:
                    assert(i in "ACDEFGHIKLMNPQRSTUVWYXacdefghiklmnpqrstuvwyx-")
                except AssertionError:
                    sys.stderr.write(line)
                    raise
            seqnames.append(w[0])
            seqs[w[0]] = w[1]
            line = fin.next()

        # Read subsequent blocks:
        for line in fin:
            if((len(line.strip()) == 0) or isConserved(line)):
                continue

            w = line.split()
            assert(len(w) == 2)
            for i in w[1]:
                assert(i in "ACDEFGHIKLMNPQRSTUVWYXacdefghiklmnpqrstuvwyx-")
            seqs[w[0]] += w[1]

        n = len(seqs[seqnames[0]])
        for i in seqnames[1:]:
            assert(len(seqs[i]) == n)

        return cls(seqmatrix = [seqs[i] for i in seqnames],
                   seqnames = seqnames)

    @classmethod
    def fromStockholm(cls, fin, headercheck = True):
        """Initialize from Stockholm format file."""
        seqnames = []
        seqs = {}
        colAnnotations = {}

        # Check file foramt
        if(headercheck):
            line = fin.next()
            assert(line[:15] == "# STOCKHOLM 1.0")

        for line in fin:
            # Skip blank lines
            if(len(line.strip()) < 1):
                continue

            # Quit on "end-of-data" marker
            if(line[:2] == '//'):
                break

            # Parse comments and annotations
            if(line[0] == '#'):
                # Column annotations
                if(line[1:4] == "=GC"):
                    (comment, tag, seq) = line.split()
                    if(not colAnnotations.has_key(tag)):
                        colAnnotations[tag] = seq
                    else:
                        colAnnotations[tag] += seq

                # Skip all other

                continue

            # Parse sequence
            w = line.split()
            assert(len(w) == 2)

            if(not seqs.has_key(w[0])):
                seqnames.append(w[0])
                seqs[w[0]] = re.sub('[*.]','-',w[1])
            else:
                seqs[w[0]] += re.sub('[*.]','-',w[1])

        n = len(seqs[seqnames[0]])
        for i in seqnames[1:]:
            try:
                assert(len(seqs[i]) == n)
            except AssertionError:
                print [(i,len(seqs[i])) for i in seqnames]
                raise

        for (tag, seq) in colAnnotations.items():
            assert(len(seq) == n)

        if(len(colAnnotations) == 0):
            colAnnotations = None

        return cls(seqmatrix = [seqs[i] for i in seqnames],
                   seqnames = seqnames, colAnnotations = colAnnotations)

    def writeStockholm(self, fout):
        """Write alignment in Stockholm format, as defined on page 60
        of the HMMer 2 user guide (version 2.3.2)."""

        # Header
        fout.write("# STOCKHOLM 1.0\n")

        # Write alignment blocks.  Pfam does not appear to break alignments
        # into multiple blocks, so we won't either.

        name_width = max(len(i) for i in self.seqnames)
        seqformat = "%-"+("%d" % name_width)+"s %s\n"

        for (seqname, sequence) in zip(self.seqnames, self.seqmatrix):
            fout.write(seqformat % (seqname, sequence))
        
        # Footer
        fout.write("//\n")

    @classmethod
    def fromCdt(cls, fin):
        """Construct a MultipleAlignment from a CDT file with an ALN column.

        C.f. NewHampshireGraph for the corresponding write method, which
        requires tree data for correct row ordering.
        """
        
        from CdtFile import CdtFile
        cdt = CdtFile.fromCdt(fin)
        aln_col = cdt.extranames.index("ALN")
        seqnames = []
        seqmatrix = []
        for row in cdt:
            seqnames.append(row.uniqid)
            seqmatrix.append(row.extra[aln_col])

        return cls(seqmatrix = seqmatrix, seqnames = seqnames)

    @classmethod
    def fromFasta(cls, fin):
        """Construct a MultipleAlignment from a FASTA file, with gaps
        indicated by dashes (-).
        """
        # Not using FastaFile module because we don't want to lose
        # ordering and we need gap characters.
        # TODO: fold this back into a more generic version of FastaFile.

        fasta_re = re.compile(
            # Header lines are marked by the '>' sign
            # We allow headers to begin in the middle of a line
            # We remove whitespace at the ends of the header
            ">[\s]*(?P<header>.*?)[\s]*$"+
            # Sequence is anything that is not a header line
            # We currently count any text in comments (';.*$') as
            # sequence (could parse this out at the same time as
            # whitespace)
            "(?P<seq>[^>]*)",
            # Use multiline mode to parse an entire FASTA file in one go
            re.M)

        name_re = re.compile(
            "^(?P<name>[\S]+)")

        seq_re = re.compile("[^-A-Za-z]+")

        parsed = fasta_re.findall(fin.read())

        seqnames = []
        seqmatrix = []

        for i in parsed:
            seqnames.append(name_re.search(i.group("header")).group("name"))
            seqmatrix.append(seq_re.sub("",i.group("seq")))

        return cls(seqmatrix = seqmatrix, seqnames = seqnames)

    def writeFasta(self, fout, w = 80):
        """Write alignment in FASTA format using w character lines."""

        # Based on Sequence.DnaSequence.FormatFasta
        # Only difference is that we allow gap characters.
        for (name, seq) in zip(self.seqnames, self.seqmatrix):
            fout.write(">%s\n" % name)
            fout.write(re.sub(r"(.{%d})" % w, r"\1\n", seq)+"\n")
