291 lines
9.3 KiB
291 lines
9.3 KiB
"""This module implements some common validation schemes from pca and pls.
from scipy import ones,mean,sqrt,dot,newaxis,zeros,sum,empty,\
from scipy.stats import median
from scipy.linalg import triu,inv,svd,norm
from select_generators import w_pls_gen,w_pls_gen_jk,pls_gen,pca_gen,diag_pert
from engines import w_simpls,pls,bridge,pca,nipals_lpls
from cx_utils import m_shape
def w_pls_cv_val(X, Y, amax, n_blocks=None, algo='simpls'):
"""Returns rmsep and aopt for pls tailored for wide X.
The root mean square error of cross validation is calculated
based on random block cross-validation. With number of blocks equal to
number of samples [default] gives leave-one-out cv.
The pls model is based on the simpls algorithm for wide X.
X : ndarray
column centered data matrix of size (samples x variables)
Y : ndarray
column centered response matrix of size (samples x responses)
amax : scalar
Maximum number of components
n_blocks : scalar
Number of blocks in cross validation
rmsep : ndarray
Root Mean Square Error of cross-validated Predictions
aopt : scalar
Guestimate of the optimal number of components
- pls_cv_val : Same output, not optimised for wide X
- w_simpls : Simpls algorithm for wide X
Based (cowardly translated) on m-files from the Chemoact toolbox
X, Y inputs need to be centered (fixme: check)
>>> import numpy as n
>>> X = n.array([[1., 2., 3.],[]])
>>> Y = n.array([[1., 2., 3.],[]])
>>> w_pls(X, Y, 1)
[4,5,6], 1
k, l = m_shape(Y)
PRESS = zeros((l, amax+1), dtype='f')
if n_blocks==None:
n_blocks = Y.shape[0]
XXt = dot(X, X.T)
V = w_pls_gen(XXt, Y, n_blocks=n_blocks, center=True)
for Din, Doi, Yin, Yout in V:
ym = -sum(Yout, 0)[newaxis]/(1.0*Yin.shape[0])
PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)
if algo=='simpls':
dat = w_simpls(Din, Yin, amax)
Q, U, H = dat['Q'], dat['U'], dat['H']
That = dot(Doi, dot(U, inv(triu(dot(H.T, U))) ))
raise NotImplementedError
Yhat = []
for j in range(l):
TQ = dot(That, triu(dot(Q[j,:][:,newaxis], ones((1,amax)))) )
E = Yout[:,j][:,newaxis] - TQ
E = E + sum(E, 0)/Din.shape[0]
PRESS[j,1:] = PRESS[j,1:] + sum(E**2, 0)
#Yhat = Yin - dot(That,Q.T)
msep = PRESS/(Y.shape[0])
aopt = find_aopt_from_sep(msep)
return sqrt(msep)
def pls_val(X, Y, amax=2, n_blocks=10, algo='pls', metric=None):
k, l = m_shape(Y)
PRESS = zeros((l, amax+1), dtype='<f8')
EE = zeros((amax, k, l), dtype='<f8')
Yhat = zeros((amax, k, l), dtype='<f8')
V = pls_gen(X, Y, n_blocks=n_blocks, center=True, index_out=True, metric=metric)
for Xin, Xout, Yin, Yout, out in V:
ym = -sum(Yout,0)[newaxis]/Yin.shape[0]
Yin = (Yin - ym)
PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)
if algo=='pls':
dat = pls(Xin, Yin, amax, mode='normal')
elif algo=='bridge':
dat = simpls(Xin, Yin, amax, mode='normal')
for a in range(amax):
Ba = dat['B'][a,:,:]
Yhat[a,out[:],:] = dot(Xout, Ba)
E = Yout - dot(Xout, Ba)
EE[a,out,:] = E
PRESS[:,a+1] = PRESS[:,a+1] + sum(E**2,0)
#rmsep = sqrt(PRESS/(k-1.))
msep = PRESS
aopt = find_aopt_from_sep(msep)
return msep, Yhat, aopt
def lpls_val(X, Y, Z, a_max=2, nsets=None,alpha=.5):
"""Performs crossvalidation to get generalisation error in lpls"""
cv_iter = pls_gen(X, Y, n_blocks=nsets,center=False,index_out=True)
k, l = Y.shape
Yhat = empty((a_max,k,l), 'd')
for i, (xcal,xi,ycal,yi,ind) in enumerate(cv_iter):
dat = nipals_lpls(xcal,ycal,Z,
B = dat['B']
b0 = dat['b0']
for a in range(a_max):
Yhat[a,ind,:] = b0[a][0][0] + dot(xi, B[a])
Yhat_class = zeros_like(Yhat)
for a in range(a_max):
for i in range(k):
class_err = 100*((Yhat_class+Y)==2).sum(1)/Y.sum(0).astype('d')
sep = (Y - Yhat)**2
rmsep = sqrt(sep.mean(1))
aopt = find_aopt_from_sep(rmsep)
return rmsep, Yhat, aopt
def pca_alter_val(a, amax, n_sets=10, method='diag'):
"""Pca validation by altering elements in X.
-- may do all jk estimates in this loop
V = diag_pert(a, n_sets, center=True, index_out=True)
sep = empty((n_sets, amax), dtype='f')
for i, (xi, ind) in enumerate(V):
dat_i = pca(xi, amax, mode='detailed')
Ti, Pi = dat_i['T'],dat_i['P']
for j in xrange(amax):
Xhat = dot(Ti[:,:j+1], Pi[:,:j+1].T)
a_sub = a.ravel().take(ind)
EE = a_sub - Xhat.ravel().take(ind)
tot = (a_sub**2).sum()
sep[i,j] = (EE**2).sum()/tot
sep = sqrt(sep)
aopt = find_aopt_from_sep(sep)
return sep, aopt
def pca_cv_val(a, amax, n_sets):
""" Returns PRESS from cross-validated pca using random segments.
-- a, data matrix (m x n)
-- amax, maximum nuber of components used
-- n_sets, number of segments to calculate
-- sep, (amax x m x n), squared error of prediction (press)
-- aopt, guestimated optimal number of components
m, n = a.shape
E = empty((amax, m, n), dtype='f')
xtot = (a**2).sum() # this needs centering
V = pca_gen(a, n_sets=7, center=True, index_out=True)
for xi, xout, ind in V:
dat_i = pca(xi, amax, mode='fast')
Pi = dat_i['P']
for a in xrange(amax):
Pia = Pi[:,:a+1]
E[a][ind,:] = (X[ind,:] - dot(xout, dot(Pia,Pia.T) ))**2
sep = []
for a in xrange(amax):
sep = array(sep)
aopt = find_aopt_from_sep(sep)
return sep, aopt
def pls_jkW(a, b, amax, n_blocks=None, algo='pls', use_pack=True, center=True, metric=None):
""" Returns CV-segments of paramter W for wide X.
todo: add support for T,Q and B
if n_blocks == None:
n_blocks = b.shape[0]
Wcv = empty((n_blocks, a.shape[1], amax), dtype='d')
if use_pack and metric==None:
u, s, inflater = svd(a, full_matrices=0)
a = u*s
V = pls_gen(a, b, n_blocks=n_blocks, center=center, metric=metric)
for nn,(a_in, a_out, b_in, b_out) in enumerate(V):
if algo=='pls':
dat = pls(a_in, b_in, amax, 'loads', 'fast')
elif algo=='bridge':
dat = bridge(a_in, b_in, amax, 'loads', 'fast')
W = dat['W']
if use_pack and metric==None:
W = dot(inflater.T, W)
Wcv[nn,:,:] = W[:,:,]
return Wcv
def pca_jkP(a, aopt, n_blocks=None, metric=None):
"""Returns loading from PCA on CV-segments.
-- a, data matrix (n x m)
-- aopt, number of components in model.
-- nblocks, number of segments
-- PP, loadings collected in a three way matrix
(n_segments, m, aopt)
* The loadings are scaled with the (1/samples)*eigenvalues.
* Crossvalidation method is currently set to random blocks of samples.
todo: add support for T
fixme: more efficient to add this in validation loop
if n_blocks == None:
n_blocks = a.shape[0]
PP = empty((n_blocks, a.shape[1], aopt), dtype='f')
V = pca_gen(a, n_sets=n_blocks, center=True)
for nn,(a_in, a_out) in enumerate(V):
dat = pca(a_in, aopt, mode='fast', scale='loads')
P = dat['P']
PP[nn,:,:] = P
return PP
def lpls_jk(X, Y, Z, a_max, nsets=None, alpha=.5):
cv_iter = pls_gen(X, Y, n_blocks=nsets,center=False,index_out=False)
m, n = X.shape
k, l = Y.shape
o, p = Z.shape
if nsets==None:
nsets = m
WWx = empty((nsets, n, a_max), 'd')
WWz = empty((nsets, o, a_max), 'd')
#WWy = empty((nsets, l, a_max), 'd')
for i, (xcal,xi,ycal,yi) in enumerate(cv_iter):
dat = nipals_lpls(xcal,ycal,Z,a_max=a_max,alpha=alpha,
WWx[i,:,:] = dat['W']
WWz[i,:,:] = dat['L']
#WWy[i,:,:] = dat['Q']
return WWx, WWz
def find_aopt_from_sep(sep, method='75perc'):
"""Returns an estimate of optimal number of components from rmsecv.
sep = sep.copy()
if method=='vanilla':
# min rmsep
rmsecv = sqrt(sep.mean(0))
return rmsecv.argmin() + 1
elif method=='75perc':
prct = .75 #percentile
ind = 1.*sep.shape[0]*prct
med = median(sep)
prc_75 = []
for col in sep.T:
col.sort() #this is inplace -> ruins sep, so we are doing a copy
prc_75 = array(prc_75)
for i in range(1, sep.shape[1], 1):
if med[i-1]<prc_75[i]:
return i
return len(med)