diff --git a/pyblm/crossvalidation.py b/pyblm/crossvalidation.py index 9d41250..16e4989 100644 --- a/pyblm/crossvalidation.py +++ b/pyblm/crossvalidation.py @@ -101,6 +101,7 @@ def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'): rmsep = sqrt(err).mean(1) # take mean over samples aopt = rmsep.mean(-1).argmin() + return rmsep, xhat, aopt, err def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False): @@ -145,7 +146,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False): if nsets > X.shape[0]: print "nsets (%d) is larger than number of variables (%d).\nnsets: %d -> %d" %(nsets, m, nsets, m) nsets = m - if n > 5*m: + if n > 15*m: # boosting (wide x) Yhat = _w_pls_predict(X, Y, a_max) @@ -160,7 +161,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False): # predictions for a in range(a_max): - Yhat[a,val,:] = ym + dot(xi, dat['B'][a]) + Yhat[a,val] = ym + dot(xi, dat['B'][a]) sep = (Y - Yhat)**2 rmsep = sqrt(sep.mean(1)).T @@ -236,9 +237,14 @@ def lpls_val(X, Y, Z, a_max=2, nsets=None,alpha=.5, center_axis=[2,0,2], zorth=F center_axis=center_axis, zorth=zorth, verbose=verbose) # center test data - xi = X[val,:] - dat['mnx'] - - ym = dat['mny'][val,:] + if center_axis[0] == 2: + xi = X[val,:] - dat['mnx'] - X[val,:].mean(1)[:,newaxis] + dat['mnx'].mean() + else: + xi = X[val,:] - dat['mnx'] + if center_axis[1] == 2: + ym = dat['mny'] + Y[val,:].mean(1)[:,newaxis] - dat['mny'].mean() + else: + ym = dat['mny'] # predictions for a in range(a_max): Yhat[a,val,:] = atleast_2d(ym + dot(xi, dat['B'][a]))