This commit is contained in:
Arnar Flatberg 2007-09-20 16:10:40 +00:00
parent d9e5398865
commit 7e9a0882f1
4 changed files with 168 additions and 59 deletions

View File

@ -15,6 +15,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
X : data matrix (m, n)
Y : data matrix (m, l)
Z : data matrix (n, o)
alpha : how much z influence (1=max, 0=none)
:output:
T : X-scores
@ -36,7 +37,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
if mean_ctr:
xctr, yctr, zctr = mean_ctr
X, mnX = center(X, xctr)
Y, mnY = center(Y, xctr)
Y, mnY = center(Y, yctr)
Z, mnZ = center(Z, zctr)
varX = pow(X, 2).sum()
@ -116,7 +117,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
T = T/tnorm
Q = Q*tnorm
W = W*tnorm
return T, W, P, Q, U, L, K, B, b0, evx, evy, evz
return T, W, P, Q, U, L, K, B, b0, evx, evy, evz, mnX, mnY, mnZ
def svd_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], verbose=True):
"""
@ -306,8 +307,14 @@ def bifpls(X, Y, Z, a_max, alpha):
evz = 100.0*(1 - var_z/varZ)
def center(a, axis):
# 0 = col center, 1 = row center, 2 = double center
# -1 = nothing
# 0 = col center, 1 = row center, 2 = double center
# -1 = nothing
if len(a.shape)==1:
mn = a.mean()
return a - mn, mn
if a.shape[0]==1 or a.shape[1]==1:
mn = a.mean()
return a - mn, mn
if axis==-1:
mn = zeros((a.shape[1],))
return a - mn, mn
@ -318,7 +325,7 @@ def center(a, axis):
mn = a.mean(1)[:,newaxis]
return a - mn , mn
elif axis==2:
mn = a.mean(0) + a.mean(1)[:,newaxis] - a.mean()
mn = a.mean(1)[:,newaxis] + a.mean(0) - a.mean()
return a - mn, mn
else:
raise IOError("input error: axis must be in [-1,0,1,2]")
@ -367,27 +374,47 @@ def correlation_loadings(D, T, P, test=True):
def cv_lpls(X, Y, Z, a_max=2, nsets=None,alpha=.5, mean_ctr=[2,0,1]):
"""Performs crossvalidation to get generalisation error in lpls"""
# if double centering of x or y:
# row-center prior to cross validation (as this is independent of subsets)
if mean_ctr[0]==2:
mnx_row = X.mean(1)[:,newaxis]
X = X - mnx_row
mean_ctr[0] = 0
else:
mnx_row = 0
if mean_ctr[1]==2:
if Y.shape[1]!=1:
mny_row = Y.mean(1)[:,newaxis]
Y = Y - mny_row
else:
mny_row = 0
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=True)
k, l = Y.shape
Yhat = empty((a_max,k,l), 'd')
for i, (xcal,xi,ycal,yi,ind) in enumerate(cv_iter):
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=alpha,
mean_ctr=mean_ctr,
verbose=False)
T, W, P, Q, U, L, K, B, b0, evx, evy, evz, mnx, mny, mnz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=alpha,
mean_ctr=mean_ctr,
verbose=False)
for a in range(a_max):
Yhat[a,ind,:] = b0[a][0][0] + dot(xi, B[a])
xc = xi - mnx
Yhat[a,ind,:] = mny + dot(xc, B[a])
Yhat_class = zeros_like(Yhat)
for a in range(a_max):
for i in range(k):
Yhat_class[a,i,argmax(Yhat[a,i,:])]=1.0
Yhat_class[a,i,argmax(Yhat[a,i,:])] = 1.0
class_err = 100*((Yhat_class+Y)==2).sum(1)/Y.sum(0).astype('d')
sep = (Y - Yhat)**2
rmsep = sqrt(sep.mean(1))
return rmsep, Yhat, class_err
def jk_lpls(X, Y, Z, a_max, nsets=None, alpha=.5, mean_ctr=[2,0,1]):
def jk_lpls(X, Y, Z, a_max, nsets=None, xz_alpha=.5, mean_ctr=[2,0,1]):
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=False)
m, n = X.shape
k, l = Y.shape
@ -398,12 +425,12 @@ def jk_lpls(X, Y, Z, a_max, nsets=None, alpha=.5, mean_ctr=[2,0,1]):
WWz = empty((nsets, o, a_max), 'd')
WWy = empty((nsets, l, a_max), 'd')
for i, (xcal,xi,ycal,yi) in enumerate(cv_iter):
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=alpha,
mean_ctr=mean_ctr,
scale='loads',
verbose=False)
T, W, P, Q, U, L, K, B, b0, evx, evy, evz,mnx,mny,mnz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=xz_alpha,
mean_ctr=mean_ctr,
scale='loads',
verbose=False)
WWx[i,:,:] = W
WWz[i,:,:] = L
WWy[i,:,:] = Q

View File

@ -2,8 +2,9 @@ import pylab
import matplotlib
import networkx as nx
import scipy
import rpy
def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,drawback=True, labels=None):
def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,drawback=True, labels=None, **kwds):
""" Correlation loading plot."""
# background
@ -25,7 +26,7 @@ def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,draw
ax.axvline(lw=1.5,color='k')
# corrloads
ax.scatter(R[:,pc1], R[:,pc2], s=s, c=c,zorder=zorder)
ax.scatter(R[:,pc1], R[:,pc2], s=s, c=c,zorder=zorder, **kwds)
ax.set_xlim([-1,1])
ax.set_ylim([-1,1])
if expvar!=None:
@ -39,24 +40,45 @@ def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,draw
pylab.text(r[pc1], r[pc2], " " + name)
#pylab.show()
def plot_dag(edge_dict, node_color='b', node_size=30,labels=None,nodelist=None,pos=None):
def dag(terms, ontology):
rpy.r.library("GOstats")
__parents = {'bp' : rpy.r.GOBPPARENTS,
'mf' : rpy.r.GOMFPARENTS,
'cc' : rpy.r.GOCCPARENTS}
gograph = rpy.r.GOGraph(terms, __parents.get(ontology))
dag = rpy.r.edges(gograph)
#setattr(dag, "_ontology", ontology)
return dag
def plot_dag(dag, node_color='b', node_size=30,with_labels=False,nodelist=None,pos=None):
rpy.r.library("GOstats")
dag_name = "GO-bp"
# networkx does not play well with colon in node names
clean_edges = {}
for head, neigb in edge_dict.items():
for head, neigb in dag.items():
head = head.replace(":", "_")
nei = [i.replace(":", "_") for i in neigb]
clean_edges[head] = nei
if pos==None:
G = nx.from_dict_of_lists(clean_edges, nx.DiGraph(name='GO'))
G = nx.from_dict_of_lists(clean_edges, nx.DiGraph(name=dag_name))
pos = nx.pydot_layout(G, prog='dot')
G = nx.from_dict_of_lists(edge_dict, nx.DiGraph(name='GO'))
pos_new = {}
for k, v in pos.items():
x,y = v
k = k.replace("_", ":")
pos_new[k] = (x, -y)
pos = pos_new
G = nx.from_dict_of_lists(dag, nx.Graph(name=dag_name))
if len(node_color)>1:
assert(len(node_color)==len(nodelist))
if labels!=None:
with_labels=True
nx.draw_networkx(G,pos, with_labels=with_labels, node_size=node_size, node_color=node_color, nodelist=nodelist)
return pos
def plot_ZXcorr(gene_ids, term_ids, gene2go, X, D, scale=True):
""" Plot correlation/covariance between genes as a function of
@ -80,6 +102,39 @@ def plot_ZXcorr(gene_ids, term_ids, gene2go, X, D, scale=True):
def clustering_index(T, Yg):
pass
def draw_gene(gid, gene_ids, gene2go, Z, tmat, terms, G, pos):
"""Draw dags with marked go terms and distance to all terms.
"""
sub_terms = gene2go[gid]
sub_index = [i for i, tid in enumerate(terms) if tid in sub_terms]
node_size = 70.*scipy.ones((len(terms),))
node_size[sub_index] = 500
gene_index = [i for i, gene_id in enumerate(gene_ids) if gene_id==gid]
node_color = Z[:,gene_index].ravel()
#1/0
#node_size=200*node_color
#node_color='g'
pylab.figure()
nx.draw_networkx(G, pos, node_color=node_color, node_size=node_size, with_labels=False, nodelist=terms)
ax = pylab.gca()
pylab.colorbar(ax.collections[0])
for tid in sub_index:
pylab.figure()
node_color = tmat[tid,:]
#node_size = 70*scipy.ones((len(terms),))
node_size = 170*node_color
node_size[tid] = 500
nx.draw_networkx(G, pos, node_color=node_color, node_size=node_size, with_labels=False, nodelist=terms)
pylab.title(terms[tid])
ax = pylab.gca()
pylab.colorbar(ax.collections[0])
pylab.show()
#nx.show()

