Projects/laydi
Projects
/
laydi
Archived
7
0
Fork 0
This commit is contained in:
Arnar Flatberg 2007-09-20 16:10:40 +00:00
parent d9e5398865
commit 7e9a0882f1
4 changed files with 168 additions and 59 deletions

View File

@ -15,6 +15,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
X : data matrix (m, n) X : data matrix (m, n)
Y : data matrix (m, l) Y : data matrix (m, l)
Z : data matrix (n, o) Z : data matrix (n, o)
alpha : how much z influence (1=max, 0=none)
:output: :output:
T : X-scores T : X-scores
@ -36,7 +37,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
if mean_ctr: if mean_ctr:
xctr, yctr, zctr = mean_ctr xctr, yctr, zctr = mean_ctr
X, mnX = center(X, xctr) X, mnX = center(X, xctr)
Y, mnY = center(Y, xctr) Y, mnY = center(Y, yctr)
Z, mnZ = center(Z, zctr) Z, mnZ = center(Z, zctr)
varX = pow(X, 2).sum() varX = pow(X, 2).sum()
@ -116,7 +117,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
T = T/tnorm T = T/tnorm
Q = Q*tnorm Q = Q*tnorm
W = W*tnorm W = W*tnorm
return T, W, P, Q, U, L, K, B, b0, evx, evy, evz return T, W, P, Q, U, L, K, B, b0, evx, evy, evz, mnX, mnY, mnZ
def svd_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], verbose=True): def svd_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], verbose=True):
""" """
@ -308,6 +309,12 @@ def bifpls(X, Y, Z, a_max, alpha):
def center(a, axis): def center(a, axis):
# 0 = col center, 1 = row center, 2 = double center # 0 = col center, 1 = row center, 2 = double center
# -1 = nothing # -1 = nothing
if len(a.shape)==1:
mn = a.mean()
return a - mn, mn
if a.shape[0]==1 or a.shape[1]==1:
mn = a.mean()
return a - mn, mn
if axis==-1: if axis==-1:
mn = zeros((a.shape[1],)) mn = zeros((a.shape[1],))
return a - mn, mn return a - mn, mn
@ -318,7 +325,7 @@ def center(a, axis):
mn = a.mean(1)[:,newaxis] mn = a.mean(1)[:,newaxis]
return a - mn , mn return a - mn , mn
elif axis==2: elif axis==2:
mn = a.mean(0) + a.mean(1)[:,newaxis] - a.mean() mn = a.mean(1)[:,newaxis] + a.mean(0) - a.mean()
return a - mn, mn return a - mn, mn
else: else:
raise IOError("input error: axis must be in [-1,0,1,2]") raise IOError("input error: axis must be in [-1,0,1,2]")
@ -367,27 +374,47 @@ def correlation_loadings(D, T, P, test=True):
def cv_lpls(X, Y, Z, a_max=2, nsets=None,alpha=.5, mean_ctr=[2,0,1]): def cv_lpls(X, Y, Z, a_max=2, nsets=None,alpha=.5, mean_ctr=[2,0,1]):
"""Performs crossvalidation to get generalisation error in lpls""" """Performs crossvalidation to get generalisation error in lpls"""
# if double centering of x or y:
# row-center prior to cross validation (as this is independent of subsets)
if mean_ctr[0]==2:
mnx_row = X.mean(1)[:,newaxis]
X = X - mnx_row
mean_ctr[0] = 0
else:
mnx_row = 0
if mean_ctr[1]==2:
if Y.shape[1]!=1:
mny_row = Y.mean(1)[:,newaxis]
Y = Y - mny_row
else:
mny_row = 0
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=True) cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=True)
k, l = Y.shape k, l = Y.shape
Yhat = empty((a_max,k,l), 'd') Yhat = empty((a_max,k,l), 'd')
for i, (xcal,xi,ycal,yi,ind) in enumerate(cv_iter): for i, (xcal,xi,ycal,yi,ind) in enumerate(cv_iter):
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(xcal,ycal,Z, T, W, P, Q, U, L, K, B, b0, evx, evy, evz, mnx, mny, mnz = nipals_lpls(xcal,ycal,Z,
a_max=a_max, a_max=a_max,
alpha=alpha, alpha=alpha,
mean_ctr=mean_ctr, mean_ctr=mean_ctr,
verbose=False) verbose=False)
for a in range(a_max): for a in range(a_max):
Yhat[a,ind,:] = b0[a][0][0] + dot(xi, B[a]) xc = xi - mnx
Yhat[a,ind,:] = mny + dot(xc, B[a])
Yhat_class = zeros_like(Yhat) Yhat_class = zeros_like(Yhat)
for a in range(a_max): for a in range(a_max):
for i in range(k): for i in range(k):
Yhat_class[a,i,argmax(Yhat[a,i,:])]=1.0 Yhat_class[a,i,argmax(Yhat[a,i,:])] = 1.0
class_err = 100*((Yhat_class+Y)==2).sum(1)/Y.sum(0).astype('d') class_err = 100*((Yhat_class+Y)==2).sum(1)/Y.sum(0).astype('d')
sep = (Y - Yhat)**2 sep = (Y - Yhat)**2
rmsep = sqrt(sep.mean(1)) rmsep = sqrt(sep.mean(1))
return rmsep, Yhat, class_err return rmsep, Yhat, class_err
def jk_lpls(X, Y, Z, a_max, nsets=None, alpha=.5, mean_ctr=[2,0,1]): def jk_lpls(X, Y, Z, a_max, nsets=None, xz_alpha=.5, mean_ctr=[2,0,1]):
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=False) cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=False)
m, n = X.shape m, n = X.shape
k, l = Y.shape k, l = Y.shape
@ -398,9 +425,9 @@ def jk_lpls(X, Y, Z, a_max, nsets=None, alpha=.5, mean_ctr=[2,0,1]):
WWz = empty((nsets, o, a_max), 'd') WWz = empty((nsets, o, a_max), 'd')
WWy = empty((nsets, l, a_max), 'd') WWy = empty((nsets, l, a_max), 'd')
for i, (xcal,xi,ycal,yi) in enumerate(cv_iter): for i, (xcal,xi,ycal,yi) in enumerate(cv_iter):
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(xcal,ycal,Z, T, W, P, Q, U, L, K, B, b0, evx, evy, evz,mnx,mny,mnz = nipals_lpls(xcal,ycal,Z,
a_max=a_max, a_max=a_max,
alpha=alpha, alpha=xz_alpha,
mean_ctr=mean_ctr, mean_ctr=mean_ctr,
scale='loads', scale='loads',
verbose=False) verbose=False)

