bugfixed scaling issues
This commit is contained in:
parent
f064b7134d
commit
8e607c7b1a
|
@ -9,7 +9,8 @@ from select_generators import shuffle_1d
|
||||||
from engines import *
|
from engines import *
|
||||||
import time
|
import time
|
||||||
|
|
||||||
def hotelling(P, Pfull, p_center='med', cov_center='med',
|
|
||||||
|
def hotelling(Pcv, P, p_center='med', cov_center='med',
|
||||||
alpha=0.3, crot=True, strict=False, metric=None):
|
alpha=0.3, crot=True, strict=False, metric=None):
|
||||||
"""Returns regularized hotelling T^2.
|
"""Returns regularized hotelling T^2.
|
||||||
|
|
||||||
|
@ -20,35 +21,35 @@ def hotelling(P, Pfull, p_center='med', cov_center='med',
|
||||||
alpha -- regularisation
|
alpha -- regularisation
|
||||||
crot -- rotate submodels toward full?
|
crot -- rotate submodels toward full?
|
||||||
strict -- only rotate 90 degree ?
|
strict -- only rotate 90 degree ?
|
||||||
metric -- inverse metric matrix (if P and Pfull from metric pca/pls)
|
metric -- inverse metric matrix (if Pcv and P from metric pca/pls)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
m, n = Pfull.shape
|
m, n = P.shape
|
||||||
if metric==None:
|
if metric==None:
|
||||||
metric = eye(m, dtype='<f8')
|
metric = eye(m, dtype='<f8')
|
||||||
Pfull = dot(metric.T, asarray(Pfull))
|
P = dot(metric.T, asarray(P))
|
||||||
n_sets,n,amax = P.shape
|
n_sets, n, amax = Pcv.shape
|
||||||
# allocate
|
# allocate
|
||||||
T_sq = empty((n, ),dtype='f')
|
T_sq = empty((n, ),dtype='f')
|
||||||
Cov_i = zeros((n, amax, amax),dtype='f')
|
Cov_i = zeros((n, amax, amax),dtype='f')
|
||||||
|
|
||||||
# rotate sub_models to full model
|
# rotate sub_models to full model
|
||||||
if crot:
|
if crot:
|
||||||
for i,Pi in enumerate(P):
|
for i, Pi in enumerate(Pcv):
|
||||||
Pi = dot(metric.T, Pi)
|
Pi = dot(metric.T, Pi)
|
||||||
P[i] = procrustes(Pfull, Pi, strict=strict)
|
Pcv[i] = procrustes(P, Pi, strict=strict)
|
||||||
|
|
||||||
# center of pnull
|
# center of pnull
|
||||||
if p_center=='med':
|
if p_center=='med':
|
||||||
P_ctr = median(P, 0)
|
P_ctr = median(Pcv, 0)
|
||||||
elif p_center=='mean':
|
elif p_center=='mean':
|
||||||
# fixme: mean is unstable
|
# fixme: mean is unstable
|
||||||
P_ctr = mean(P, 0)
|
P_ctr = mean(Pcv, 0)
|
||||||
else: #use full
|
else: #use full
|
||||||
P_ctr = Pfull
|
P_ctr = P
|
||||||
|
|
||||||
for i in xrange(n):
|
for i in xrange(n):
|
||||||
Pi = P[:,i,:] # (n_sets x amax)
|
Pi = Pcv[:,i,:] # (n_sets x amax)
|
||||||
Pi_ctr = P_ctr[i,:] # (1 x amax)
|
Pi_ctr = P_ctr[i,:] # (1 x amax)
|
||||||
Pim = (Pi - Pi_ctr[newaxis])*sqrt(n_sets-1)
|
Pim = (Pi - Pi_ctr[newaxis])*sqrt(n_sets-1)
|
||||||
Cov_i[i] = (1./n_sets)*dot(Pim.T, Pim)
|
Cov_i[i] = (1./n_sets)*dot(Pim.T, Pim)
|
||||||
|
@ -131,8 +132,6 @@ def pls_qvals(a, b, aopt=None, alpha=.3,
|
||||||
crot -- bool, use rotations of sub models?
|
crot -- bool, use rotations of sub models?
|
||||||
strict -- bool, use stict (rot/flips only) rotations?
|
strict -- bool, use stict (rot/flips only) rotations?
|
||||||
metric -- bool, use row metric?
|
metric -- bool, use row metric?
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
m, n = a.shape
|
m, n = a.shape
|
||||||
|
@ -144,8 +143,8 @@ def pls_qvals(a, b, aopt=None, alpha=.3,
|
||||||
dat = bridge(a, b, aopt, 'loads', 'fast')
|
dat = bridge(a, b, aopt, 'loads', 'fast')
|
||||||
else:
|
else:
|
||||||
dat = pls(a, b, aopt, 'loads', 'fast')
|
dat = pls(a, b, aopt, 'loads', 'fast')
|
||||||
W = pls_jkW(a, b, aopt, n_blocks=None, algo=algo)
|
Wcv = pls_jkW(a, b, aopt, n_blocks=None, algo=algo)
|
||||||
tsq_full = hotelling(W, dat['W'], p_center=p_center,
|
tsq_full = hotelling(Wcv, dat['W'], p_center=p_center,
|
||||||
alpha=alpha, crot=crot, strict=strict,
|
alpha=alpha, crot=crot, strict=strict,
|
||||||
cov_center=cov_center, metric=metric)
|
cov_center=cov_center, metric=metric)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
@ -156,8 +155,8 @@ def pls_qvals(a, b, aopt=None, alpha=.3,
|
||||||
dat = bridge(a, b_shuff, aopt, 'loads','fast')
|
dat = bridge(a, b_shuff, aopt, 'loads','fast')
|
||||||
else:
|
else:
|
||||||
dat = pls(a, b, aopt, 'loads', 'fast')
|
dat = pls(a, b, aopt, 'loads', 'fast')
|
||||||
W = pls_jkW(a, b_shuff, aopt, n_blocks=None, algo=algo)
|
Wcv = pls_jkW(a, b_shuff, aopt, n_blocks=None, algo=algo)
|
||||||
TSQ[:,i] = hotelling(W, dat['W'],p_center=p_center,
|
TSQ[:,i] = hotelling(Wcv, dat['W'], p_center=p_center,
|
||||||
alpha=alpha, crot=crot, strict=strict,
|
alpha=alpha, crot=crot, strict=strict,
|
||||||
cov_center=cov_center, metric=metric)
|
cov_center=cov_center, metric=metric)
|
||||||
print time.time() - t1
|
print time.time() - t1
|
||||||
|
@ -194,8 +193,6 @@ def ensure_strict(C, only_flips=True):
|
||||||
raise ValueError,"Implement this!"
|
raise ValueError,"Implement this!"
|
||||||
return Cm*S
|
return Cm*S
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def leverage(aopt=1,*args):
|
def leverage(aopt=1,*args):
|
||||||
"""Returns leverages
|
"""Returns leverages
|
||||||
input : aopt, number of components to base leverage calculations on
|
input : aopt, number of components to base leverage calculations on
|
||||||
|
@ -213,7 +210,6 @@ def leverage(aopt=1,*args):
|
||||||
lev.append(lev_u)
|
lev.append(lev_u)
|
||||||
return lev
|
return lev
|
||||||
|
|
||||||
|
|
||||||
def variances(a,t,p):
|
def variances(a,t,p):
|
||||||
"""Returns explained variance and ind. var from blm-params.
|
"""Returns explained variance and ind. var from blm-params.
|
||||||
input:
|
input:
|
||||||
|
|
Reference in New Issue