merge fix
This commit is contained in:
parent
7e61e585fc
commit
970454faed
@ -101,6 +101,7 @@ def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'):
|
|||||||
|
|
||||||
rmsep = sqrt(err).mean(1) # take mean over samples
|
rmsep = sqrt(err).mean(1) # take mean over samples
|
||||||
aopt = rmsep.mean(-1).argmin()
|
aopt = rmsep.mean(-1).argmin()
|
||||||
|
|
||||||
return rmsep, xhat, aopt, err
|
return rmsep, xhat, aopt, err
|
||||||
|
|
||||||
def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
|
def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
|
||||||
@ -145,7 +146,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
|
|||||||
if nsets > X.shape[0]:
|
if nsets > X.shape[0]:
|
||||||
print "nsets (%d) is larger than number of variables (%d).\nnsets: %d -> %d" %(nsets, m, nsets, m)
|
print "nsets (%d) is larger than number of variables (%d).\nnsets: %d -> %d" %(nsets, m, nsets, m)
|
||||||
nsets = m
|
nsets = m
|
||||||
if n > 5*m:
|
if n > 15*m:
|
||||||
# boosting (wide x)
|
# boosting (wide x)
|
||||||
Yhat = _w_pls_predict(X, Y, a_max)
|
Yhat = _w_pls_predict(X, Y, a_max)
|
||||||
|
|
||||||
@ -160,7 +161,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
|
|||||||
|
|
||||||
# predictions
|
# predictions
|
||||||
for a in range(a_max):
|
for a in range(a_max):
|
||||||
Yhat[a,val,:] = ym + dot(xi, dat['B'][a])
|
Yhat[a,val] = ym + dot(xi, dat['B'][a])
|
||||||
|
|
||||||
sep = (Y - Yhat)**2
|
sep = (Y - Yhat)**2
|
||||||
rmsep = sqrt(sep.mean(1)).T
|
rmsep = sqrt(sep.mean(1)).T
|
||||||
@ -236,9 +237,14 @@ def lpls_val(X, Y, Z, a_max=2, nsets=None,alpha=.5, center_axis=[2,0,2], zorth=F
|
|||||||
center_axis=center_axis, zorth=zorth, verbose=verbose)
|
center_axis=center_axis, zorth=zorth, verbose=verbose)
|
||||||
|
|
||||||
# center test data
|
# center test data
|
||||||
xi = X[val,:] - dat['mnx']
|
if center_axis[0] == 2:
|
||||||
|
xi = X[val,:] - dat['mnx'] - X[val,:].mean(1)[:,newaxis] + dat['mnx'].mean()
|
||||||
ym = dat['mny'][val,:]
|
else:
|
||||||
|
xi = X[val,:] - dat['mnx']
|
||||||
|
if center_axis[1] == 2:
|
||||||
|
ym = dat['mny'] + Y[val,:].mean(1)[:,newaxis] - dat['mny'].mean()
|
||||||
|
else:
|
||||||
|
ym = dat['mny']
|
||||||
# predictions
|
# predictions
|
||||||
for a in range(a_max):
|
for a in range(a_max):
|
||||||
Yhat[a,val,:] = atleast_2d(ym + dot(xi, dat['B'][a]))
|
Yhat[a,val,:] = atleast_2d(ym + dot(xi, dat['B'][a]))
|
||||||
|
Reference in New Issue
Block a user