#!/usr/bin/env python
# Time-stamp: <dp.py 2011-05-03 15:25:14 Mark Voorhies>
"""Dynamic programming examples"""

def makeIdent(alphabet):
    ident = {}
    for i in alphabet:
        ident[i] = {}
        for j in alphabet:
            if(i == j):
                ident[i][j] = 1
            else:
                ident[i][j] = -1
    return ident

def nw(seq1, seq2, s, e = -1, debug = False):
    """Return an optimal global alignment of seq1 and seq2
    given scoring matrix s (a dictionary of dictionaries),
    and gap extension penalty e.
    All scores are assumed to be integers."""

    (m, p) = nw_fill(seq1, seq2, s, e)
    if(debug):
        nw_dump(seq1, seq2, m, p)
    return nw_traceback(seq1, seq2, p)

def nw_fill(seq1, seq2, s, e):
    """Return the score and pointer matrices from the dynamic
    programming step of Needleman-Wunsch."""
    # Initialize dp matrix, a two dimensional matrix of scores
    #    first dimension  = rows = seq1 positions, starting one position early
    #    second dimension = cols = seq2 positions, starting one position early
    m = [[0]]
    # Initialize pointer matrix, a two dimensional matrix of lists of
    #   (row, column) pointers
    p = [[None]]

    # In the following loops, i = row and j = column.
    #   i and j always index the _current_ sequence position
    #     and one _before_ the current cell.
    
    # Fill first row as leading gaps
    for j in range(len(seq2)):
        m[-1].append(m[0][j]+e)
        p[-1].append([(0,j)])

    # Fill remaining rows
    for i in range(len(seq1)):
        # First column is leading gaps
        m.append([m[i][0]+e])
        p.append([[(i,0)]])
        for j in range(len(seq2)):
            # Score for aligning seq1[i] with seq2[j]
            #  (diagonal arrow)
            match = m[i][j]+s[seq1[i]][seq2[j]]
            # Score for aligning seq1[i] with a gap
            #  (horizontal arrow)
            hgap = m[i+1][j]+e
            # Score for aligning seq2[i] with a gap
            #  (vertical arrow)
            vgap = m[i][j+1]+e

            best = max(match, vgap, hgap)
            m[-1].append(best)
            p[-1].append([])

            if(match == best):
                p[-1][-1].append((i,j))
            if(hgap == best):
                p[-1][-1].append((i+1,j))
            if(vgap == best):
                p[-1][-1].append((i,j+1))

    return (m,p)

def nw_traceback(seq1, seq2, p):
    """Return the alignment corresponding to the "middle-road" traceback
    of the given pointer matrix."""

    # "Middle road" traceback
    curpos = (len(seq1),len(seq2))

    #  We should use a "while" loop at this point, but we haven't covered
    #  it in class.  Therefore, we'll fake it with a for loop over the
    #  maximum possible number of steps
    exitFlag = False
    aligned1 = ""
    aligned2 = ""
    for i in range(len(seq1)+len(seq2)):
        plist = p[curpos[0]][curpos[1]]
        if(plist is None):
            exitFlag = True
            break
        nextpos = plist[0]
        # Check for vgap
        if(nextpos[0] == curpos[0]):
            aligned1 = "-"+aligned1
        else:
            aligned1 = seq1[nextpos[0]]+aligned1
        # Check for hgap
        if(nextpos[1] == curpos[1]):
            aligned2 = "-"+aligned2
        else:
            aligned2 = seq2[nextpos[1]]+aligned2

        curpos = nextpos
        
    if(exitFlag == False):
        print "WARNING: Unexpected exit from traceback"

    return (aligned1, aligned2)

def nw_dump(seq1, seq2, m, p):
    """Print a diagnostic dump of the results of nw_fill."""

    print "".join("%5s" % j for j in " "+seq2)
    for (i,c) in zip(m," "+seq1):
        print "".join("%5s" % j for j in i+[c])

    print

    for i in p:
        print i

def sw(seq1, seq2, s, g = 0, e = -1, debug = False):
    """Return an optimal local alignment of seq1 and seq2
    given scoring matrix s (a dictionary of dictionaries),
    gap opening penalty g, and gap extension penalty e.
    All scores are assumed to be integers."""

    (m, p, bestCell) = sw_fill(seq1, seq2, s, g, e)
    if(debug):
        nw_dump(seq1, seq2, m, p)
    return sw_traceback(seq1, seq2, p, bestCell)