View File

@ -38,6 +38,9 @@ def goterms_from_gene(genelist, ontology='BP', garbage=None):
print "loading GO definitions environment"
gene2terms = {}
cc = 0
dd = 0
ii = 0
for gene in genelist:
info = rpy.r('GOENTREZID2GO[["' + str(gene) + '"]]')
#print info
@ -50,17 +53,23 @@ def goterms_from_gene(genelist, ontology='BP', garbage=None):
if ic.get(term)==None:
#print "\nHave no IC on this GO term %s for this gene: %s" %(term,gene)
skip=True
ii += 1
if desc['Ontology']!=ontology:
#print "\nThis GO term %s belongs to: %s:" %(term,desc['Ontology'])
skip = True
dd += 1
if not skip:
if gene2terms.has_key(gene):
gene2terms[gene].append(term)
else:
gene2terms[gene] = [term]
else:
print "\nHave no Annotation on this gene: %s" %gene
cc += 1
print "\nNumber of genes without annotation: %d" %cc
print "\nNumber of genes not in %s : %d " %(ontology, dd)
print "\nNumber of genes with infs : %d " %ii
return gene2terms
def genego_matrix(goterms, tmat, gene_ids, term_ids, func=max):
@ -166,7 +175,7 @@ def gene_GO_hypergeo_test(genelist,universe="entrezUniverse",ontology="BP",chip
)
result = rpy.r.summary(rpy.r.hyperGTest(params))
return rpy.r.summary(result), params
return result, params
def data_aff2loc_hgu133a(X, aff_ids, verbose=False):
aff_ids = scipy.asarray(aff_ids)