View File

@ -2,8 +2,9 @@ import pylab
import matplotlib import matplotlib
import networkx as nx import networkx as nx
import scipy import scipy
import rpy
def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,drawback=True, labels=None): def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,drawback=True, labels=None, **kwds):
""" Correlation loading plot.""" """ Correlation loading plot."""
# background # background
@ -25,7 +26,7 @@ def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,draw
ax.axvline(lw=1.5,color='k') ax.axvline(lw=1.5,color='k')
# corrloads # corrloads
ax.scatter(R[:,pc1], R[:,pc2], s=s, c=c,zorder=zorder) ax.scatter(R[:,pc1], R[:,pc2], s=s, c=c,zorder=zorder, **kwds)
ax.set_xlim([-1,1]) ax.set_xlim([-1,1])
ax.set_ylim([-1,1]) ax.set_ylim([-1,1])
if expvar!=None: if expvar!=None:
@ -39,23 +40,44 @@ def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,draw
pylab.text(r[pc1], r[pc2], " " + name) pylab.text(r[pc1], r[pc2], " " + name)
#pylab.show() #pylab.show()
def plot_dag(edge_dict, node_color='b', node_size=30,labels=None,nodelist=None,pos=None):
def dag(terms, ontology):
rpy.r.library("GOstats")
__parents = {'bp' : rpy.r.GOBPPARENTS,
'mf' : rpy.r.GOMFPARENTS,
'cc' : rpy.r.GOCCPARENTS}
gograph = rpy.r.GOGraph(terms, __parents.get(ontology))
dag = rpy.r.edges(gograph)
#setattr(dag, "_ontology", ontology)
return dag
def plot_dag(dag, node_color='b', node_size=30,with_labels=False,nodelist=None,pos=None):
rpy.r.library("GOstats")
dag_name = "GO-bp"
# networkx does not play well with colon in node names # networkx does not play well with colon in node names
clean_edges = {} clean_edges = {}
for head, neigb in edge_dict.items(): for head, neigb in dag.items():
head = head.replace(":", "_") head = head.replace(":", "_")
nei = [i.replace(":", "_") for i in neigb] nei = [i.replace(":", "_") for i in neigb]
clean_edges[head] = nei clean_edges[head] = nei
if pos==None: if pos==None:
G = nx.from_dict_of_lists(clean_edges, nx.DiGraph(name='GO')) G = nx.from_dict_of_lists(clean_edges, nx.DiGraph(name=dag_name))
pos = nx.pydot_layout(G, prog='dot') pos = nx.pydot_layout(G, prog='dot')
G = nx.from_dict_of_lists(edge_dict, nx.DiGraph(name='GO')) pos_new = {}
for k, v in pos.items():
x,y = v
k = k.replace("_", ":")
pos_new[k] = (x, -y)
pos = pos_new
G = nx.from_dict_of_lists(dag, nx.Graph(name=dag_name))
if len(node_color)>1: if len(node_color)>1:
assert(len(node_color)==len(nodelist)) assert(len(node_color)==len(nodelist))
if labels!=None:
with_labels=True
nx.draw_networkx(G,pos, with_labels=with_labels, node_size=node_size, node_color=node_color, nodelist=nodelist) nx.draw_networkx(G,pos, with_labels=with_labels, node_size=node_size, node_color=node_color, nodelist=nodelist)
return pos
def plot_ZXcorr(gene_ids, term_ids, gene2go, X, D, scale=True): def plot_ZXcorr(gene_ids, term_ids, gene2go, X, D, scale=True):
@ -81,6 +103,39 @@ def plot_ZXcorr(gene_ids, term_ids, gene2go, X, D, scale=True):
def clustering_index(T, Yg): def clustering_index(T, Yg):
pass pass
def draw_gene(gid, gene_ids, gene2go, Z, tmat, terms, G, pos):
"""Draw dags with marked go terms and distance to all terms.
"""
sub_terms = gene2go[gid]
sub_index = [i for i, tid in enumerate(terms) if tid in sub_terms]
node_size = 70.*scipy.ones((len(terms),))
node_size[sub_index] = 500
gene_index = [i for i, gene_id in enumerate(gene_ids) if gene_id==gid]
node_color = Z[:,gene_index].ravel()
#1/0
#node_size=200*node_color
#node_color='g'
pylab.figure()
nx.draw_networkx(G, pos, node_color=node_color, node_size=node_size, with_labels=False, nodelist=terms)
ax = pylab.gca()
pylab.colorbar(ax.collections[0])
for tid in sub_index:
pylab.figure()
node_color = tmat[tid,:]
#node_size = 70*scipy.ones((len(terms),))
node_size = 170*node_color
node_size[tid] = 500
nx.draw_networkx(G, pos, node_color=node_color, node_size=node_size, with_labels=False, nodelist=terms)
pylab.title(terms[tid])
ax = pylab.gca()
pylab.colorbar(ax.collections[0])
pylab.show()
#nx.show()

