Goal: Fit independent strain, infection, and time factors, otherwise following the approach of GSE88801_limma2.ipynb

In [1]:
%cd ~/pdf/papers/SciRep_7_42225/
/home/mvoorhie/pdf/papers/SciRep_7_42225

In [2]:
from CdtFile import CdtFile, CdtRow
from MsvUtil import Table, revdict, multirevdict
from SafeMath import safelog, safesub, safeadd, safesum
In [3]:
import gzip
from math import log
import os
import numpy as np
In [4]:
# Uncomment this line for newer Jupyter stacks (e.g., Debian 9 (stretch), current Enthought Canopy)
#%load_ext rpy2.ipython 
# Uncomment this line for older IPython stacks (e.g., Debian 8 (jessie), Ubuntu 16.04 (xenial))
%load_ext rmagic 
In [5]:
%%R
library(limma)
library(edgeR)

Load Annotations

Sample annotations

Hand compiled spreadsheet based on paper supplement, mapped to GEO IDs via read counts

In [6]:
samples = Table.fromCsv("sample_table_v2.csv")
print len(samples)
print samples[0]
36
name:		GSM2348248
infection:		uninfected
strain:		J774
time:		4
replicate:		1

In [7]:
ival = {"Live":0,"Dead":1,"uninfected":2}
samples = Table(header = samples.header[:],
                rows = [list(j) 
                        for j in sorted(samples, key = lambda i: (i["strain"],ival[i["infection"]],int(i["replicate"]),int(i["time"])))])
In [8]:
name2sample = dict((i["name"],i) for i in samples)

ENSEMBL Mouse Transcript ID -> Gene ID mapping

In [9]:
import re
gtf_re = re.compile(r'(?P<key>[\S]+)[\s]+"(?P<val>[^"]+)";')
In [10]:
transcript2gene = {}
gene2name = {}
for line in gzip.open("/home/bms270/BMS270_2017/Mus_musculus.GRCm38.79.gtf.gz"):
    if(line.startswith("#")):
        continue
    fields = line.split("\t")
    if(fields[2] == "transcript"):
        transcript = None
        gene = None
        for anno in gtf_re.finditer(fields[8]):
            if(anno.group("key") == "gene_id"):
                gene = anno.group("val")
            elif(anno.group("key") == "transcript_id"):
                transcript = anno.group("val")
        assert(None not in (gene,transcript))
        transcript2gene[transcript] = gene
    elif(fields[2] == "gene"):
        gene = None
        name = None
        for anno in gtf_re.finditer(fields[8]):
            if(anno.group("key") == "gene_id"):
                gene = anno.group("val")
            elif(anno.group("key") == "gene_name"):
                name = anno.group("val")
        assert(gene is not None)
        if(name is None):
            name = gene
        gene2name[gene] = name
len(transcript2gene),len(gene2name)
Out[10]:
(104129, 43629)
In [11]:
gene2transcript = multirevdict(transcript2gene)
len(gene2transcript)
Out[11]:
43629

Merge count and TPM values from kallisto

Step 1: TPMs

In [12]:
genes = None
cols = []
for i in samples["name"]:
    table = Table.fromTdt(open(os.path.join(i,"abundance.tsv")))
    if(genes is None):
        genes = table["target_id"]
    else:
        assert(genes == table["target_id"])
    # N.B.: Not log transforming yet so that we can sum over transcript TPMs for each gene
    cols.append([float(i) for i in table["tpm"]])

trans_tpm = CdtFile(probes = [CdtRow(gid = i[0], uniqid = i[0], name = transcript2gene[i[0]],
                                        ratios = i[1:])
                                 for i in zip(*([genes]+cols))],
                       fieldnames = samples["name"],
                       eweights = [1]*len(samples["name"]))
