This commit is contained in:
2007-09-20 16:10:40 +00:00
parent d9e5398865
commit 7e9a0882f1
4 changed files with 168 additions and 59 deletions

View File

@@ -15,6 +15,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
X : data matrix (m, n)
Y : data matrix (m, l)
Z : data matrix (n, o)
alpha : how much z influence (1=max, 0=none)
:output:
T : X-scores
@@ -36,7 +37,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
if mean_ctr:
xctr, yctr, zctr = mean_ctr
X, mnX = center(X, xctr)
Y, mnY = center(Y, xctr)
Y, mnY = center(Y, yctr)
Z, mnZ = center(Z, zctr)
varX = pow(X, 2).sum()
@@ -116,7 +117,7 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', ve
T = T/tnorm
Q = Q*tnorm
W = W*tnorm
return T, W, P, Q, U, L, K, B, b0, evx, evy, evz
return T, W, P, Q, U, L, K, B, b0, evx, evy, evz, mnX, mnY, mnZ
def svd_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], verbose=True):
"""
@@ -306,8 +307,14 @@ def bifpls(X, Y, Z, a_max, alpha):
evz = 100.0*(1 - var_z/varZ)
def center(a, axis):
# 0 = col center, 1 = row center, 2 = double center
# -1 = nothing
# 0 = col center, 1 = row center, 2 = double center
# -1 = nothing
if len(a.shape)==1:
mn = a.mean()
return a - mn, mn
if a.shape[0]==1 or a.shape[1]==1:
mn = a.mean()
return a - mn, mn
if axis==-1:
mn = zeros((a.shape[1],))
return a - mn, mn
@@ -318,7 +325,7 @@ def center(a, axis):
mn = a.mean(1)[:,newaxis]
return a - mn , mn
elif axis==2:
mn = a.mean(0) + a.mean(1)[:,newaxis] - a.mean()
mn = a.mean(1)[:,newaxis] + a.mean(0) - a.mean()
return a - mn, mn
else:
raise IOError("input error: axis must be in [-1,0,1,2]")
@@ -367,27 +374,47 @@ def correlation_loadings(D, T, P, test=True):
def cv_lpls(X, Y, Z, a_max=2, nsets=None,alpha=.5, mean_ctr=[2,0,1]):
"""Performs crossvalidation to get generalisation error in lpls"""
# if double centering of x or y:
# row-center prior to cross validation (as this is independent of subsets)
if mean_ctr[0]==2:
mnx_row = X.mean(1)[:,newaxis]
X = X - mnx_row
mean_ctr[0] = 0
else:
mnx_row = 0
if mean_ctr[1]==2:
if Y.shape[1]!=1:
mny_row = Y.mean(1)[:,newaxis]
Y = Y - mny_row
else:
mny_row = 0
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=True)
k, l = Y.shape
Yhat = empty((a_max,k,l), 'd')
for i, (xcal,xi,ycal,yi,ind) in enumerate(cv_iter):
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=alpha,
mean_ctr=mean_ctr,
verbose=False)
T, W, P, Q, U, L, K, B, b0, evx, evy, evz, mnx, mny, mnz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=alpha,
mean_ctr=mean_ctr,
verbose=False)
for a in range(a_max):
Yhat[a,ind,:] = b0[a][0][0] + dot(xi, B[a])
xc = xi - mnx
Yhat[a,ind,:] = mny + dot(xc, B[a])
Yhat_class = zeros_like(Yhat)
for a in range(a_max):
for i in range(k):
Yhat_class[a,i,argmax(Yhat[a,i,:])]=1.0
Yhat_class[a,i,argmax(Yhat[a,i,:])] = 1.0
class_err = 100*((Yhat_class+Y)==2).sum(1)/Y.sum(0).astype('d')
sep = (Y - Yhat)**2
rmsep = sqrt(sep.mean(1))
return rmsep, Yhat, class_err
def jk_lpls(X, Y, Z, a_max, nsets=None, alpha=.5, mean_ctr=[2,0,1]):
def jk_lpls(X, Y, Z, a_max, nsets=None, xz_alpha=.5, mean_ctr=[2,0,1]):
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=False)
m, n = X.shape
k, l = Y.shape
@@ -398,12 +425,12 @@ def jk_lpls(X, Y, Z, a_max, nsets=None, alpha=.5, mean_ctr=[2,0,1]):
WWz = empty((nsets, o, a_max), 'd')
WWy = empty((nsets, l, a_max), 'd')
for i, (xcal,xi,ycal,yi) in enumerate(cv_iter):
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=alpha,
mean_ctr=mean_ctr,
scale='loads',
verbose=False)
T, W, P, Q, U, L, K, B, b0, evx, evy, evz,mnx,mny,mnz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=xz_alpha,
mean_ctr=mean_ctr,
scale='loads',
verbose=False)
WWx[i,:,:] = W
WWz[i,:,:] = L
WWy[i,:,:] = Q