View File

@ -38,6 +38,9 @@ def goterms_from_gene(genelist, ontology='BP', garbage=None):
print "loading GO definitions environment" print "loading GO definitions environment"
gene2terms = {} gene2terms = {}
cc = 0
dd = 0
ii = 0
for gene in genelist: for gene in genelist:
info = rpy.r('GOENTREZID2GO[["' + str(gene) + '"]]') info = rpy.r('GOENTREZID2GO[["' + str(gene) + '"]]')
#print info #print info
@ -50,16 +53,22 @@ def goterms_from_gene(genelist, ontology='BP', garbage=None):
if ic.get(term)==None: if ic.get(term)==None:
#print "\nHave no IC on this GO term %s for this gene: %s" %(term,gene) #print "\nHave no IC on this GO term %s for this gene: %s" %(term,gene)
skip=True skip=True
ii += 1
if desc['Ontology']!=ontology: if desc['Ontology']!=ontology:
#print "\nThis GO term %s belongs to: %s:" %(term,desc['Ontology']) #print "\nThis GO term %s belongs to: %s:" %(term,desc['Ontology'])
skip = True skip = True
dd += 1
if not skip: if not skip:
if gene2terms.has_key(gene): if gene2terms.has_key(gene):
gene2terms[gene].append(term) gene2terms[gene].append(term)
else: else:
gene2terms[gene] = [term] gene2terms[gene] = [term]
else: else:
print "\nHave no Annotation on this gene: %s" %gene cc += 1
print "\nNumber of genes without annotation: %d" %cc
print "\nNumber of genes not in %s : %d " %(ontology, dd)
print "\nNumber of genes with infs : %d " %ii
return gene2terms return gene2terms
@ -166,7 +175,7 @@ def gene_GO_hypergeo_test(genelist,universe="entrezUniverse",ontology="BP",chip
) )
result = rpy.r.summary(rpy.r.hyperGTest(params)) result = rpy.r.summary(rpy.r.hyperGTest(params))
return rpy.r.summary(result), params return result, params
def data_aff2loc_hgu133a(X, aff_ids, verbose=False): def data_aff2loc_hgu133a(X, aff_ids, verbose=False):
aff_ids = scipy.asarray(aff_ids) aff_ids = scipy.asarray(aff_ids)

