This commit is contained in:
2007-07-23 13:25:34 +00:00
parent 9db5991108
commit 155dfada5c
4 changed files with 31 additions and 17 deletions

View File

@@ -9,7 +9,7 @@ import select_generators
sys.path.remove("/home/flatberg/fluents/fluents/lib")
def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], verbose=True):
def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], scale='scores', verbose=True):
""" L-shaped Partial Least Sqaures Regression by the nipals algorithm.
(X!Z)->Y
@@ -113,7 +113,11 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], verbose=True):
evx = 100.0*(1 - var_x/varX)
evy = 100.0*(1 - var_y/varY)
evz = 100.0*(1 - var_z/varZ)
if scale=='loads':
tnorm = apply_along_axis(norm, 0, T)
T = T/tnorm
Q = Q*tnorm
W = W*tnorm
return T, W, P, Q, U, L, K, B, b0, evx, evy, evz
def svd_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], verbose=True):
@@ -307,7 +311,8 @@ def center(a, axis):
# 0 = col center, 1 = row center, 2 = double center
# -1 = nothing
if axis==-1:
return a
mn = zeros((a.shape[1],))
return a - mn, mn
elif axis==0:
mn = a.mean(0)
return a - mn, mn
@@ -364,14 +369,14 @@ def correlation_loadings(D, T, P, test=True):
def cv_lpls(X, Y, Z, a_max=2, nsets=None,alpha=.5):
"""Performs crossvalidation to get generalisation error in lpls"""
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=True,index_out=True)
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=True)
k, l = Y.shape
Yhat = empty((a_max,k,l), 'd')
for i, (xcal,xi,ycal,yi,ind) in enumerate(cv_iter):
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=alpha,
mean_ctr=[0,0,1],
mean_ctr=[2,0,1],
verbose=False)
for a in range(a_max):
Yhat[a,ind,:] = b0[a][0][0] + dot(xi, B[a])
@@ -385,7 +390,7 @@ def cv_lpls(X, Y, Z, a_max=2, nsets=None,alpha=.5):
return rmsep, Yhat, class_err
def jk_lpls(X, Y, Z, a_max, nsets=None, alpha=.5):
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=True,index_out=False)
cv_iter = select_generators.pls_gen(X, Y, n_blocks=nsets,center=False,index_out=False)
m, n = X.shape
k, l = Y.shape
o, p = Z.shape
@@ -398,7 +403,8 @@ def jk_lpls(X, Y, Z, a_max, nsets=None, alpha=.5):
T, W, P, Q, U, L, K, B, b0, evx, evy, evz = nipals_lpls(xcal,ycal,Z,
a_max=a_max,
alpha=alpha,
mean_ctr=[0,0,1],
mean_ctr=[2,0,1],
scale='loads',
verbose=False)
WWx[i,:,:] = W
WWz[i,:,:] = L