merge fix
This commit is contained in:
		| @@ -101,6 +101,7 @@ def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'): | |||||||
|                  |                  | ||||||
|     rmsep = sqrt(err).mean(1) # take mean over samples |     rmsep = sqrt(err).mean(1) # take mean over samples | ||||||
|     aopt = rmsep.mean(-1).argmin() |     aopt = rmsep.mean(-1).argmin() | ||||||
|  |      | ||||||
|     return rmsep, xhat, aopt, err |     return rmsep, xhat, aopt, err | ||||||
|  |  | ||||||
| def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False): | def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False): | ||||||
| @@ -145,7 +146,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False): | |||||||
|     if nsets > X.shape[0]: |     if nsets > X.shape[0]: | ||||||
|         print "nsets (%d) is larger than number of variables (%d).\nnsets:  %d -> %d" %(nsets, m, nsets, m) |         print "nsets (%d) is larger than number of variables (%d).\nnsets:  %d -> %d" %(nsets, m, nsets, m) | ||||||
|         nsets = m |         nsets = m | ||||||
|     if n > 5*m: |     if n > 15*m: | ||||||
|         # boosting (wide x) |         # boosting (wide x) | ||||||
|         Yhat = _w_pls_predict(X, Y, a_max) |         Yhat = _w_pls_predict(X, Y, a_max) | ||||||
|          |          | ||||||
| @@ -160,7 +161,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False): | |||||||
|          |          | ||||||
|         # predictions |         # predictions | ||||||
|         for a in range(a_max): |         for a in range(a_max): | ||||||
|             Yhat[a,val,:] = ym + dot(xi, dat['B'][a]) |             Yhat[a,val] = ym + dot(xi, dat['B'][a]) | ||||||
|      |      | ||||||
|     sep = (Y - Yhat)**2 |     sep = (Y - Yhat)**2 | ||||||
|     rmsep = sqrt(sep.mean(1)).T |     rmsep = sqrt(sep.mean(1)).T | ||||||
| @@ -236,9 +237,14 @@ def lpls_val(X, Y, Z, a_max=2, nsets=None,alpha=.5, center_axis=[2,0,2], zorth=F | |||||||
|                    center_axis=center_axis, zorth=zorth, verbose=verbose) |                    center_axis=center_axis, zorth=zorth, verbose=verbose) | ||||||
|  |  | ||||||
|         # center test data |         # center test data | ||||||
|         xi = X[val,:] - dat['mnx'] |         if center_axis[0] == 2: | ||||||
|          |             xi = X[val,:] - dat['mnx'] - X[val,:].mean(1)[:,newaxis] + dat['mnx'].mean() | ||||||
|         ym = dat['mny'][val,:] |         else: | ||||||
|  |             xi = X[val,:] - dat['mnx'] | ||||||
|  |         if center_axis[1] == 2: | ||||||
|  |             ym = dat['mny'] + Y[val,:].mean(1)[:,newaxis] - dat['mny'].mean() | ||||||
|  |         else: | ||||||
|  |             ym = dat['mny'] | ||||||
|         # predictions |         # predictions | ||||||
|         for a in range(a_max): |         for a in range(a_max): | ||||||
|             Yhat[a,val,:] = atleast_2d(ym + dot(xi, dat['B'][a])) |             Yhat[a,val,:] = atleast_2d(ym + dot(xi, dat['B'][a])) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user