View File

@ -19,42 +19,49 @@ if use_data=='smoker':
Y = DY.asarray().astype('d') Y = DY.asarray().astype('d')
gene_ids = DX.get_identifiers('gene_ids', sorted=True) gene_ids = DX.get_identifiers('gene_ids', sorted=True)
elif use_data=='scherf': elif use_data=='scherf':
DX = dataset.read_ftsv(open("../../data/scherf/Scherf.ftsv")) DX = dataset.read_ftsv(open("../../data/scherf/scherfX.ftsv"))
DY = dataset.read_ftsv(open("../../data/scherf/Yd.ftsv")) DY = dataset.read_ftsv(open("../../data/scherf/scherfY.ftsv"))
Y = DY.asarray().astype('d') Y = DY.asarray().astype('d')
gene_ids = DX.get_identifiers('gene_ids', sorted=True) gene_ids = DX.get_identifiers('gene_ids', sorted=True)
elif use_data=='staunton': elif use_data=='staunton':
pass pass
elif use_data=='uma': elif use_data=='uma':
DX = dataset.read_ftsv(open("../../data/uma/X133.ftsv")) DX = dataset.read_ftsv(open("../../data/uma/X133.ftsv"))
DY = dataset.read_ftsv(open("../../data/uma/Yg133.ftsv")) DYg = dataset.read_ftsv(open("../../data/uma/Yg133.ftsv"))
DY = dataset.read_ftsv(open("../../data/uma/Yd.ftsv"))
Y = DY.asarray().astype('d') Y = DY.asarray().astype('d')
gene_ids = DX.get_identifiers('gene_ids', sorted=True) gene_ids = DX.get_identifiers('gene_ids', sorted=True)
# Use only subset defined on GO # use subset with defined GO-terms
ontology = 'BP'
print "\n\nFiltering genes by Go terms "
gene2goterms = rpy_go.goterms_from_gene(gene_ids) gene2goterms = rpy_go.goterms_from_gene(gene_ids)
all_terms = set() all_terms = set()
for t in gene2goterms.values(): for t in gene2goterms.values():
all_terms.update(t) all_terms.update(t)
terms = list(all_terms) terms = list(all_terms)
print "\nNumber of go-terms: %s" %len(terms) print "\nNumber of go-terms: %s" %len(terms)
# update genelist # update genelist
gene_ids = gene2goterms.keys() gene_ids = gene2goterms.keys()
print "\nNumber of genes: %s" %len(gene_ids) print "\nNumber of genes: %s" %len(gene_ids)
X = DX.asarray()
index = DX.get_indices('gene_ids', gene_ids)
X = X[:,index]
1/0
# Use only subset defined on GO
ontology = 'BP'
print "\n\nFiltering genes by Go terms "
# use subset based on SAM or IQR # use subset based on SAM or IQR
subset = 'm' subset = 'not'
if subset=='sam': if subset=='sam':
# select subset genes by SAM # select subset genes by SAM
rpy.r.library("siggenes") rpy.r.library("siggenes")
rpy.r.library("qvalue") rpy.r.library("qvalue")
data = DX.asarray().T rpy.r.assign("data", X.T)
# data = data[:100,:] cl = dot(DY.asarray(), diag(arange(Y.shape[1])+1)).sum(1)
rpy.r.assign("data", data)
cl = dot(DY.asarray(), diag([1,2,3])).sum(1)
rpy.r.assign("cl", cl) rpy.r.assign("cl", cl)
rpy.r.assign("B", 20) rpy.r.assign("B", 20)
# Perform a SAM analysis. # Perform a SAM analysis.
@ -65,13 +72,21 @@ if subset=='sam':
qq = rpy.r('qobj<-qvalue(sam.out@p.value)') qq = rpy.r('qobj<-qvalue(sam.out@p.value)')
qvals = asarray(qq['qvalues']) qvals = asarray(qq['qvalues'])
# cut off # cut off
cutoff = 0.001 cutoff = 0.01
index = where(qvals<cutoff)[0] index = where(qvals<cutoff)[0]
# Subset data # Subset data
X = DX.asarray() X = X[:,index]
#Xr = X[:,index]
gene_ids = DX.get_identifiers('gene_ids', index) gene_ids = [gid for i, gid in enumerate(gene_ids) if i in index]
print "\nWorking on subset with %s genes " %len(gene_ids) print "\nWorking on subset with %s genes " %len(gene_ids)
# update valid go-terms
gene2goterms = rpy_go.goterms_from_gene(gene_ids)
all_terms = set()
for t in gene2goterms.values():
all_terms.update(t)
terms = list(all_terms)
print "\nNumber of go-terms: %s" %len(terms)
else: else:
# noimp (smoker data is prefiltered) # noimp (smoker data is prefiltered)
pass pass
@ -97,9 +112,9 @@ Xr = DX.asarray()[:,newind]
######## LPLSR ######## ######## LPLSR ########
print "LPLSR ..." print "LPLSR ..."
a_max = 5 a_max = 10
aopt = 3 aopt = 3
xz_alpha = .5 xz_alpha = .6
w_alpha = .1 w_alpha = .1
mean_ctr = [2, 0, 2] mean_ctr = [2, 0, 2]
@ -108,7 +123,7 @@ sdtz = False
if sdtz: if sdtz:
Z = Z/Z.std(0) Z = Z/Z.std(0)
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(Xr,Y,Z, a_max, T, W, P, Q, U, L, K, B, b0, evx, evy, evz,mnx,mny,mnz = nipals_lpls(Xr,Y,Z, a_max,
alpha=xz_alpha, alpha=xz_alpha,
mean_ctr=mean_ctr) mean_ctr=mean_ctr)
@ -118,11 +133,13 @@ dx,Ry,rssy = correlation_loadings(Y, T, Q)
cadz,Rz,rssz = correlation_loadings(Z.T, W, L) cadz,Rz,rssz = correlation_loadings(Z.T, W, L)
# Prediction error # Prediction error
rmsep , yhat, class_error = cv_lpls(Xr, Y, Z, a_max, alpha=xz_alpha,mean_ctr=mean_ctr) rmsep , yhat, class_error = cv_lpls(Xr, Y, Z, a_max, alpha=xz_alpha,mean_ctr=mean_ctr)
alpha_check=False alpha_check=True
if alpha_check: if alpha_check:
Alpha = arange(0.01, 1, .1) Alpha = arange(0.01, 1, .1)
Rmsep,Yhat, CE = [],[],[] Rmsep,Yhat, CE = [],[],[]
for a in Alpha: for a in Alpha:
print "alpha %f" %a
rmsep , yhat, ce = cv_lpls(Xr, Y, Z, a_max, alpha=xz_alpha,mean_ctr=mean_ctr) rmsep , yhat, ce = cv_lpls(Xr, Y, Z, a_max, alpha=xz_alpha,mean_ctr=mean_ctr)
Rmsep.append(rmsep) Rmsep.append(rmsep)
Yhat.append(yhat) Yhat.append(yhat)
@ -131,11 +148,12 @@ if alpha_check:
Yhat = asarray(Yhat) Yhat = asarray(Yhat)
CE = asarray(CE) CE = asarray(CE)
# Significance Hotellings T # Significance Hotellings T
Wx, Wz, Wy, = jk_lpls(Xr, Y, Z, aopt, mean_ctr=mean_ctr,alpha=w_alpha) Wx, Wz, Wy, = jk_lpls(Xr, Y, Z, aopt, mean_ctr=mean_ctr,alpha=xz_alpha)
Ws = W*apply_along_axis(norm, 0, T) Ws = W*apply_along_axis(norm, 0, T)
tsqx = cx_stats.hotelling(Wx, Ws[:,:aopt]) tsqx = cx_stats.hotelling(Wx, Ws[:,:aopt], alpha=w_alpha)
tsqz = cx_stats.hotelling(Wz, L[:,:aopt]) tsqz = cx_stats.hotelling(Wz, L[:,:aopt], alpha=0)
## plots ## ## plots ##
@ -156,12 +174,12 @@ title('Classification accuracy')
figure(3) # Hypoid correlations figure(3) # Hypoid correlations
tsqz_s = 250*tsqz/tsqz.max() tsqz_s = 250*tsqz/tsqz.max()
plot_corrloads(Rz, pc1=0, pc2=1, s=tsqz_s, c='b', zorder=5, expvar=evz, ax=None) plot_corrloads(Rz, pc1=0, pc2=1, s=tsqz_s, c=tsqz, zorder=5, expvar=evz, ax=None,alpha=.5)
ax = gca() ax = gca()
ylabels = DY.get_identifiers('_status', sorted=True) ylabels = DY.get_identifiers(DY.get_dim_name()[1], sorted=True)
plot_corrloads(Ry, pc1=0, pc2=1, s=150, c='g', zorder=5, expvar=evy, ax=ax,labels=ylabels) plot_corrloads(Ry, pc1=0, pc2=1, s=150, c='g', zorder=5, expvar=evy, ax=ax,labels=ylabels,alpha=.5)
figure(3) figure(4)
subplot(221) subplot(221)
ax = gca() ax = gca()
plot_corrloads(Rx, pc1=0, pc2=1, s=tsqx/2.0, c='b', zorder=5, expvar=evx, ax=ax) plot_corrloads(Rx, pc1=0, pc2=1, s=tsqx/2.0, c='b', zorder=5, expvar=evx, ax=ax)