This repository has been archived on 2024-07-04. You can view files and clone it, but cannot push or open issues or pull requests.
laydi/fluents/lib/cx_stats.py

256 lines
7.9 KiB
Python

from scipy import zeros,zeros_like,sqrt,dot,trace,sign,round_,argmax,\
sort,ravel,newaxis,asarray,diag,sum,outer,argsort,arange,ones_like,\
all,apply_along_axis,eye
from scipy.linalg import svd,inv,norm,det,sqrtm
from scipy.stats import mean,median
from cx_utils import mat_center
from validation import pls_jkW
from select_generators import shuffle_1d
from engines import *
import time
def hotelling(Pcv, P, p_center='med', cov_center='med',
alpha=0.3, crot=True, strict=False, metric=None):
"""Returns regularized hotelling T^2.
alpha -- regularisation towards pooled cov estimates
beta -- regularisation for unstable eigenvalues
p_center -- location method for submodels
cov_center -- location method for sub coviariances
alpha -- regularisation
crot -- rotate submodels toward full?
strict -- only rotate 90 degree ?
metric -- inverse metric matrix (if Pcv and P from metric pca/pls)
"""
m, n = P.shape
if metric==None:
metric = eye(m, dtype='<f8')
P = dot(metric.T, asarray(P))
n_sets, n, amax = Pcv.shape
# allocate
T_sq = empty((n, ),dtype='f')
Cov_i = zeros((n, amax, amax),dtype='f')
# rotate sub_models to full model
if crot:
for i, Pi in enumerate(Pcv):
Pi = dot(metric.T, Pi)
Pcv[i] = procrustes(P, Pi, strict=strict)
# center of pnull
if p_center=='med':
P_ctr = median(Pcv, 0)
elif p_center=='mean':
# fixme: mean is unstable
P_ctr = mean(Pcv, 0)
else: #use full
P_ctr = P
for i in xrange(n):
Pi = Pcv[:,i,:] # (n_sets x amax)
Pi_ctr = P_ctr[i,:] # (1 x amax)
Pim = (Pi - Pi_ctr[newaxis])*sqrt(n_sets-1)
Cov_i[i] = (1./n_sets)*dot(Pim.T, Pim)
if cov_center == 'med':
Cov = median(Cov_i, 0)
else:
Cov = mean(Cov_i, 0)
reg_cov = (1. - alpha)*Cov_i + alpha*Cov
for i in xrange(n):
Pc = P_ctr[i,:][:,newaxis]
sigma = reg_cov[i]
#T_sq[i] = sqrt(dot(dot(Pc.T, inv(sigma)), Pc).ravel())
T_sq[i] = dot(dot(Pc.T, inv(sigma)), Pc).ravel()
return T_sq
def procrustes(A, B, strict=True, center=False, verbose=False):
"""Rotation of B to A.
strict -- Only do flipping and shuffling
center -- Center before rotation, translate back after
verbose -- Print ssq
No scaling calculated.
Output B_rot = Rotated B
"""
if center:
A,mn_A = mat_center(A, ret_mn=True)
B,mn_B = mat_center(B, ret_mn=True)
u,s,vh = svd(dot(B.T, A))
v = vh.T
Cm = dot(u, v.T) #orthogonal rotation matrix
if strict: # just inverting and flipping
Cm = ensure_strict(Cm)
b_rot = dot(B, Cm)
if verbose:
print Cm.round()
fit = sum(ravel(B - b_rot)**2)
print "Sum of squares: %s" %fit
if center:
return mn_B + b_rot
else:
return b_rot
def expl_var_x(Xc, T):
"""Returns explained variance of X.
T should carry variance in length, Xc has zero col-mean.
"""
exp_var_x = diag(dot(T.T, T))*100/(sum(Xc**2))
return exp_var_x
def expl_var_y(Y, T, Q):
"""Returns explained variance of Y.
"""
# centered Y
exp_var_y = zeros((Q.shape[1], ))
for a in range(Q.shape[1]):
Ya = outer(T[:,a], Q[:,a])
exp_var_y[a] = 100*sum(Ya**2)/sum(Y**2)
return exp_var_y
def pls_qvals(a, b, aopt=None, alpha=.3,
n_iter=20, algo='pls',
sim_method='shuffle',
p_center='med', cov_center='med',
crot=True, strict=False, metric=None):
"""Returns qvals for pls model.
input:
a -- centered data matrix
b -- centered data matrix
aopt -- scalar, opt. number of components
alpha -- [0,1] regularisation parameter for T2-test
n_iter -- number of permutations
sim_method -- permutation method ['shuffle']
p_center -- location estimator for sub models ['med']
cov_center -- location estimator for covariance of submodels ['med']
crot -- bool, use rotations of sub models?
strict -- bool, use stict (rot/flips only) rotations?
metric -- bool, use row metric?
"""
m, n = a.shape
TSQ = zeros((n, n_iter), dtype='<f8') # (nvars x n_subsets)
n_false = zeros((n, n_iter), dtype='<f8')
#full model
if metric!=None:
a = dot(a, metric)
if algo=='bridge':
dat = bridge(a, b, aopt, 'loads', 'fast')
else:
dat = pls(a, b, aopt, 'loads', 'fast')
Wcv = pls_jkW(a, b, aopt, n_blocks=None, algo=algo, metric=metric)
tsq_full = hotelling(Wcv, dat['W'], p_center=p_center,
alpha=alpha, crot=crot, strict=strict,
cov_center=cov_center)
t0 = time.time()
Vs = shuffle_1d(b, n_iter)
for i, b_shuff in enumerate(Vs):
t1 = time.time()
if algo=='bridge':
dat = bridge(a, b_shuff, aopt, 'loads','fast')
else:
dat = pls(a, b, aopt, 'loads', 'fast')
Wcv = pls_jkW(a, b_shuff, aopt, n_blocks=None, algo=algo, metric=metric)
TSQ[:,i] = hotelling(Wcv, dat['W'], p_center=p_center,
alpha=alpha, crot=crot, strict=strict,
cov_center=cov_center)
print time.time() - t1
sort_index = argsort(tsq_full)[::-1]
back_sort_index = sort_index.argsort()
print time.time() - t0
# count false positives
tsq_full_sorted = tsq_full.take(sort_index)
for i in xrange(n_iter):
for j in xrange(n):
n_false[j,i] = sum(TSQ[:,i]>=tsq_full[j])
false_pos = median(n_false, 1)
ll = arange(1, len(false_pos)+1, 1)
sort_qval = false_pos.take(sort_index)/ll
qval = false_pos/ll.take(back_sort_index)
print time.time() - t0
return qval, false_pos, TSQ, tsq_full
def ensure_strict(C, only_flips=True):
"""Ensure that a rotation matrix does only 90 degree rotations.
In multiplication with pcs this allows flips and reordering.
if only_flips is True there will onlt be flips allowed
"""
Cm = C
S = sign(C) # signs
if only_flips==True:
C = eye(Cm.shape[0])*S
return C
Cm = zeros_like(C)
Cm.putmask(1.,abs(C)>.6)
if det(Cm)>1:
raise ValueError,"Implement this!"
return Cm*S
def leverage(aopt=1,*args):
"""Returns leverages
input : aopt, number of components to base leverage calculations on
*args, matrices of normed blm-paramters
output: leverages
For PCA typical inputs are normalised T or normalised P
For PLSR typical inputs are normalised T or normalised W
"""
if aopt<1:
raise ValueError,"Leverages only make sense for aopt>0"
lev = []
for u in args:
lev_u = 1./u.shape[0] + dot(u[:,:aopt], u[:,:aopt].T).diagonal()
lev.append(lev_u)
return lev
def variances(a, t, p):
"""Returns explained variance and ind. var from blm-params.
input:
a -- full centered matrix
t,p -- parameters from a bilinear approx of the above matrix.
output:
var -- variance of each component
var_exp -- cumulative explained variance in percentage
Typical inputs are: X(centered),T,P for PCA or
X(centered),T,P / Y(centered),T,Q for PLSR.
"""
tot_var = sum(a**2)
var = 100*(sum(p**2, 0)*sum(t**2, 0))/tot_var
var_exp = cumsum(var)
return var, var_exp
def residual_diagnostics(Y, Yhat, aopt=1):
"""Root mean errors and press values.
R2 vals
"""
pass
def ssq(E, axis=0, weights=None):
"""Sum of squares, supports weights."""
n = E.shape[axis]
if weights==None:
weights = eye(n)
else:
weigths = diag(weigths)
if axis==0:
Ew = dot(weights, E)
elif axis==1:
Ew = dot(E, weights)
else:
raise NotImplementedError, "Higher order modes not supported"
return pow(Ew,2).sum(axis)