merge fix

This commit is contained in:
Arnar Flatberg 2008-01-08 11:33:49 +00:00
parent 7e61e585fc
commit 970454faed

View File

@ -101,6 +101,7 @@ def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'):
rmsep = sqrt(err).mean(1) # take mean over samples rmsep = sqrt(err).mean(1) # take mean over samples
aopt = rmsep.mean(-1).argmin() aopt = rmsep.mean(-1).argmin()
return rmsep, xhat, aopt, err return rmsep, xhat, aopt, err
def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False): def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
@ -145,7 +146,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
if nsets > X.shape[0]: if nsets > X.shape[0]:
print "nsets (%d) is larger than number of variables (%d).\nnsets: %d -> %d" %(nsets, m, nsets, m) print "nsets (%d) is larger than number of variables (%d).\nnsets: %d -> %d" %(nsets, m, nsets, m)
nsets = m nsets = m
if n > 5*m: if n > 15*m:
# boosting (wide x) # boosting (wide x)
Yhat = _w_pls_predict(X, Y, a_max) Yhat = _w_pls_predict(X, Y, a_max)
@ -160,7 +161,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
# predictions # predictions
for a in range(a_max): for a in range(a_max):
Yhat[a,val,:] = ym + dot(xi, dat['B'][a]) Yhat[a,val] = ym + dot(xi, dat['B'][a])
sep = (Y - Yhat)**2 sep = (Y - Yhat)**2
rmsep = sqrt(sep.mean(1)).T rmsep = sqrt(sep.mean(1)).T
@ -236,9 +237,14 @@ def lpls_val(X, Y, Z, a_max=2, nsets=None,alpha=.5, center_axis=[2,0,2], zorth=F
center_axis=center_axis, zorth=zorth, verbose=verbose) center_axis=center_axis, zorth=zorth, verbose=verbose)
# center test data # center test data
if center_axis[0] == 2:
xi = X[val,:] - dat['mnx'] - X[val,:].mean(1)[:,newaxis] + dat['mnx'].mean()
else:
xi = X[val,:] - dat['mnx'] xi = X[val,:] - dat['mnx']
if center_axis[1] == 2:
ym = dat['mny'][val,:] ym = dat['mny'] + Y[val,:].mean(1)[:,newaxis] - dat['mny'].mean()
else:
ym = dat['mny']
# predictions # predictions
for a in range(a_max): for a in range(a_max):
Yhat[a,val,:] = atleast_2d(ym + dot(xi, dat['B'][a])) Yhat[a,val,:] = atleast_2d(ym + dot(xi, dat['B'][a]))