2007-01-25 12:58:10 +01:00
|
|
|
"""This module implements some common validation schemes from pca and pls.
|
|
|
|
"""
|
2006-12-18 12:59:12 +01:00
|
|
|
from scipy import ones,mean,sqrt,dot,newaxis,zeros,sum,empty,\
|
2007-01-25 12:58:10 +01:00
|
|
|
apply_along_axis,eye,kron,array,sort
|
|
|
|
from scipy.stats import median
|
2006-12-18 12:59:12 +01:00
|
|
|
from scipy.linalg import triu,inv,svd,norm
|
|
|
|
|
|
|
|
from select_generators import w_pls_gen,w_pls_gen_jk,pls_gen,pca_gen,diag_pert
|
2007-07-23 20:07:10 +02:00
|
|
|
from engines import w_simpls,pls,bridge,pca,nipals_lpls
|
2007-01-25 12:58:10 +01:00
|
|
|
from cx_utils import m_shape
|
2006-12-18 12:59:12 +01:00
|
|
|
|
|
|
|
def w_pls_cv_val(X, Y, amax, n_blocks=None, algo='simpls'):
|
2007-03-14 17:33:54 +01:00
|
|
|
"""Returns rmsep and aopt for pls tailored for wide X.
|
|
|
|
|
2007-07-23 19:33:21 +02:00
|
|
|
The root mean square error of cross validation is calculated
|
|
|
|
based on random block cross-validation. With number of blocks equal to
|
|
|
|
number of samples [default] gives leave-one-out cv.
|
|
|
|
The pls model is based on the simpls algorithm for wide X.
|
|
|
|
|
|
|
|
:Parameters:
|
|
|
|
X : ndarray
|
|
|
|
column centered data matrix of size (samples x variables)
|
|
|
|
Y : ndarray
|
|
|
|
column centered response matrix of size (samples x responses)
|
|
|
|
amax : scalar
|
|
|
|
Maximum number of components
|
|
|
|
n_blocks : scalar
|
|
|
|
Number of blocks in cross validation
|
|
|
|
|
|
|
|
:Returns:
|
|
|
|
rmsep : ndarray
|
|
|
|
Root Mean Square Error of cross-validated Predictions
|
|
|
|
aopt : scalar
|
|
|
|
Guestimate of the optimal number of components
|
|
|
|
|
|
|
|
:SeeAlso:
|
|
|
|
- pls_cv_val : Same output, not optimised for wide X
|
|
|
|
- w_simpls : Simpls algorithm for wide X
|
|
|
|
|
|
|
|
Notes
|
|
|
|
-----
|
|
|
|
Based (cowardly translated) on m-files from the Chemoact toolbox
|
|
|
|
X, Y inputs need to be centered (fixme: check)
|
|
|
|
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2007-07-23 19:33:21 +02:00
|
|
|
Examples
|
|
|
|
--------
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2007-07-23 19:33:21 +02:00
|
|
|
>>> import numpy as n
|
|
|
|
>>> X = n.array([[1., 2., 3.],[]])
|
|
|
|
>>> Y = n.array([[1., 2., 3.],[]])
|
|
|
|
>>> w_pls(X, Y, 1)
|
|
|
|
[4,5,6], 1
|
2006-12-18 12:59:12 +01:00
|
|
|
"""
|
2007-07-23 19:33:21 +02:00
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
k, l = m_shape(Y)
|
2006-12-18 12:59:12 +01:00
|
|
|
PRESS = zeros((l, amax+1), dtype='f')
|
|
|
|
if n_blocks==None:
|
|
|
|
n_blocks = Y.shape[0]
|
2007-01-25 12:58:10 +01:00
|
|
|
XXt = dot(X, X.T)
|
|
|
|
V = w_pls_gen(XXt, Y, n_blocks=n_blocks, center=True)
|
2006-12-18 12:59:12 +01:00
|
|
|
for Din, Doi, Yin, Yout in V:
|
|
|
|
ym = -sum(Yout, 0)[newaxis]/(1.0*Yin.shape[0])
|
|
|
|
Yin = Yin - ym
|
|
|
|
PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)
|
|
|
|
if algo=='simpls':
|
|
|
|
dat = w_simpls(Din, Yin, amax)
|
2007-01-25 12:58:10 +01:00
|
|
|
Q, U, H = dat['Q'], dat['U'], dat['H']
|
2007-07-23 19:33:21 +02:00
|
|
|
That = dot(Doi, dot(U, inv(triu(dot(H.T, U))) ))
|
2006-12-18 12:59:12 +01:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2006-12-18 12:59:12 +01:00
|
|
|
Yhat = []
|
|
|
|
for j in range(l):
|
|
|
|
TQ = dot(That, triu(dot(Q[j,:][:,newaxis], ones((1,amax)))) )
|
|
|
|
E = Yout[:,j][:,newaxis] - TQ
|
|
|
|
E = E + sum(E, 0)/Din.shape[0]
|
|
|
|
PRESS[j,1:] = PRESS[j,1:] + sum(E**2, 0)
|
2007-07-23 19:33:21 +02:00
|
|
|
Yhat = Y - dot(That,Q.T)
|
2007-01-25 12:58:10 +01:00
|
|
|
rmsep = sqrt(PRESS/Y.shape[0])
|
|
|
|
aopt = find_aopt_from_sep(rmsep)
|
2007-07-23 19:33:21 +02:00
|
|
|
return rmsep, Yhat, aopt
|
2006-12-18 12:59:12 +01:00
|
|
|
|
2007-03-14 17:33:54 +01:00
|
|
|
def pls_val(X, Y, amax=2, n_blocks=10, algo='pls', metric=None):
|
2007-07-23 19:33:21 +02:00
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
k, l = m_shape(Y)
|
2006-12-18 12:59:12 +01:00
|
|
|
PRESS = zeros((l, amax+1), dtype='<f8')
|
|
|
|
EE = zeros((amax, k, l), dtype='<f8')
|
|
|
|
Yhat = zeros((amax, k, l), dtype='<f8')
|
2007-03-14 17:33:54 +01:00
|
|
|
V = pls_gen(X, Y, n_blocks=n_blocks, center=True, index_out=True, metric=metric)
|
2006-12-18 12:59:12 +01:00
|
|
|
for Xin, Xout, Yin, Yout, out in V:
|
|
|
|
ym = -sum(Yout,0)[newaxis]/Yin.shape[0]
|
|
|
|
Yin = (Yin - ym)
|
|
|
|
PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)
|
2007-01-25 12:58:10 +01:00
|
|
|
|
2006-12-18 12:59:12 +01:00
|
|
|
if algo=='pls':
|
|
|
|
dat = pls(Xin, Yin, amax, mode='normal')
|
|
|
|
elif algo=='bridge':
|
|
|
|
dat = simpls(Xin, Yin, amax, mode='normal')
|
|
|
|
|
|
|
|
for a in range(amax):
|
|
|
|
Ba = dat['B'][a,:,:]
|
|
|
|
Yhat[a,out[:],:] = dot(Xout, Ba)
|
|
|
|
E = Yout - dot(Xout, Ba)
|
|
|
|
EE[a,out,:] = E
|
|
|
|
PRESS[:,a+1] = PRESS[:,a+1] + sum(E**2,0)
|
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
rmsep = sqrt(PRESS/(k-1.))
|
|
|
|
aopt = find_aopt_from_sep(rmsep)
|
2007-07-23 19:33:21 +02:00
|
|
|
return rmsep, Yhat, aopt
|
|
|
|
|
|
|
|
def lpls_val(X, Y, Z, a_max=2, nsets=None,alpha=.5):
|
|
|
|
"""Performs crossvalidation to get generalisation error in lpls"""
|
2007-07-23 20:07:10 +02:00
|
|
|
cv_iter = pls_gen(X, Y, n_blocks=nsets,center=False,index_out=True)
|
2007-07-23 19:33:21 +02:00
|
|
|
k, l = Y.shape
|
|
|
|
Yhat = empty((a_max,k,l), 'd')
|
|
|
|
for i, (xcal,xi,ycal,yi,ind) in enumerate(cv_iter):
|
2007-07-23 20:07:10 +02:00
|
|
|
dat = nipals_lpls(xcal,ycal,Z,
|
|
|
|
a_max=a_max,
|
|
|
|
alpha=alpha,
|
|
|
|
mean_ctr=[2,0,1],
|
|
|
|
verbose=False)
|
|
|
|
B = dat['B']
|
|
|
|
b0 = dat['b0']
|
2007-07-23 19:33:21 +02:00
|
|
|
for a in range(a_max):
|
|
|
|
Yhat[a,ind,:] = b0[a][0][0] + dot(xi, B[a])
|
|
|
|
Yhat_class = zeros_like(Yhat)
|
|
|
|
for a in range(a_max):
|
|
|
|
for i in range(k):
|
|
|
|
Yhat_class[a,i,argmax(Yhat[a,i,:])]=1.0
|
|
|
|
class_err = 100*((Yhat_class+Y)==2).sum(1)/Y.sum(0).astype('d')
|
|
|
|
sep = (Y - Yhat)**2
|
|
|
|
rmsep = sqrt(sep.mean(1))
|
|
|
|
aopt = find_aopt_from_sep(rmsep)
|
|
|
|
return rmsep, Yhat, aopt
|
2006-12-18 12:59:12 +01:00
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
def pca_alter_val(a, amax, n_sets=10, method='diag'):
|
2006-12-18 12:59:12 +01:00
|
|
|
"""Pca validation by altering elements in X.
|
2007-03-14 17:33:54 +01:00
|
|
|
|
|
|
|
comments:
|
|
|
|
-- may do all jk estimates in this loop
|
2006-12-18 12:59:12 +01:00
|
|
|
"""
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2006-12-18 12:59:12 +01:00
|
|
|
V = diag_pert(a, n_sets, center=True, index_out=True)
|
|
|
|
sep = empty((n_sets, amax), dtype='f')
|
|
|
|
for i, (xi, ind) in enumerate(V):
|
|
|
|
dat_i = pca(xi, amax, mode='detailed')
|
2007-03-14 17:33:54 +01:00
|
|
|
Ti, Pi = dat_i['T'],dat_i['P']
|
2006-12-18 12:59:12 +01:00
|
|
|
for j in xrange(amax):
|
|
|
|
Xhat = dot(Ti[:,:j+1], Pi[:,:j+1].T)
|
|
|
|
a_sub = a.ravel().take(ind)
|
|
|
|
EE = a_sub - Xhat.ravel().take(ind)
|
|
|
|
tot = (a_sub**2).sum()
|
|
|
|
sep[i,j] = (EE**2).sum()/tot
|
2007-01-25 12:58:10 +01:00
|
|
|
sep = sqrt(sep)
|
|
|
|
aopt = find_aopt_from_sep(sep)
|
|
|
|
return sep, aopt
|
|
|
|
|
|
|
|
def pca_cv_val(a, amax, n_sets):
|
|
|
|
""" Returns PRESS from cross-validated pca using random segments.
|
2006-12-18 12:59:12 +01:00
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
input:
|
|
|
|
-- a, data matrix (m x n)
|
|
|
|
-- amax, maximum nuber of components used
|
|
|
|
-- n_sets, number of segments to calculate
|
|
|
|
output:
|
|
|
|
-- sep, (amax x m x n), squared error of prediction (press)
|
|
|
|
-- aopt, guestimated optimal number of components
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2006-12-18 12:59:12 +01:00
|
|
|
"""
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
m, n = a.shape
|
2006-12-18 12:59:12 +01:00
|
|
|
E = empty((amax, m, n), dtype='f')
|
2007-01-25 12:58:10 +01:00
|
|
|
xtot = (a**2).sum() # this needs centering
|
|
|
|
V = pca_gen(a, n_sets=7, center=True, index_out=True)
|
|
|
|
for xi, xout, ind in V:
|
|
|
|
dat_i = pca(xi, amax, mode='fast')
|
2006-12-18 12:59:12 +01:00
|
|
|
Pi = dat_i['P']
|
|
|
|
for a in xrange(amax):
|
|
|
|
Pia = Pi[:,:a+1]
|
|
|
|
E[a][ind,:] = (X[ind,:] - dot(xout, dot(Pia,Pia.T) ))**2
|
|
|
|
|
|
|
|
sep = []
|
|
|
|
for a in xrange(amax):
|
|
|
|
sep.append(E[a].sum()/xtot)
|
2007-01-25 12:58:10 +01:00
|
|
|
sep = array(sep)
|
|
|
|
aopt = find_aopt_from_sep(sep)
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
return sep, aopt
|
2006-12-18 12:59:12 +01:00
|
|
|
|
2007-03-14 17:33:54 +01:00
|
|
|
def pls_jkW(a, b, amax, n_blocks=None, algo='pls', use_pack=True, center=True, metric=None):
|
2006-12-18 12:59:12 +01:00
|
|
|
""" Returns CV-segments of paramter W for wide X.
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2006-12-18 12:59:12 +01:00
|
|
|
todo: add support for T,Q and B
|
|
|
|
"""
|
|
|
|
if n_blocks == None:
|
|
|
|
n_blocks = b.shape[0]
|
|
|
|
|
2007-07-23 19:33:21 +02:00
|
|
|
Wcv = empty((n_blocks, a.shape[1], amax), dtype='d')
|
2007-03-14 17:33:54 +01:00
|
|
|
if use_pack and metric==None:
|
2006-12-18 12:59:12 +01:00
|
|
|
u, s, inflater = svd(a, full_matrices=0)
|
|
|
|
a = u*s
|
2007-03-14 17:33:54 +01:00
|
|
|
|
|
|
|
V = pls_gen(a, b, n_blocks=n_blocks, center=center, metric=metric)
|
2006-12-18 12:59:12 +01:00
|
|
|
for nn,(a_in, a_out, b_in, b_out) in enumerate(V):
|
|
|
|
if algo=='pls':
|
|
|
|
dat = pls(a_in, b_in, amax, 'loads', 'fast')
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2006-12-18 12:59:12 +01:00
|
|
|
elif algo=='bridge':
|
|
|
|
dat = bridge(a_in, b_in, amax, 'loads', 'fast')
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2006-12-18 12:59:12 +01:00
|
|
|
W = dat['W']
|
2007-03-14 17:33:54 +01:00
|
|
|
if use_pack and metric==None:
|
2006-12-18 12:59:12 +01:00
|
|
|
W = dot(inflater.T, W)
|
2007-01-31 12:59:23 +01:00
|
|
|
|
2007-07-23 19:33:21 +02:00
|
|
|
Wcv[nn,:,:] = W[:,:,]
|
2006-12-18 12:59:12 +01:00
|
|
|
|
2007-01-31 12:59:23 +01:00
|
|
|
return Wcv
|
2006-12-18 12:59:12 +01:00
|
|
|
|
2007-03-14 17:33:54 +01:00
|
|
|
def pca_jkP(a, aopt, n_blocks=None, metric=None):
|
2007-01-25 12:58:10 +01:00
|
|
|
"""Returns loading from PCA on CV-segments.
|
|
|
|
|
|
|
|
input:
|
|
|
|
-- a, data matrix (n x m)
|
|
|
|
-- aopt, number of components in model.
|
|
|
|
-- nblocks, number of segments
|
|
|
|
output:
|
|
|
|
-- PP, loadings collected in a three way matrix
|
|
|
|
(n_segments, m, aopt)
|
|
|
|
|
|
|
|
comments:
|
|
|
|
* The loadings are scaled with the (1/samples)*eigenvalues.
|
|
|
|
* Crossvalidation method is currently set to random blocks of samples.
|
|
|
|
|
2006-12-18 12:59:12 +01:00
|
|
|
todo: add support for T
|
|
|
|
fixme: more efficient to add this in validation loop
|
|
|
|
"""
|
|
|
|
if n_blocks == None:
|
|
|
|
n_blocks = a.shape[0]
|
|
|
|
|
|
|
|
PP = empty((n_blocks, a.shape[1], aopt), dtype='f')
|
|
|
|
V = pca_gen(a, n_sets=n_blocks, center=True)
|
|
|
|
for nn,(a_in, a_out) in enumerate(V):
|
2007-01-25 12:58:10 +01:00
|
|
|
dat = pca(a_in, aopt, mode='fast', scale='loads')
|
2006-12-18 12:59:12 +01:00
|
|
|
P = dat['P']
|
|
|
|
PP[nn,:,:] = P
|
|
|
|
|
|
|
|
return PP
|
2007-01-25 12:58:10 +01:00
|
|
|
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2007-07-23 19:33:21 +02:00
|
|
|
def lpls_jk(X, Y, Z, a_max, nsets=None, alpha=.5):
|
2007-07-23 20:07:10 +02:00
|
|
|
cv_iter = pls_gen(X, Y, n_blocks=nsets,center=False,index_out=False)
|
2007-07-23 19:33:21 +02:00
|
|
|
m, n = X.shape
|
|
|
|
k, l = Y.shape
|
|
|
|
o, p = Z.shape
|
|
|
|
if nsets==None:
|
|
|
|
nsets = m
|
|
|
|
WWx = empty((nsets, n, a_max), 'd')
|
|
|
|
WWz = empty((nsets, o, a_max), 'd')
|
|
|
|
#WWy = empty((nsets, l, a_max), 'd')
|
|
|
|
for i, (xcal,xi,ycal,yi) in enumerate(cv_iter):
|
2007-07-23 20:07:10 +02:00
|
|
|
dat = nipals_lpls(xcal,ycal,Z,a_max=a_max,alpha=alpha,
|
|
|
|
mean_ctr=[2,0,1],scale='loads',verbose=False)
|
|
|
|
WWx[i,:,:] = dat['W']
|
|
|
|
WWz[i,:,:] = dat['L']
|
|
|
|
#WWy[i,:,:] = dat['Q']
|
2007-07-23 19:33:21 +02:00
|
|
|
|
|
|
|
return WWx, WWz
|
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
def find_aopt_from_sep(sep, method='75perc'):
|
|
|
|
"""Returns an estimate of optimal number of components from rmsecv.
|
|
|
|
"""
|
2007-03-14 17:33:54 +01:00
|
|
|
|
2007-01-25 12:58:10 +01:00
|
|
|
if method=='vanilla':
|
|
|
|
# min rmsep
|
|
|
|
rmsecv = sqrt(sep.mean(0))
|
|
|
|
return rmsecv.argmin() + 1
|
|
|
|
|
|
|
|
elif method=='75perc':
|
|
|
|
prct = .75 #percentile
|
|
|
|
ind = 1.*sep.shape[0]*prct
|
|
|
|
med = median(sep)
|
|
|
|
prc_75 = []
|
|
|
|
for col in sep.T:
|
|
|
|
col.sort()
|
|
|
|
prc_75.append(col[int(ind)])
|
|
|
|
prc_75 = array(prc_75)
|
|
|
|
for i in range(1, sep.shape[1], 1):
|
|
|
|
if med[i-1]<prc_75[i]:
|
|
|
|
return i
|
|
|
|
return len(med)
|