This commit is contained in:
Arnar Flatberg 2008-01-08 11:29:23 +00:00
parent c58f0ef453
commit 7e61e585fc
2 changed files with 54 additions and 49 deletions

View File

@ -6,10 +6,10 @@ __all__ = ['pca_val', 'pls_val', 'lpls_val', 'pls_jk', 'lpls_jk']
__docformat__ = "restructuredtext en"
from numpy import dot,empty,zeros,sqrt,atleast_2d,argmax,asarray,median,\
array_split,arange, isnan, any,newaxis
array_split,arange, isnan, any,newaxis,eye,diag,tile,ones
from numpy.random import shuffle
from engines import pls, pca
from engines import pls, pca, center
from engines import nipals_lpls as lpls
def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'):
@ -33,8 +33,7 @@ def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'):
Squared error of prediction for each component and xvar (a_max, m)
xhat : {array}
Crossvalidated predicted a (a_max, m, n)
aopt : {integer}
Estimate of optimal number of components
aopt : {integer} Estimate of optimal number of components
*Notes*:
@ -50,59 +49,63 @@ def pca_val(a, a_max, nsets=None, center_axis=[0], method='cv'):
n, m = a.shape
if nsets == None:
nsets = n
err = zeros((a_max, n, m), dtype=a.dtype)
err_mn = zeros((a_max, n, m), dtype=a.dtype)
xhat = zeros((a_max, n, m), dtype=a.dtype)
err = zeros((a_max + 1, n, m), dtype=a.dtype)
xhat = zeros((a_max + 1, n, m), dtype=a.dtype)
if center_axis[0] == 1:
a = a - a.mean(1)[:,newaxis]
center_axis = [-1]
elif center_axis[0] == 2:
a = a - a.mean(1)[:,newaxis]
center_axis = [0]
if method == 'diag':
mn_a = .5*(a.mean(0) + a.mean(1)[:,newaxis])
mn_a = dot(ones((n,1)), a.mean(0)[newaxis])
for i, val in enumerate(diag_cv(a.shape, nsets)):
old_values = a.take(val)
true_values = a.take(val)
new_values = mn_a.take(val)
# impute with mean values
b = a.copy()
a.put(val, new_values)
dat = pca(a, a_max, mode='fast', center_axis=center_axis)
Ti, Pi = dat['T'], dat['P']
bc = b - dat['mnx']
bc2 = b - b.mean(0)
Ti, Pi, mnx = dat['T'], dat['P'], dat['mnx']
# expand mean values
if center_axis[0] == 1:
mnx = tile(mnx, (1, m))
else:
mnx = tile(mnx, (n, 1))
err[0,:,:].put(val, (true_values - mnx.take(val))**2)
for j in xrange(a_max):
# estimate the imputed values
a_pred = dot(Ti[:,:j+1], Pi[:,:j+1].T).take(val)
a_true = bc2.take(val)
err[j,:,:].put(val, (a_true - a_pred)**2)
err_mn[j,:,:].put(val, (bc.take(val) - a_pred)**2)
a_pred = (mnx + dot(Ti[:,:j+1], Pi[:,:j+1].T)).take(val)
err[j+1,:,:].put(val, (true_values - a_pred)**2)
xhat[j,:,:].put(val, a_pred)
# put original values back
a.put(val, old_values)
a.put(val, true_values)
elif method == 'cv':
for i, (cal, val) in enumerate(cv(n, nsets)):
xval = atleast_2d(x[val,:])
xcal = x[cal, :]
P = pca(xcal, aopt, mode='fast', scale='scores')['P']
xcal = a[cal, :]
dat = pca(xcal, a_max, mode='fast', scale='scores', center_axis=center_axis)
P, mnx = dat['P'], dat['mnx']
xval = atleast_2d(a[val,:]) - mnx
err[0,val,:] = xval**2 #pc0
e = eye(m)
rmat = zeros((m, m))
for j, p in enumerate(P.T):
#d2 = diag(e) - (p**2).ravel()
p = p[:,newaxis]
e = e - dot(p, p.T)
d = diag(e)
es = e/atleast_2d(d)
xhat[j,cal,:] = dot(xval, es)
err[j,cal,:] = (xhat - xval)**2
xhat[j+1,val,:] = dot(xval, es)
err[j+1,val,:] = (dot(xval, es)**2)
rmsep = sqrt(err).mean(1) # take mean over samples
if method == ''
rmsep2 = sqrt(err_mn).mean(1)
aopt = rmsep.mean(-1).argmin()
return rmsep, xhat, aopt, err, rmsep2
return rmsep, xhat, aopt, err
def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
"""Performs crossvalidation for generalisation error in pls.
*Parameters*:
X : {array}
@ -137,7 +140,6 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
m, n = X.shape
k, l = Y.shape
assert m == k, "X (%d,%d) - Y (%d,%d) dim mismatch" %(m, n, k, l)
assert n == p, "X (%d,%d) - Z (%d,%d) dim mismatch" %(m, n, o, p)
if nsets == None:
nsets = m
if nsets > X.shape[0]:
@ -150,7 +152,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
Yhat = empty((a_max, k, l), dtype=dt)
for cal, val in cv(k, nsets):
# do the training model
dat = pls(X[cal], Y[cal], a_max=a_max,center_axis=center_axis)
dat = pls(X[cal], Y[cal], aopt=a_max,center_axis=center_axis)
# center test data
xi = X[val,:] - dat['mnx']
@ -165,7 +167,7 @@ def pls_val(X, Y, a_max=2, nsets=None, center_axis=[0,0], verbose=False):
#aopt = find_aopt_from_sep(rmsep)
# todo: need a better support for classification error
error = prediction_error(Yhat, Y, method='1/2')
error = prediction_error(Yhat, Y, method='squared')
return rmsep, Yhat, error
@ -282,8 +284,7 @@ def pca_jk(a, aopt, nsets=None, center_axis=[0], method='cv'):
*Notes*:
- .
Nope
"""
m, n = a.shape
if nsets == None:
@ -302,7 +303,6 @@ def pca_jk(a, aopt, nsets=None, center_axis=[0], method='cv'):
# put original values back
a.put(val, old_values)
elif method == 'cv':
print "using ....cv "
for i, (cal, val) in enumerate(cv(m, nsets)):
Pcv[i,:,:] = pca(a[cal,:], aopt, mode='fast', scale='loads', center_axis = center_axis)['P']
else:
@ -624,7 +624,7 @@ def prediction_error(y_hat, y, method='squared'):
return error
def _wkernel_pls_val(X, Y, a_max, n_blocks=None):
"""Returns rmsep and aopt for pls tailored for wide X.
"""Returns pls crossvalidated predictions tailored for wide X.
The error of cross validation is calculated
based on random block cross-validation. With number of blocks equal to
@ -672,19 +672,21 @@ def _wkernel_pls_val(X, Y, a_max, n_blocks=None):
[4,5,6], 1
"""
dt = X.dtype
k, l = m_shape(Y)
PRESS = zeros((l, a_max+1), dtype=dt)
k, l = atleast_2d(Y).shape
if k == 1:
Y = Y.T
k, l = Y.shape
PRESS = zeros((l, a_max + 1), dtype=dt)
if n_blocks==None:
n_blocks = Y.shape[0]
XXt = dot(X, X.T)
V = w_pls_gen(XXt, Y, n_blocks=n_blocks, center=True)
for Din, Doi, Yin, Yout in V:
ym = -sum(Yout, 0)[newaxis]/(1.0*Yin.shape[0])
PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)
for cal, val in cv(k, n_blocks):
ym = -sum(Y[val,:], 0)[newaxis]/(1.0*Y[cal,:].shape[0])
PRESS[:,0] = PRESS[:,0] + ((Y[val,:] - ym)**2).sum(0)
dat = w_simpls(Din, Yin, a_max)
dat = w_simpls(XX[cal,:][:,cal], Y[cal,:], a_max)
Q, U, H = dat['Q'], dat['U'], dat['H']
That = dot(Doi, dot(U, inv(triu(dot(H.T, U))) ))
That = dot(XX[val,:][:,cal], dot(U, inv(triu(dot(H.T, U))) ))
Yhat = zeros((a_max, k, l), dtype=dt)
for j in range(l):

View File

@ -10,8 +10,10 @@ from math import sqrt as msqrt
from numpy import dot,empty,zeros,apply_along_axis,newaxis,finfo,sqrt,r_,expand_dims,\
minimum,any,isnan,ones,tile
from numpy.linalg import inv,svd
from scipy.sandbox import arpack
try:
from scipy.sandbox import arpack
except:
import arpack
def pca(X, aopt, scale='scores', mode='normal', center_axis=[0]):
""" Principal Component Analysis.
@ -718,9 +720,10 @@ def center(a, axis):
mn = a.mean(1)[:,newaxis]
#mn = tile(mn, (1, a.shape[1]))
elif axis == 2:
#fixme: double centering returns column mean as loc-vector, ok?
mn = a.mean(0)[newaxis] + a.mean(1)[:,newaxis] - a.mean()
return a - mn , mn
# double centering returns row vec
# to get correct broadcasting in cv
mn = mn[0][newaxis]
else:
raise IOError("input error: axis must be in [-1,0,1,2]")