diff --git a/pyblm/crossvalidation.py b/pyblm/crossvalidation.py index 41afde5..50a0b5c 100644 --- a/pyblm/crossvalidation.py +++ b/pyblm/crossvalidation.py @@ -63,7 +63,7 @@ def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'): # impute with mean values b = a.copy() a.put(val, new_values) - dat = pca(a, a_max, mode='normal', center_axis=center_axis) + dat = pca(a, a_max, mode='fast', center_axis=center_axis) Ti, Pi = dat['T'], dat['P'] bc = b - dat['mnx'] bc2 = b - b.mean(0) @@ -85,14 +85,15 @@ def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'): e = eye(m) rmat = zeros((m, m)) for j, p in enumerate(P.T): - d2 = diag(e) - (p**2).ravel() + #d2 = diag(e) - (p**2).ravel() e = e - dot(p, p.T) d = diag(e) es = e/atleast_2d(d) - xhat[j,:,:] = dot(xval, es) - err[i, a] = (dot(xval, es)**2).sum() + xhat[j,cal,:] = dot(xval, es) + err[j,cal,:] = (xhat - xval)**2 rmsep = sqrt(err).mean(1) # take mean over samples + if method == '' rmsep2 = sqrt(err_mn).mean(1) aopt = rmsep.mean(-1).argmin() return rmsep, xhat, aopt, err, rmsep2