2006-12-18 12:59:12 +01:00
|
|
|
from scipy import zeros,zeros_like,sqrt,dot,trace,sign,round_,argmax,\
|
|
|
|
sort,ravel,newaxis,asarray,diag,sum,outer,argsort,arange,ones_like,\
|
|
|
|
all,apply_along_axis,eye
|
|
|
|
from scipy.linalg import svd,inv,norm,det,sqrtm
|
|
|
|
from scipy.stats import mean,median
|
|
|
|
from cx_utils import mat_center
|
|
|
|
from validation import pls_jkW
|
|
|
|
from select_generators import shuffle_1d
|
|
|
|
from engines import *
|
|
|
|
import time
|
|
|
|
|
2007-01-31 12:57:04 +01:00
|
|
|
|
|
|
|
def hotelling(Pcv, P, p_center='med', cov_center='med',
|
2006-12-18 12:59:12 +01:00
|
|
|
alpha=0.3, crot=True, strict=False, metric=None):
|
|
|
|
"""Returns regularized hotelling T^2.
|
|
|
|
|
|
|
|
alpha -- regularisation towards pooled cov estimates
|
|
|
|
beta -- regularisation for unstable eigenvalues
|
|
|
|
p_center -- location method for submodels
|
|
|
|
cov_center -- location method for sub coviariances
|
|
|
|
alpha -- regularisation
|
|
|
|
crot -- rotate submodels toward full?
|
|
|
|
strict -- only rotate 90 degree ?
|
2007-01-31 12:57:04 +01:00
|
|
|
metric -- inverse metric matrix (if Pcv and P from metric pca/pls)
|
2006-12-18 12:59:12 +01:00
|
|
|
|
|
|
|
"""
|
2007-01-31 12:57:04 +01:00
|
|
|
m, n = P.shape
|
2006-12-18 12:59:12 +01:00
|
|
|
if metric==None:
|
|
|
|
metric = eye(m, dtype='<f8')
|
2007-01-31 12:57:04 +01:00
|
|
|
P = dot(metric.T, asarray(P))
|
|
|
|
n_sets, n, amax = Pcv.shape
|
2006-12-18 12:59:12 +01:00
|
|
|
# allocate
|
|
|
|
T_sq = empty((n, ),dtype='f')
|
|
|
|
Cov_i = zeros((n, amax, amax),dtype='f')
|
|
|
|
|
|
|
|
# rotate sub_models to full model
|
|
|
|
if crot:
|
2007-01-31 12:57:04 +01:00
|
|
|
for i, Pi in enumerate(Pcv):
|
2006-12-18 12:59:12 +01:00
|
|
|
Pi = dot(metric.T, Pi)
|
2007-01-31 12:57:04 +01:00
|
|
|
Pcv[i] = procrustes(P, Pi, strict=strict)
|
2006-12-18 12:59:12 +01:00
|
|
|
|
|
|
|
# center of pnull
|
|
|
|
if p_center=='med':
|
2007-01-31 12:57:04 +01:00
|
|
|
P_ctr = median(Pcv, 0)
|
2006-12-18 12:59:12 +01:00
|
|
|
elif p_center=='mean':
|
|
|
|
# fixme: mean is unstable
|
2007-01-31 12:57:04 +01:00
|
|
|
P_ctr = mean(Pcv, 0)
|
2006-12-18 12:59:12 +01:00
|
|
|
else: #use full
|
2007-01-31 12:57:04 +01:00
|
|
|
P_ctr = P
|
2006-12-18 12:59:12 +01:00
|
|
|
|
|
|
|
for i in xrange(n):
|
2007-01-31 12:57:04 +01:00
|
|
|
Pi = Pcv[:,i,:] # (n_sets x amax)
|
2006-12-18 12:59:12 +01:00
|
|
|
Pi_ctr = P_ctr[i,:] # (1 x amax)
|
|
|
|
Pim = (Pi - Pi_ctr[newaxis])*sqrt(n_sets-1)
|
|
|
|
Cov_i[i] = (1./n_sets)*dot(Pim.T, Pim)
|
|
|
|
|
|
|
|
if cov_center == 'med':
|
|
|
|
Cov = median(Cov_i, 0)
|
|
|
|
else:
|
|
|
|
Cov = mean(Cov_i, 0)
|
|
|
|
|
|
|
|
reg_cov = (1. - alpha)*Cov_i + alpha*Cov
|
|
|
|
for i in xrange(n):
|
|
|
|
Pc = P_ctr[i,:][:,newaxis]
|
|
|
|
sigma = reg_cov[i]
|
|
|
|
#T_sq[i] = sqrt(dot(dot(Pc.T, inv(sigma)), Pc).ravel())
|
|
|
|
T_sq[i] = dot(dot(Pc.T, inv(sigma)), Pc).ravel()
|
|
|
|
return T_sq
|
|
|
|
|
|
|
|
def procrustes(A, B, strict=True, center=False, verbose=False):
|
|
|
|
"""Rotation of B to A.
|
|
|
|
|
|
|
|
strict -- Only do flipping and shuffling
|
|
|
|
center -- Center before rotation, translate back after
|
|
|
|
verbose -- Print ssq
|
|
|
|
|
|
|
|
No scaling calculated.
|
|
|
|
Output B_rot = Rotated B
|
|
|
|
"""
|
|
|
|
if center:
|
|
|
|
A,mn_A = mat_center(A, ret_mn=True)
|
|
|
|
B,mn_B = mat_center(B, ret_mn=True)
|
|
|
|
u,s,vh = svd(dot(B.T, A))
|
|
|
|
v = vh.T
|
|
|
|
Cm = dot(u, v.T) #orthogonal rotation matrix
|
|
|
|
if strict: # just inverting and flipping
|
|
|
|
Cm = ensure_strict(Cm)
|
|
|
|
b_rot = dot(B, Cm)
|
|
|
|
|
|
|
|
if verbose:
|
|
|
|
print Cm.round()
|
|
|
|
fit = sum(ravel(B - b_rot)**2)
|
|
|
|
print "Sum of squares: %s" %fit
|
|
|
|
if center:
|
|
|
|
return mn_B + b_rot
|
|
|
|
else:
|
|
|
|
return b_rot
|
|
|
|
|
2007-03-14 17:31:25 +01:00
|
|
|
def expl_var_x(Xc, T):
|
|
|
|
"""Returns explained variance of X.
|
|
|
|
T should carry variance in length, Xc has zero col-mean.
|
|
|
|
"""
|
|
|
|
exp_var_x = diag(dot(T.T, T))*100/(sum(Xc**2))
|
2006-12-18 12:59:12 +01:00
|
|
|
return exp_var_x
|
|
|
|
|
|
|
|
def expl_var_y(Y, T, Q):
|
|
|
|
"""Returns explained variance of Y.
|
|
|
|
"""
|
|
|
|
# centered Y
|
|
|
|
exp_var_y = zeros((Q.shape[1], ))
|
|
|
|
for a in range(Q.shape[1]):
|
|
|
|
Ya = outer(T[:,a], Q[:,a])
|
|
|
|
exp_var_y[a] = 100*sum(Ya**2)/sum(Y**2)
|
|
|
|
return exp_var_y
|
|
|
|
|
|
|
|
def pls_qvals(a, b, aopt=None, alpha=.3,
|
|
|
|
n_iter=20, algo='pls',
|
|
|
|
sim_method='shuffle',
|
|
|
|
p_center='med', cov_center='med',
|
|
|
|
crot=True, strict=False, metric=None):
|
|
|
|
|
|
|
|
"""Returns qvals for pls model.
|
|
|
|
|
|
|
|
input:
|
|
|
|
a -- centered data matrix
|
|
|
|
b -- centered data matrix
|
|
|
|
aopt -- scalar, opt. number of components
|
|
|
|
alpha -- [0,1] regularisation parameter for T2-test
|
|
|
|
n_iter -- number of permutations
|
|
|
|
sim_method -- permutation method ['shuffle']
|
|
|
|
p_center -- location estimator for sub models ['med']
|
|
|
|
cov_center -- location estimator for covariance of submodels ['med']
|
|
|
|
crot -- bool, use rotations of sub models?
|
|
|
|
strict -- bool, use stict (rot/flips only) rotations?
|
|
|
|
metric -- bool, use row metric?
|
|
|
|
"""
|
|
|
|
|
|
|
|
m, n = a.shape
|
|
|
|
TSQ = zeros((n, n_iter), dtype='<f8') # (nvars x n_subsets)
|
|
|
|
n_false = zeros((n, n_iter), dtype='<f8')
|
|
|
|
|
|
|
|
#full model
|
2007-03-14 17:31:25 +01:00
|
|
|
if metric!=None:
|
|
|
|
a = dot(a, metric)
|
2006-12-18 12:59:12 +01:00
|
|
|
if algo=='bridge':
|
|
|
|
dat = bridge(a, b, aopt, 'loads', 'fast')
|
|
|
|
else:
|
|
|
|
dat = pls(a, b, aopt, 'loads', 'fast')
|
2007-03-14 17:31:25 +01:00
|
|
|
Wcv = pls_jkW(a, b, aopt, n_blocks=None, algo=algo, metric=metric)
|
2007-01-31 12:57:04 +01:00
|
|
|
tsq_full = hotelling(Wcv, dat['W'], p_center=p_center,
|
2006-12-18 12:59:12 +01:00
|
|
|
alpha=alpha, crot=crot, strict=strict,
|
2007-03-14 17:31:25 +01:00
|
|
|
cov_center=cov_center)
|
2006-12-18 12:59:12 +01:00
|
|
|
t0 = time.time()
|
|
|
|
Vs = shuffle_1d(b, n_iter)
|
2007-03-14 17:31:25 +01:00
|
|
|
for i, b_shuff in enumerate(Vs):
|
2006-12-18 12:59:12 +01:00
|
|
|
t1 = time.time()
|
|
|
|
if algo=='bridge':
|
|
|
|
dat = bridge(a, b_shuff, aopt, 'loads','fast')
|
|
|
|
else:
|
|
|
|
dat = pls(a, b, aopt, 'loads', 'fast')
|
2007-03-14 17:31:25 +01:00
|
|
|
Wcv = pls_jkW(a, b_shuff, aopt, n_blocks=None, algo=algo, metric=metric)
|
2007-01-31 12:57:04 +01:00
|
|
|
TSQ[:,i] = hotelling(Wcv, dat['W'], p_center=p_center,
|
2006-12-18 12:59:12 +01:00
|
|
|
alpha=alpha, crot=crot, strict=strict,
|
2007-03-14 17:31:25 +01:00
|
|
|
cov_center=cov_center)
|
2006-12-18 12:59:12 +01:00
|
|
|
print time.time() - t1
|
|
|
|
sort_index = argsort(tsq_full)[::-1]
|
|
|
|
back_sort_index = sort_index.argsort()
|
|
|
|
print time.time() - t0
|
|
|
|
|
|
|
|
# count false positives
|
|
|
|
tsq_full_sorted = tsq_full.take(sort_index)
|
|
|
|
for i in xrange(n_iter):
|
|
|
|
for j in xrange(n):
|
|
|
|
n_false[j,i] = sum(TSQ[:,i]>=tsq_full[j])
|
|
|
|
false_pos = median(n_false, 1)
|
|
|
|
ll = arange(1, len(false_pos)+1, 1)
|
|
|
|
sort_qval = false_pos.take(sort_index)/ll
|
|
|
|
qval = false_pos/ll.take(back_sort_index)
|
|
|
|
print time.time() - t0
|
|
|
|
return qval, false_pos, TSQ, tsq_full
|
|
|
|
|
|
|
|
def ensure_strict(C, only_flips=True):
|
|
|
|
"""Ensure that a rotation matrix does only 90 degree rotations.
|
|
|
|
In multiplication with pcs this allows flips and reordering.
|
|
|
|
|
|
|
|
if only_flips is True there will onlt be flips allowed
|
|
|
|
"""
|
|
|
|
Cm = C
|
|
|
|
S = sign(C) # signs
|
|
|
|
if only_flips==True:
|
|
|
|
C = eye(Cm.shape[0])*S
|
|
|
|
return C
|
|
|
|
Cm = zeros_like(C)
|
|
|
|
Cm.putmask(1.,abs(C)>.6)
|
|
|
|
if det(Cm)>1:
|
|
|
|
raise ValueError,"Implement this!"
|
|
|
|
return Cm*S
|
|
|
|
|
|
|
|
def leverage(aopt=1,*args):
|
|
|
|
"""Returns leverages
|
|
|
|
input : aopt, number of components to base leverage calculations on
|
|
|
|
*args, matrices of normed blm-paramters
|
|
|
|
output: leverages
|
|
|
|
|
|
|
|
For PCA typical inputs are normalised T or normalised P
|
|
|
|
For PLSR typical inputs are normalised T or normalised W
|
|
|
|
"""
|
|
|
|
if aopt<1:
|
|
|
|
raise ValueError,"Leverages only make sense for aopt>0"
|
|
|
|
lev = []
|
|
|
|
for u in args:
|
|
|
|
lev_u = 1./u.shape[0] + dot(u[:,:aopt], u[:,:aopt].T).diagonal()
|
|
|
|
lev.append(lev_u)
|
|
|
|
return lev
|
|
|
|
|
2007-03-14 17:31:25 +01:00
|
|
|
def variances(a, t, p):
|
2006-12-18 12:59:12 +01:00
|
|
|
"""Returns explained variance and ind. var from blm-params.
|
|
|
|
input:
|
|
|
|
a -- full centered matrix
|
|
|
|
t,p -- parameters from a bilinear approx of the above matrix.
|
|
|
|
output:
|
|
|
|
var -- variance of each component
|
|
|
|
var_exp -- cumulative explained variance in percentage
|
|
|
|
|
|
|
|
Typical inputs are: X(centered),T,P for PCA or
|
|
|
|
X(centered),T,P / Y(centered),T,Q for PLSR.
|
|
|
|
"""
|
|
|
|
|
|
|
|
tot_var = sum(a**2)
|
|
|
|
var = 100*(sum(p**2, 0)*sum(t**2, 0))/tot_var
|
|
|
|
var_exp = cumsum(var)
|
|
|
|
return var, var_exp
|
|
|
|
|
|
|
|
def residual_diagnostics(Y, Yhat, aopt=1):
|
|
|
|
"""Root mean errors and press values.
|
|
|
|
R2 vals
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def ssq(E, axis=0, weights=None):
|
|
|
|
"""Sum of squares, supports weights."""
|
|
|
|
n = E.shape[axis]
|
|
|
|
if weights==None:
|
|
|
|
weights = eye(n)
|
|
|
|
else:
|
|
|
|
weigths = diag(weigths)
|
|
|
|
if axis==0:
|
|
|
|
Ew = dot(weights, E)
|
|
|
|
elif axis==1:
|
|
|
|
Ew = dot(E, weights)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError, "Higher order modes not supported"
|
|
|
|
return pow(Ew,2).sum(axis)
|
|
|
|
|