This repository has been archived on 2024-07-04. You can view files and clone it, but cannot push or open issues or pull requests.
laydi/fluents/lib/validation.py

225 lines
6.7 KiB
Python
Raw Normal View History

2007-01-25 12:58:10 +01:00
"""This module implements some common validation schemes from pca and pls.
"""
2006-12-18 12:59:12 +01:00
from scipy import ones,mean,sqrt,dot,newaxis,zeros,sum,empty,\
2007-01-25 12:58:10 +01:00
apply_along_axis,eye,kron,array,sort
from scipy.stats import median
2006-12-18 12:59:12 +01:00
from scipy.linalg import triu,inv,svd,norm
from select_generators import w_pls_gen,w_pls_gen_jk,pls_gen,pca_gen,diag_pert
2007-01-25 12:58:10 +01:00
from engines import w_simpls,pls,bridge,pca
from cx_utils import m_shape
2006-12-18 12:59:12 +01:00
def w_pls_cv_val(X, Y, amax, n_blocks=None, algo='simpls'):
2007-03-14 17:33:54 +01:00
"""Returns rmsep and aopt for pls tailored for wide X.
comments:
-- X, Y inputs need to be centered (fixme: check)
2006-12-18 12:59:12 +01:00
"""
2007-01-25 12:58:10 +01:00
k, l = m_shape(Y)
2006-12-18 12:59:12 +01:00
PRESS = zeros((l, amax+1), dtype='f')
if n_blocks==None:
n_blocks = Y.shape[0]
2007-01-25 12:58:10 +01:00
XXt = dot(X, X.T)
V = w_pls_gen(XXt, Y, n_blocks=n_blocks, center=True)
2006-12-18 12:59:12 +01:00
for Din, Doi, Yin, Yout in V:
ym = -sum(Yout, 0)[newaxis]/(1.0*Yin.shape[0])
Yin = Yin - ym
PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)
if algo=='simpls':
dat = w_simpls(Din, Yin, amax)
2007-01-25 12:58:10 +01:00
Q, U, H = dat['Q'], dat['U'], dat['H']
2006-12-18 12:59:12 +01:00
That = dot(Doi, dot(U, inv(triu(dot(H.T,U))) ))
else:
raise NotImplementedError
2007-03-14 17:33:54 +01:00
2006-12-18 12:59:12 +01:00
Yhat = []
for j in range(l):
TQ = dot(That, triu(dot(Q[j,:][:,newaxis], ones((1,amax)))) )
E = Yout[:,j][:,newaxis] - TQ
E = E + sum(E, 0)/Din.shape[0]
PRESS[j,1:] = PRESS[j,1:] + sum(E**2, 0)
#Yhat = Y - dot(That,Q.T)
2007-01-25 12:58:10 +01:00
rmsep = sqrt(PRESS/Y.shape[0])
aopt = find_aopt_from_sep(rmsep)
return rmsep, aopt
2006-12-18 12:59:12 +01:00
2007-03-14 17:33:54 +01:00
def pls_val(X, Y, amax=2, n_blocks=10, algo='pls', metric=None):
2006-12-18 12:59:12 +01:00
""" Validation results of pls model.
2007-03-14 17:33:54 +01:00
comments:
-- X, Y inputs need to be centered (fixme: check)
2007-01-25 12:58:10 +01:00
"""
k, l = m_shape(Y)
2006-12-18 12:59:12 +01:00
PRESS = zeros((l, amax+1), dtype='<f8')
EE = zeros((amax, k, l), dtype='<f8')
Yhat = zeros((amax, k, l), dtype='<f8')
2007-03-14 17:33:54 +01:00
V = pls_gen(X, Y, n_blocks=n_blocks, center=True, index_out=True, metric=metric)
2006-12-18 12:59:12 +01:00
for Xin, Xout, Yin, Yout, out in V:
ym = -sum(Yout,0)[newaxis]/Yin.shape[0]
Yin = (Yin - ym)
PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)
2007-01-25 12:58:10 +01:00
2006-12-18 12:59:12 +01:00
if algo=='pls':
dat = pls(Xin, Yin, amax, mode='normal')
elif algo=='bridge':
dat = simpls(Xin, Yin, amax, mode='normal')
for a in range(amax):
Ba = dat['B'][a,:,:]
Yhat[a,out[:],:] = dot(Xout, Ba)
E = Yout - dot(Xout, Ba)
EE[a,out,:] = E
PRESS[:,a+1] = PRESS[:,a+1] + sum(E**2,0)
2007-01-25 12:58:10 +01:00
rmsep = sqrt(PRESS/(k-1.))
aopt = find_aopt_from_sep(rmsep)
return rmsep, aopt
2006-12-18 12:59:12 +01:00
2007-01-25 12:58:10 +01:00
def pca_alter_val(a, amax, n_sets=10, method='diag'):
2006-12-18 12:59:12 +01:00
"""Pca validation by altering elements in X.
2007-03-14 17:33:54 +01:00
comments:
-- may do all jk estimates in this loop
2006-12-18 12:59:12 +01:00
"""
2007-03-14 17:33:54 +01:00
2006-12-18 12:59:12 +01:00
V = diag_pert(a, n_sets, center=True, index_out=True)
sep = empty((n_sets, amax), dtype='f')
for i, (xi, ind) in enumerate(V):
dat_i = pca(xi, amax, mode='detailed')
2007-03-14 17:33:54 +01:00
Ti, Pi = dat_i['T'],dat_i['P']
2006-12-18 12:59:12 +01:00
for j in xrange(amax):
Xhat = dot(Ti[:,:j+1], Pi[:,:j+1].T)
a_sub = a.ravel().take(ind)
EE = a_sub - Xhat.ravel().take(ind)
tot = (a_sub**2).sum()
sep[i,j] = (EE**2).sum()/tot
2007-01-25 12:58:10 +01:00
sep = sqrt(sep)
aopt = find_aopt_from_sep(sep)
return sep, aopt
def pca_cv_val(a, amax, n_sets):
""" Returns PRESS from cross-validated pca using random segments.
2006-12-18 12:59:12 +01:00
2007-01-25 12:58:10 +01:00
input:
-- a, data matrix (m x n)
-- amax, maximum nuber of components used
-- n_sets, number of segments to calculate
output:
-- sep, (amax x m x n), squared error of prediction (press)
-- aopt, guestimated optimal number of components
2007-03-14 17:33:54 +01:00
2006-12-18 12:59:12 +01:00
"""
2007-03-14 17:33:54 +01:00
2007-01-25 12:58:10 +01:00
m, n = a.shape
2006-12-18 12:59:12 +01:00
E = empty((amax, m, n), dtype='f')
2007-01-25 12:58:10 +01:00
xtot = (a**2).sum() # this needs centering
V = pca_gen(a, n_sets=7, center=True, index_out=True)
for xi, xout, ind in V:
dat_i = pca(xi, amax, mode='fast')
2006-12-18 12:59:12 +01:00
Pi = dat_i['P']
for a in xrange(amax):
Pia = Pi[:,:a+1]
E[a][ind,:] = (X[ind,:] - dot(xout, dot(Pia,Pia.T) ))**2
sep = []
for a in xrange(amax):
sep.append(E[a].sum()/xtot)
2007-01-25 12:58:10 +01:00
sep = array(sep)
aopt = find_aopt_from_sep(sep)
2007-03-14 17:33:54 +01:00
2007-01-25 12:58:10 +01:00
return sep, aopt
2006-12-18 12:59:12 +01:00
2007-03-14 17:33:54 +01:00
def pls_jkW(a, b, amax, n_blocks=None, algo='pls', use_pack=True, center=True, metric=None):
2006-12-18 12:59:12 +01:00
""" Returns CV-segments of paramter W for wide X.
2007-03-14 17:33:54 +01:00
2006-12-18 12:59:12 +01:00
todo: add support for T,Q and B
"""
if n_blocks == None:
n_blocks = b.shape[0]
Wcv = empty((n_blocks, a.shape[1], amax), dtype='f')
2006-12-18 12:59:12 +01:00
2007-03-14 17:33:54 +01:00
if use_pack and metric==None:
2006-12-18 12:59:12 +01:00
u, s, inflater = svd(a, full_matrices=0)
a = u*s
2007-03-14 17:33:54 +01:00
V = pls_gen(a, b, n_blocks=n_blocks, center=center, metric=metric)
2006-12-18 12:59:12 +01:00
for nn,(a_in, a_out, b_in, b_out) in enumerate(V):
if algo=='pls':
dat = pls(a_in, b_in, amax, 'loads', 'fast')
2007-03-14 17:33:54 +01:00
2006-12-18 12:59:12 +01:00
elif algo=='bridge':
dat = bridge(a_in, b_in, amax, 'loads', 'fast')
2007-03-14 17:33:54 +01:00
2006-12-18 12:59:12 +01:00
W = dat['W']
2007-03-14 17:33:54 +01:00
if use_pack and metric==None:
2006-12-18 12:59:12 +01:00
W = dot(inflater.T, W)
Wcv[nn,:,:] = W
2006-12-18 12:59:12 +01:00
return Wcv
2006-12-18 12:59:12 +01:00
2007-03-14 17:33:54 +01:00
def pca_jkP(a, aopt, n_blocks=None, metric=None):
2007-01-25 12:58:10 +01:00
"""Returns loading from PCA on CV-segments.
input:
-- a, data matrix (n x m)
-- aopt, number of components in model.
-- nblocks, number of segments
output:
-- PP, loadings collected in a three way matrix
(n_segments, m, aopt)
comments:
* The loadings are scaled with the (1/samples)*eigenvalues.
* Crossvalidation method is currently set to random blocks of samples.
2006-12-18 12:59:12 +01:00
todo: add support for T
fixme: more efficient to add this in validation loop
"""
if n_blocks == None:
n_blocks = a.shape[0]
PP = empty((n_blocks, a.shape[1], aopt), dtype='f')
V = pca_gen(a, n_sets=n_blocks, center=True)
for nn,(a_in, a_out) in enumerate(V):
2007-01-25 12:58:10 +01:00
dat = pca(a_in, aopt, mode='fast', scale='loads')
2006-12-18 12:59:12 +01:00
P = dat['P']
PP[nn,:,:] = P
return PP
2007-01-25 12:58:10 +01:00
2007-03-14 17:33:54 +01:00
2007-01-25 12:58:10 +01:00
def find_aopt_from_sep(sep, method='75perc'):
"""Returns an estimate of optimal number of components from rmsecv.
"""
2007-03-14 17:33:54 +01:00
2007-01-25 12:58:10 +01:00
if method=='vanilla':
# min rmsep
rmsecv = sqrt(sep.mean(0))
return rmsecv.argmin() + 1
elif method=='75perc':
prct = .75 #percentile
ind = 1.*sep.shape[0]*prct
med = median(sep)
prc_75 = []
for col in sep.T:
col.sort()
prc_75.append(col[int(ind)])
prc_75 = array(prc_75)
for i in range(1, sep.shape[1], 1):
if med[i-1]<prc_75[i]:
return i
return len(med)