len(trans_tpm)
Out[12]:
88198
In [13]:
genes = set(i.Name() for i in trans_tpm)
len(genes)
Out[13]:
30735
In [14]:
def gene_row(gene, cdt, gene2transcript):
    rows = [cdt.GetUid(i) for i in gene2transcript[gene]]
    return CdtRow(gid = gene, uniqid = gene, name = gene,
                  ratios = [safelog(sum(j)) for j in zip(*rows)])

# N.B.: At this point, we switch to log space
genes_tpm = CdtFile.fromPrototype(trans_tpm, probes = [gene_row(i, trans_tpm, gene2transcript) for i in genes])
len(genes_tpm)
Out[14]:
30735

Step 2: estimated counts

Following the same method as for TPMs

In [15]:
genes2 = None
cols = []
for i in samples["name"]:
    table = Table.fromTdt(open(os.path.join(i,"abundance.tsv")))
    if(genes2 is None):
        genes2 = table["target_id"]
    else:
        assert(genes2 == table["target_id"])
    # N.B.: Not log transforming yet so that we can sum over transcript TPMs for each gene
    cols.append([float(i) for i in table["est_counts"]])

trans_count = CdtFile(probes = [CdtRow(gid = i[0], uniqid = i[0], name = transcript2gene[i[0]],
                                        ratios = i[1:])
                                 for i in zip(*([genes2]+cols))],
                       fieldnames = samples["name"],
                       eweights = [1]*len(samples["name"]))
len(trans_count)
Out[15]:
88198
In [16]:
genes2 = set(i.Name() for i in trans_count)
genes == genes2
Out[16]:
True
In [17]:
def gene_row(gene, cdt, gene2transcript):
    rows = [cdt.GetUid(i) for i in gene2transcript[gene]]
    return CdtRow(gid = gene, uniqid = gene, name = gene,
                  ratios = [sum(j) for j in zip(*rows)])

# Note that we *do not* log transform counts.
genes_count = CdtFile.fromPrototype(trans_tpm, probes = [gene_row(i, trans_count, gene2transcript) for i in genes])
len(genes_count)
Out[17]:
30735

TPM Filter

In [18]:
def clip(x):
    if(x <= 0.):
        return 0.
    return x

def clip_cdt(cdt):
    return CdtFile.fromPrototype(
           cdt,
           probes = [CdtRow.fromPrototype(i, ratios = [clip(j) for j in i])
                     for i in cdt])

def threshold_cdt(cdt, thresh):
    return CdtFile.fromPrototype(
           cdt,
           probes = [i for i in cdt if(max(i) >= thresh)])

Filter for genes with $TPM \geq 10$ * in at least one sample

In [19]:
ct_genes = threshold_cdt(clip_cdt(genes_tpm), log(10)/log(2))
len(ct_genes), len(genes_tpm)
Out[19]:
(9939, 30735)
In [20]:
ct_count = CdtFile.fromPrototype(genes_count, probes = [genes_count.GetUid(i.Uniqid()) for i in ct_genes])
len(ct_count)
Out[20]:
9939

Stage for import into R

In [21]:
C = np.array([[clip(j) for j in i] for i in ct_count])
C.shape
Out[21]:
(9939, 36)
In [22]:
name2row = dict((i.Name(),n+1) for (n,i) in enumerate(ct_count))
row2name = revdict(name2row)
map(len,(name2row,row2name))
Out[22]:
[9939, 9939]
In [23]:
state = []
cols = []
infection = []
strain = []
time = []
replicate = []
for i in samples:
    cols.append(genes_count.fieldnames.index(i["name"]))
    time.append(int(i["time"]))
    infection.append(i["infection"])
    strain.append(i["strain"])
    replicate.append(i["replicate"])
    state.append("%s.%s.%d" % (strain[-1],infection[-1], time[-1]))

Limma fit

In [24]:
%%R -i C
dge <- DGEList(counts=C)
dge <- calcNormFactors(dge)
In [25]:
%%R -i state,strain,infection,time -o d,cpm,fc,cn
state <- as.factor(state)
strain <- as.factor(strain)
# Note: setting levels so that uninfected is the reference
infection <- factor(infection, levels = c("uninfected","Dead","Live"))
time <- as.factor(time)
d <- model.matrix(~strain+infection+time)
print(colnames(d))