View File

@ -19,42 +19,49 @@ if use_data=='smoker':
Y = DY.asarray().astype('d')
gene_ids = DX.get_identifiers('gene_ids', sorted=True)
elif use_data=='scherf':
DX = dataset.read_ftsv(open("../../data/scherf/Scherf.ftsv"))
DY = dataset.read_ftsv(open("../../data/scherf/Yd.ftsv"))
DX = dataset.read_ftsv(open("../../data/scherf/scherfX.ftsv"))
DY = dataset.read_ftsv(open("../../data/scherf/scherfY.ftsv"))
Y = DY.asarray().astype('d')
gene_ids = DX.get_identifiers('gene_ids', sorted=True)
elif use_data=='staunton':
pass
elif use_data=='uma':
DX = dataset.read_ftsv(open("../../data/uma/X133.ftsv"))
DY = dataset.read_ftsv(open("../../data/uma/Yg133.ftsv"))
DYg = dataset.read_ftsv(open("../../data/uma/Yg133.ftsv"))
DY = dataset.read_ftsv(open("../../data/uma/Yd.ftsv"))
Y = DY.asarray().astype('d')
gene_ids = DX.get_identifiers('gene_ids', sorted=True)
# Use only subset defined on GO
ontology = 'BP'
print "\n\nFiltering genes by Go terms "
# use subset with defined GO-terms
gene2goterms = rpy_go.goterms_from_gene(gene_ids)
all_terms = set()
for t in gene2goterms.values():
all_terms.update(t)
terms = list(all_terms)
print "\nNumber of go-terms: %s" %len(terms)
# update genelist
gene_ids = gene2goterms.keys()
print "\nNumber of genes: %s" %len(gene_ids)
X = DX.asarray()
index = DX.get_indices('gene_ids', gene_ids)
X = X[:,index]
1/0
# Use only subset defined on GO
ontology = 'BP'
print "\n\nFiltering genes by Go terms "
# use subset based on SAM or IQR
subset = 'm'
subset = 'not'
if subset=='sam':
# select subset genes by SAM
rpy.r.library("siggenes")
rpy.r.library("qvalue")
data = DX.asarray().T
# data = data[:100,:]
rpy.r.assign("data", data)
cl = dot(DY.asarray(), diag([1,2,3])).sum(1)
rpy.r.assign("data", X.T)
cl = dot(DY.asarray(), diag(arange(Y.shape[1])+1)).sum(1)
rpy.r.assign("cl", cl)
rpy.r.assign("B", 20)
# Perform a SAM analysis.
@ -65,13 +72,21 @@ if subset=='sam':
qq = rpy.r('qobj<-qvalue(sam.out@p.value)')
qvals = asarray(qq['qvalues'])
# cut off
cutoff = 0.001
cutoff = 0.01
index = where(qvals<cutoff)[0]
# Subset data
X = DX.asarray()
#Xr = X[:,index]
gene_ids = DX.get_identifiers('gene_ids', index)
X = X[:,index]
gene_ids = [gid for i, gid in enumerate(gene_ids) if i in index]
print "\nWorking on subset with %s genes " %len(gene_ids)
# update valid go-terms
gene2goterms = rpy_go.goterms_from_gene(gene_ids)
all_terms = set()
for t in gene2goterms.values():
all_terms.update(t)
terms = list(all_terms)
print "\nNumber of go-terms: %s" %len(terms)
else:
# noimp (smoker data is prefiltered)
pass
@ -97,9 +112,9 @@ Xr = DX.asarray()[:,newind]
######## LPLSR ########
print "LPLSR ..."
a_max = 5
a_max = 10
aopt = 3
xz_alpha = .5
xz_alpha = .6
w_alpha = .1
mean_ctr = [2, 0, 2]
@ -108,9 +123,9 @@ sdtz = False
if sdtz:
Z = Z/Z.std(0)
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(Xr,Y,Z, a_max,
alpha=xz_alpha,
mean_ctr=mean_ctr)
T, W, P, Q, U, L, K, B, b0, evx, evy, evz,mnx,mny,mnz = nipals_lpls(Xr,Y,Z, a_max,
alpha=xz_alpha,
mean_ctr=mean_ctr)
# Correlation loadings
dx,Rx,rssx = correlation_loadings(Xr, T, P)
@ -118,11 +133,13 @@ dx,Ry,rssy = correlation_loadings(Y, T, Q)
cadz,Rz,rssz = correlation_loadings(Z.T, W, L)
# Prediction error
rmsep , yhat, class_error = cv_lpls(Xr, Y, Z, a_max, alpha=xz_alpha,mean_ctr=mean_ctr)
alpha_check=False
alpha_check=True
if alpha_check:
Alpha = arange(0.01, 1, .1)
Rmsep,Yhat, CE = [],[],[]
for a in Alpha:
print "alpha %f" %a
rmsep , yhat, ce = cv_lpls(Xr, Y, Z, a_max, alpha=xz_alpha,mean_ctr=mean_ctr)
Rmsep.append(rmsep)
Yhat.append(yhat)
@ -131,11 +148,12 @@ if alpha_check:
Yhat = asarray(Yhat)
CE = asarray(CE)
# Significance Hotellings T
Wx, Wz, Wy, = jk_lpls(Xr, Y, Z, aopt, mean_ctr=mean_ctr,alpha=w_alpha)
Wx, Wz, Wy, = jk_lpls(Xr, Y, Z, aopt, mean_ctr=mean_ctr,alpha=xz_alpha)
Ws = W*apply_along_axis(norm, 0, T)
tsqx = cx_stats.hotelling(Wx, Ws[:,:aopt])
tsqz = cx_stats.hotelling(Wz, L[:,:aopt])
tsqx = cx_stats.hotelling(Wx, Ws[:,:aopt], alpha=w_alpha)
tsqz = cx_stats.hotelling(Wz, L[:,:aopt], alpha=0)
## plots ##
@ -156,12 +174,12 @@ title('Classification accuracy')
figure(3) # Hypoid correlations
tsqz_s = 250*tsqz/tsqz.max()
plot_corrloads(Rz, pc1=0, pc2=1, s=tsqz_s, c='b', zorder=5, expvar=evz, ax=None)
plot_corrloads(Rz, pc1=0, pc2=1, s=tsqz_s, c=tsqz, zorder=5, expvar=evz, ax=None,alpha=.5)
ax = gca()
ylabels = DY.get_identifiers('_status', sorted=True)
plot_corrloads(Ry, pc1=0, pc2=1, s=150, c='g', zorder=5, expvar=evy, ax=ax,labels=ylabels)
ylabels = DY.get_identifiers(DY.get_dim_name()[1], sorted=True)
plot_corrloads(Ry, pc1=0, pc2=1, s=150, c='g', zorder=5, expvar=evy, ax=ax,labels=ylabels,alpha=.5)
figure(3)
figure(4)
subplot(221)
ax = gca()
plot_corrloads(Rx, pc1=0, pc2=1, s=tsqx/2.0, c='b', zorder=5, expvar=evx, ax=ax)