def sw_fill(seq1, seq2, s, g, e):
    """Return the score and pointer matrices from the dynamic
    programming step of Smith-Waterman."""
    # Initialize dp matrix, a two dimensional matrix of scores
    #    first dimension  = rows = seq1 positions, starting one position early
    #    second dimension = cols = seq2 positions, starting one position early
    #    third dimension = pointers, as list of (row, col) tuples
    m = [[0]]
    # Initialize pointer matrix, a two dimensional matrix of lists of
    #   (row, column) pointers
    p = [[None]]

    # In the following loops, i = row and j = column.
    #   i and j always index the _current_ sequence position
    #     and one _before_ the current cell.
    
    # Fill first row as "alignment not started"
    for j in range(len(seq2)):
        m[-1].append(0)
        p[-1].append(None)

    # Keep track of the current best alignment
    bestCell = None
    bestScore = None
    
    # Fill remaining rows
    for i in range(len(seq1)):
        # First column is "alignment not started"
        m.append([0])
        p.append([None])
        for j in range(len(seq2)):
            # Score for aligning seq1[i] with seq2[j]
            #  (diagonal arrow)
            match = m[i][j]+s[seq1[i]][seq2[j]]
            # Score for aligning seq1[i] with a gap
            #  (horizontal arrow)
            curgap = e+g
            hgap = m[i+1][j]+curgap
            hpos = j
            for k in reversed(range(0,j)):
                curgap += e
                tgap = curgap + m[i+1][k]
                if(tgap > hgap):
                    hgap = tgap
                    hpos = k

            # Score for aligning seq2[i] with a gap
            #  (vertical arrow)
            curgap = e+g
            vgap = m[i][j+1]+curgap
            vpos = i
            for k in reversed(range(0,i)):
                curgap += e
                tgap = curgap + m[k][j+1]
                if(tgap > vgap):
                    vgap = tgap
                    vpos = k

            best = max(match, vgap, hgap, 0)
            m[-1].append(best)
            p[-1].append([])

            if(best == 0):
                p[-1][-1] = None
            else:
                if(match == best):
                    p[-1][-1].append((i,j))
                if(vgap == best):
                    p[-1][-1].append((vpos,j+1))
                if(hgap == best):
                    p[-1][-1].append((i+1,hpos))
                if(best > bestScore):
                    bestScore = best
                    bestCell = [(i+1,j+1)]
                elif(best == bestScore):
                    bestCell.append((i+1,j+1))

    return (m,p,bestCell)

def sw_traceback(seq1, seq2, p, bestCell):
    """Return the alignment corresponding to the "middle-road" traceback
    of the given pointer matrix."""

    if(bestCell is None):
        return ("","")

    # "Middle road" traceback
    curpos = bestCell[-1]

    #  We should use a "while" loop at this point, but we haven't covered
    #  it in class.  Therefore, we'll fake it with a for loop over the
    #  maximum possible number of steps
    exitFlag = False
    aligned1 = ""
    aligned2 = ""
    for i in range(len(seq1)+len(seq2)):
        plist = p[curpos[0]][curpos[1]]
        if(plist is None):
            exitFlag = True
            break
        nextpos = plist[0]
        # Check for vgap
        if(nextpos[0] == curpos[0]):
            aligned1 = "-"*(curpos[1]-nextpos[1])+aligned1
            aligned2 = seq2[nextpos[1]:curpos[1]]+aligned2
        # Check for hgap
        elif(nextpos[1] == curpos[1]):
            aligned2 = "-"*(curpos[0]-nextpos[0])+aligned2
            aligned1 = seq1[nextpos[0]:curpos[0]]+aligned1
        else:
            aligned1 = seq1[nextpos[0]]+aligned1
            aligned2 = seq2[nextpos[1]]+aligned2

        curpos = nextpos
        
    if(exitFlag == False):
        print "WARNING: Unexpected exit from traceback"

    return (aligned1, aligned2)

def nwg(seq1, seq2, s, g = 0, e = -1, debug = False):
    """Return an optimal global alignment of seq1 and seq2
    given scoring matrix s (a dictionary of dictionaries),
    gap opening penalty g, and gap extension penalty e.
    All scores are assumed to be integers."""

    (m, p) = nwg_fill(seq1, seq2, s, g, e)
    if(debug):
        nw_dump(seq1, seq2, m, p)

    # We can use the sw_traceback function as long as we hard-wire
    #   the starting point to the bottom-right cell.
    return sw_traceback(seq1, seq2, p, [(len(seq1),len(seq2))])

def nwg_fill(seq1, seq2, s, g, e):
    """Return the score and pointer matrices from the dynamic
    programming step of Smith-Waterman."""
    # Initialize dp matrix, a two dimensional matrix of scores
    #    first dimension  = rows = seq1 positions, starting one position early
    #    second dimension = cols = seq2 positions, starting one position early
    m = [[0]]
    # Initialize pointer matrix, a two dimensional matrix of lists of
    #   (row, column) pointers
    p = [[None]]

    # In the following loops, i = row and j = column.
    #   i and j always index the _current_ sequence position
    #     and one _before_ the current cell.
    
    # Fill first row as leading gaps
    for j in range(len(seq2)):
        m[-1].append(m[0][j]+e)
        # We count an opening penalty for the leading gaps
        if(j == 0):
            m[-1][-1] += g
        p[-1].append([(0,j)])
    
    # Fill remaining rows
    for i in range(len(seq1)):
        # First column is leading gaps
        m.append([m[i][0]+e])
        # We count an opening penalty for the leading gaps
        if(i == 0):
            m[-1][-1] += g
        p.append([[(i,0)]])
        for j in range(len(seq2)):
            # Score for aligning seq1[i] with seq2[j]
            #  (diagonal arrow)
            match = m[i][j]+s[seq1[i]][seq2[j]]
            # Score for aligning seq1[i] with a gap
            #  (horizontal arrow)
            curgap = e+g
            hgap = m[i+1][j]+curgap
            hpos = j
            for k in reversed(range(0,j)):
                curgap += e
                tgap = curgap + m[i+1][k]
                if(tgap > hgap):
                    hgap = tgap
                    hpos = k

            # Score for aligning seq2[i] with a gap
            #  (vertical arrow)
            curgap = e+g
            vgap = m[i][j+1]+curgap
            vpos = i
            for k in reversed(range(0,i)):
                curgap += e
                tgap = curgap + m[k][j+1]
                if(tgap > vgap):
                    vgap = tgap
                    vpos = k

            best = max(match, vgap, hgap)
            m[-1].append(best)
            p[-1].append([])

            if(match == best):
                p[-1][-1].append((i,j))
            if(vgap == best):
                p[-1][-1].append((vpos,j+1))
            if(hgap == best):
                p[-1][-1].append((i+1,hpos))

    return (m,p)
