257 lines
7.8 KiB
Python
257 lines
7.8 KiB
Python
|
from scipy import zeros,zeros_like,sqrt,dot,trace,sign,round_,argmax,\
|
||
|
sort,ravel,newaxis,asarray,diag,sum,outer,argsort,arange,ones_like,\
|
||
|
all,apply_along_axis,eye
|
||
|
from scipy.linalg import svd,inv,norm,det,sqrtm
|
||
|
from scipy.stats import mean,median
|
||
|
from cx_utils import mat_center
|
||
|
from validation import pls_jkW
|
||
|
from select_generators import shuffle_1d
|
||
|
from engines import *
|
||
|
import time
|
||
|
|
||
|
def hotelling(P, Pfull, p_center='med', cov_center='med',
|
||
|
alpha=0.3, crot=True, strict=False, metric=None):
|
||
|
"""Returns regularized hotelling T^2.
|
||
|
|
||
|
alpha -- regularisation towards pooled cov estimates
|
||
|
beta -- regularisation for unstable eigenvalues
|
||
|
p_center -- location method for submodels
|
||
|
cov_center -- location method for sub coviariances
|
||
|
alpha -- regularisation
|
||
|
crot -- rotate submodels toward full?
|
||
|
strict -- only rotate 90 degree ?
|
||
|
metric -- inverse metric matrix (if P and Pfull from metric pca/pls)
|
||
|
|
||
|
"""
|
||
|
m, n = Pfull.shape
|
||
|
if metric==None:
|
||
|
metric = eye(m, dtype='<f8')
|
||
|
Pfull = dot(metric.T, asarray(Pfull))
|
||
|
n_sets,n,amax = P.shape
|
||
|
# allocate
|
||
|
T_sq = empty((n, ),dtype='f')
|
||
|
Cov_i = zeros((n, amax, amax),dtype='f')
|
||
|
|
||
|
# rotate sub_models to full model
|
||
|
if crot:
|
||
|
for i,Pi in enumerate(P):
|
||
|
Pi = dot(metric.T, Pi)
|
||
|
P[i] = procrustes(Pfull, Pi, strict=strict)
|
||
|
|
||
|
# center of pnull
|
||
|
if p_center=='med':
|
||
|
P_ctr = median(P, 0)
|
||
|
elif p_center=='mean':
|
||
|
# fixme: mean is unstable
|
||
|
P_ctr = mean(P, 0)
|
||
|
else: #use full
|
||
|
P_ctr = Pfull
|
||
|
|
||
|
for i in xrange(n):
|
||
|
Pi = P[:,i,:] # (n_sets x amax)
|
||
|
Pi_ctr = P_ctr[i,:] # (1 x amax)
|
||
|
Pim = (Pi - Pi_ctr[newaxis])*sqrt(n_sets-1)
|
||
|
Cov_i[i] = (1./n_sets)*dot(Pim.T, Pim)
|
||
|
|
||
|
if cov_center == 'med':
|
||
|
Cov = median(Cov_i, 0)
|
||
|
else:
|
||
|
Cov = mean(Cov_i, 0)
|
||
|
|
||
|
reg_cov = (1. - alpha)*Cov_i + alpha*Cov
|
||
|
for i in xrange(n):
|
||
|
Pc = P_ctr[i,:][:,newaxis]
|
||
|
sigma = reg_cov[i]
|
||
|
#T_sq[i] = sqrt(dot(dot(Pc.T, inv(sigma)), Pc).ravel())
|
||
|
T_sq[i] = dot(dot(Pc.T, inv(sigma)), Pc).ravel()
|
||
|
return T_sq
|
||
|
|
||
|
def procrustes(A, B, strict=True, center=False, verbose=False):
|
||
|
"""Rotation of B to A.
|
||
|
|
||
|
strict -- Only do flipping and shuffling
|
||
|
center -- Center before rotation, translate back after
|
||
|
verbose -- Print ssq
|
||
|
|
||
|
No scaling calculated.
|
||
|
Output B_rot = Rotated B
|
||
|
"""
|
||
|
if center:
|
||
|
A,mn_A = mat_center(A, ret_mn=True)
|
||
|
B,mn_B = mat_center(B, ret_mn=True)
|
||
|
u,s,vh = svd(dot(B.T, A))
|
||
|
v = vh.T
|
||
|
Cm = dot(u, v.T) #orthogonal rotation matrix
|
||
|
if strict: # just inverting and flipping
|
||
|
Cm = ensure_strict(Cm)
|
||
|
b_rot = dot(B, Cm)
|
||
|
|
||
|
if verbose:
|
||
|
print Cm.round()
|
||
|
fit = sum(ravel(B - b_rot)**2)
|
||
|
print "Sum of squares: %s" %fit
|
||
|
if center:
|
||
|
return mn_B + b_rot
|
||
|
else:
|
||
|
return b_rot
|
||
|
|
||
|
def expl_var_x(X, T):
|
||
|
"""Returns explained variance of X."""
|
||
|
# centered X,Y
|
||
|
exp_var_x = diag(dot(T.T, T))*100/(sum(X**2))
|
||
|
return exp_var_x
|
||
|
|
||
|
def expl_var_y(Y, T, Q):
|
||
|
"""Returns explained variance of Y.
|
||
|
"""
|
||
|
# centered Y
|
||
|
exp_var_y = zeros((Q.shape[1], ))
|
||
|
for a in range(Q.shape[1]):
|
||
|
Ya = outer(T[:,a], Q[:,a])
|
||
|
exp_var_y[a] = 100*sum(Ya**2)/sum(Y**2)
|
||
|
return exp_var_y
|
||
|
|
||
|
def pls_qvals(a, b, aopt=None, alpha=.3,
|
||
|
n_iter=20, algo='pls',
|
||
|
sim_method='shuffle',
|
||
|
p_center='med', cov_center='med',
|
||
|
crot=True, strict=False, metric=None):
|
||
|
|
||
|
"""Returns qvals for pls model.
|
||
|
|
||
|
input:
|
||
|
a -- centered data matrix
|
||
|
b -- centered data matrix
|
||
|
aopt -- scalar, opt. number of components
|
||
|
alpha -- [0,1] regularisation parameter for T2-test
|
||
|
n_iter -- number of permutations
|
||
|
sim_method -- permutation method ['shuffle']
|
||
|
p_center -- location estimator for sub models ['med']
|
||
|
cov_center -- location estimator for covariance of submodels ['med']
|
||
|
crot -- bool, use rotations of sub models?
|
||
|
strict -- bool, use stict (rot/flips only) rotations?
|
||
|
metric -- bool, use row metric?
|
||
|
|
||
|
|
||
|
"""
|
||
|
|
||
|
m, n = a.shape
|
||
|
TSQ = zeros((n, n_iter), dtype='<f8') # (nvars x n_subsets)
|
||
|
n_false = zeros((n, n_iter), dtype='<f8')
|
||
|
|
||
|
#full model
|
||
|
if algo=='bridge':
|
||
|
dat = bridge(a, b, aopt, 'loads', 'fast')
|
||
|
else:
|
||
|
dat = pls(a, b, aopt, 'loads', 'fast')
|
||
|
W = pls_jkW(a, b, aopt, n_blocks=None, algo=algo)
|
||
|
tsq_full = hotelling(W, dat['W'], p_center=p_center,
|
||
|
alpha=alpha, crot=crot, strict=strict,
|
||
|
cov_center=cov_center, metric=metric)
|
||
|
t0 = time.time()
|
||
|
Vs = shuffle_1d(b, n_iter)
|
||
|
for i,b_shuff in enumerate(Vs):
|
||
|
t1 = time.time()
|
||
|
if algo=='bridge':
|
||
|
dat = bridge(a, b_shuff, aopt, 'loads','fast')
|
||
|
else:
|
||
|
dat = pls(a, b, aopt, 'loads', 'fast')
|
||
|
W = pls_jkW(a, b_shuff, aopt, n_blocks=None, algo=algo)
|
||
|
TSQ[:,i] = hotelling(W, dat['W'],p_center=p_center,
|
||
|
alpha=alpha, crot=crot, strict=strict,
|
||
|
cov_center=cov_center, metric=metric)
|
||
|
print time.time() - t1
|
||
|
sort_index = argsort(tsq_full)[::-1]
|
||
|
back_sort_index = sort_index.argsort()
|
||
|
print time.time() - t0
|
||
|
|
||
|
# count false positives
|
||
|
tsq_full_sorted = tsq_full.take(sort_index)
|
||
|
for i in xrange(n_iter):
|
||
|
for j in xrange(n):
|
||
|
n_false[j,i] = sum(TSQ[:,i]>=tsq_full[j])
|
||
|
false_pos = median(n_false, 1)
|
||
|
ll = arange(1, len(false_pos)+1, 1)
|
||
|
sort_qval = false_pos.take(sort_index)/ll
|
||
|
qval = false_pos/ll.take(back_sort_index)
|
||
|
print time.time() - t0
|
||
|
return qval, false_pos, TSQ, tsq_full
|
||
|
|
||
|
def ensure_strict(C, only_flips=True):
|
||
|
"""Ensure that a rotation matrix does only 90 degree rotations.
|
||
|
In multiplication with pcs this allows flips and reordering.
|
||
|
|
||
|
if only_flips is True there will onlt be flips allowed
|
||
|
"""
|
||
|
Cm = C
|
||
|
S = sign(C) # signs
|
||
|
if only_flips==True:
|
||
|
C = eye(Cm.shape[0])*S
|
||
|
return C
|
||
|
Cm = zeros_like(C)
|
||
|
Cm.putmask(1.,abs(C)>.6)
|
||
|
if det(Cm)>1:
|
||
|
raise ValueError,"Implement this!"
|
||
|
return Cm*S
|
||
|
|
||
|
|
||
|
|
||
|
def leverage(aopt=1,*args):
|
||
|
"""Returns leverages
|
||
|
input : aopt, number of components to base leverage calculations on
|
||
|
*args, matrices of normed blm-paramters
|
||
|
output: leverages
|
||
|
|
||
|
For PCA typical inputs are normalised T or normalised P
|
||
|
For PLSR typical inputs are normalised T or normalised W
|
||
|
"""
|
||
|
if aopt<1:
|
||
|
raise ValueError,"Leverages only make sense for aopt>0"
|
||
|
lev = []
|
||
|
for u in args:
|
||
|
lev_u = 1./u.shape[0] + dot(u[:,:aopt], u[:,:aopt].T).diagonal()
|
||
|
lev.append(lev_u)
|
||
|
return lev
|
||
|
|
||
|
|
||
|
def variances(a,t,p):
|
||
|
"""Returns explained variance and ind. var from blm-params.
|
||
|
input:
|
||
|
a -- full centered matrix
|
||
|
t,p -- parameters from a bilinear approx of the above matrix.
|
||
|
output:
|
||
|
var -- variance of each component
|
||
|
var_exp -- cumulative explained variance in percentage
|
||
|
|
||
|
Typical inputs are: X(centered),T,P for PCA or
|
||
|
X(centered),T,P / Y(centered),T,Q for PLSR.
|
||
|
"""
|
||
|
|
||
|
tot_var = sum(a**2)
|
||
|
var = 100*(sum(p**2, 0)*sum(t**2, 0))/tot_var
|
||
|
var_exp = cumsum(var)
|
||
|
return var, var_exp
|
||
|
|
||
|
def residual_diagnostics(Y, Yhat, aopt=1):
|
||
|
"""Root mean errors and press values.
|
||
|
R2 vals
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
def ssq(E, axis=0, weights=None):
|
||
|
"""Sum of squares, supports weights."""
|
||
|
n = E.shape[axis]
|
||
|
if weights==None:
|
||
|
weights = eye(n)
|
||
|
else:
|
||
|
weigths = diag(weigths)
|
||
|
if axis==0:
|
||
|
Ew = dot(weights, E)
|
||
|
elif axis==1:
|
||
|
Ew = dot(E, weights)
|
||
|
else:
|
||
|
raise NotImplementedError, "Higher order modes not supported"
|
||
|
return pow(Ew,2).sum(axis)
|
||
|
|