class MultipleAlignment:
    def __init__(self, seqmatrix, seqnames = None, colAnnotations = None):
        self.seqmatrix = seqmatrix
        if(seqnames is not None):
            self.seqnames = seqnames
            assert(len(self.seqnames) == len(self.seqmatrix))
        else:
            self.seqnames = ["Seq%05d" % i for i in xrange(len(self.seqmatrix))]

        if(colAnnotations is not None):
            self.colAnnotations = colAnnotations
        else:
            self.colAnnotations = {}

        self.name2seq = dict((i,j) for (i,j) in zip(seqnames, seqmatrix))
        assert(len(self.name2seq) == len(self.seqnames))

    def __getitem__(self, i):
        """Return the sequence named "i" if i is a string
        or column vector [i] (counting from 0) if i is an integer)
        """
        if(isinstance(i, int)):
            return [row[i] for row in self.seqmatrix]
        if(isinstance(i, str)):
            return self.name2seq[i]
        raise KeyError

    def __getslice__(self, i, j):
        """Return a MultipleAlignment corresponding to columns
        [i:j] (indexing from 0, j is one past the last column).
        """
        return MultipleAlignment(
            seqmatrix = [seq[i:j] for seq in self.seqmatrix],
            seqnames = self.seqnames,
            colAnnotations = dict((key, seq[i:j])
                                  for (key,seq) in self.colAnnotations.items()))

    def writeClustal(self, fout):
        """Write alignment in (inferred) CLUSTALX .aln format."""
        seqnames = [i[max(0,len(i)-35):] for i in self.seqnames]
        for i in seqnames:
            assert(len(i) < 36)
        assert(len(seqnames) == len(set(seqnames)))

        fout.write('CLUSTAL X (1.83) multiple sequence alignment\n')
        fout.write('\n')
        step = 50
        offset = 0
        n = len(self.seqmatrix[0])
        while(offset < n):
            # spacer
            fout.write('\n')
            for (name, seq) in zip(seqnames, self.seqmatrix):
                fout.write('%-35s %s\n' % (name, seq[offset:offset+step]))
            offset += step
            # conservation line
            fout.write(' '*86+'\n')

    @classmethod
    def fromClustal(cls, fin, headercheck = True):
        seqnames = []
        seqs = {}

        # Read past header
        line = fin.next()
        if(headercheck):
            assert(line[:7] == "CLUSTAL")
        while((line[:7] == "CLUSTAL") or (len(line.strip()) == 0)):
            line = fin.next()

        def isConserved(s):
            if(len(s) < 1):
                return False
            if(not s[0].isspace()):
                return False
            for i in s:
                if(i.isspace()):
                    continue
                if(i not in ".:*"):
                    return False
            return True
                
        # Read first block
        while((len(line.strip()) > 0) and (not isConserved(line))):
            w = line.split()
            assert(len(w) == 2)
            assert(not seqs.has_key(w[0]))
            for i in w[1]:
                try:
                    assert(i in "ACDEFGHIKLMNPQRSTUVWYXacdefghiklmnpqrstuvwyx-")
                except AssertionError:
                    sys.stderr.write(line)
                    raise
            seqnames.append(w[0])
            seqs[w[0]] = w[1]
            line = fin.next()

        # Read subsequent blocks:
        for line in fin:
            if((len(line.strip()) == 0) or isConserved(line)):
                continue

            w = line.split()
            assert(len(w) == 2)
            for i in w[1]:
                assert(i in "ACDEFGHIKLMNPQRSTUVWYXacdefghiklmnpqrstuvwyx-")
            seqs[w[0]] += w[1]

        n = len(seqs[seqnames[0]])
        for i in seqnames[1:]:
            assert(len(seqs[i]) == n)

        return cls(seqmatrix = [seqs[i] for i in seqnames],
                   seqnames = seqnames)

    @classmethod
    def fromStockholm(cls, fin, headercheck = True):
        import re

        seqnames = []
        seqs = {}
        colAnnotations = {}

        # Check file foramt
        if(headercheck):
            line = fin.next()
            assert(line[:15] == "# STOCKHOLM 1.0")

        for line in fin:
            # Skip blank lines
            if(len(line.strip()) < 1):
                continue

            # Quit on "end-of-data" marker
            if(line[:2] == '//'):
                break

            # Parse comments and annotations
            if(line[0] == '#'):
                # Column annotations
                if(line[1:4] == "=GC"):
                    (comment, tag, seq) = line.split()
                    if(not colAnnotations.has_key(tag)):
                        colAnnotations[tag] = seq
                    else:
                        colAnnotations[tag] += seq

                # Skip all other

                continue

            # Parse sequence
            w = line.split()
            assert(len(w) == 2)

            if(not seqs.has_key(w[0])):
                seqnames.append(w[0])
                seqs[w[0]] = re.sub('[*.]','-',w[1])
            else:
                seqs[w[0]] += re.sub('[*.]','-',w[1])

        n = len(seqs[seqnames[0]])
        for i in seqnames[1:]:
            try:
                assert(len(seqs[i]) == n)
            except AssertionError:
                print [(i,len(seqs[i])) for i in seqnames]
                raise

        for (tag, seq) in colAnnotations.items():
            assert(len(seq) == n)

        if(len(colAnnotations) == 0):
            colAnnotations = None

        return cls(seqmatrix = [seqs[i] for i in seqnames],
                   seqnames = seqnames, colAnnotations = colAnnotations)