v <- voom(dge, d, plot = TRUE)
fit <- lmFit(v, d)

fit2 <- eBayes(fit)
print(summary(decideTests(fit2)))

fc <- fit$coefficients
cn <- colnames(fit$coefficients)
cpm <- v$E
[1] "(Intercept)"   "strainJ774"    "infectionDead" "infectionLive"
[5] "time24"       
   (Intercept) strainJ774 infectionDead infectionLive time24
-1         278       3851          1145          2225   2239
0          203       1976          7743          5789   5468
1         9458       4112          1051          1925   2232

In [26]:
%%R
print(summary(decideTests(fit2, lfc=1)))
   (Intercept) strainJ774 infectionDead infectionLive time24
-1         253       1433            81           284    159
0          288       7274          9505          9086   9535
1         9398       1232           353           569    245

Extract significantly differential genes

In [27]:
cn
Out[27]:
array(['(Intercept)', 'strainJ774', 'infectionDead', 'infectionLive',
       'time24'], 
      dtype='|S13')
In [28]:
fc.shape
Out[28]:
(9939, 5)
In [29]:
%%R -o t0
t0 <- topTable(fit2, coef="infectionLive", n = 40000, lfc=1, p.value = .05)
t0 <- cbind(Row.Names = rownames(t0), t0)
In [30]:
len(t0)
Out[30]:
853

853 genes at least 2x differential in live/uninfected when factoring out strain and time effects

In [31]:
uid2t0 = dict((row2name[int(i[0])], list(i)) for i in t0)
In [32]:
len(ct_count)
Out[32]:
9939
In [33]:
uid2fc = dict((i.Uniqid(), j) for (i,j) in zip(ct_count, fc))
In [34]:
fc[0]
Out[34]:
array([  7.09644996e+00,   1.24938101e+00,   5.13558886e-04,
         7.15088718e-02,   1.24814609e-01])

Export heatmaps in CDT format

In [35]:
def fit_ratios(uid):
    return uid2fc[uid][1:]

Limma fit contrasts for all genes

In [36]:
common = sorted(set(i.Uniqid() for i in ct_genes))
In [37]:
set(common) == set(uid2fc)
Out[37]:
True
In [38]:
list(cn)[1:]
Out[38]:
['strainJ774', 'infectionDead', 'infectionLive', 'time24']
In [39]:
fieldnames = "J774/BMDM","Dead/uninfected","Live/uninfected","24h/4h"
probes = []
for i in common:
    probes.append(CdtRow(gid = i, uniqid = i, name = gene2name[i], 
                         ratios = fit_ratios(i),
                         # Adding intercept as an annotation column for two reasons
                         # 1) prevents it showing up as a saturated value in the heatmap
                         # 2) works around JavaTreeView bug where no annotation cols -> disabled scatterplot
                         extra = list(uid2fc[i])[:1]))

cdt_contrast2 = CdtFile(probes = probes,
                       fieldnames = fieldnames,
                       eweights = [1.]*len(fieldnames),
                       extranames = ["Abundance"])

cdt_contrast2.write(open("cdt_contrast2.cdt","w"))
In [40]:
len(cdt_contrast2)
Out[40]:
9939

Limma fit contrasts for significantly differential genes

In [41]:
cdt_contrast2_sig = CdtFile.fromPrototype(
   cdt_contrast2, 
   probes = [i for i in cdt_contrast2 if(uid2t0.has_key(i.Uniqid()))])
len(cdt_contrast2_sig)
Out[41]:
853
In [42]:
%%time
tree = cdt_contrast2_sig.cluster(dist = "u", method = "m")
cdt_contrast2_sig.writeCdtGtr("cdt_contrast2_sig.um",tree)
CPU times: user 412 ms, sys: 12 ms, total: 424 ms
Wall time: 410 ms

Building array...
Building distance matrix...
Clustering...

In [42]: