223 lines
6.7 KiB
Python
223 lines
6.7 KiB
Python
"""Matrix cross validation selection generators
|
|
"""
|
|
from scipy import take,arange,ceil,repeat,newaxis,mean,asarray,dot,ones,\
|
|
random,array_split,floor,vstack,asarray,minimum
|
|
from cx_utils import randperm
|
|
|
|
def w_pls_gen(aat,b,n_blocks=None,center=True,index_out=False):
|
|
"""Random block crossvalidation for wide (XX.T) trick in PLS.
|
|
Leave-one-out is a subset, with n_blocks equals nSamples
|
|
|
|
aat -- outerproduct of X
|
|
b -- Y
|
|
n_blocks =
|
|
center -- use centering of calibration ,sets (aat_in,b_in) are centered
|
|
|
|
Returns:
|
|
-- aat_in,aat_out,b_in,b_out,[out]
|
|
"""
|
|
m,n = aat.shape
|
|
index = randperm(m)
|
|
nValuesInBlock = m/n_blocks
|
|
if n_blocks==m:
|
|
index = arange(m)
|
|
out_ind = [index[i*nValuesInBlock:(i+1)*nValuesInBlock] for i in range(n_blocks)]
|
|
|
|
for out in out_ind:
|
|
inn = [i for i in index if i not in out]
|
|
aat_in = aat[inn,:][:,inn]
|
|
aat_out = aat[out,:][:,inn]
|
|
b_in = b[inn,:]
|
|
b_out = b[out,:]
|
|
if center:
|
|
aat_in, mn = outerprod_centering(aat_in)
|
|
aat_out = aat_out - mn
|
|
if index_out:
|
|
yield aat_in,aat_out,b_in,b_out,out
|
|
else:
|
|
yield aat_in,aat_out,b_in,b_out
|
|
|
|
def pls_gen(a, b, n_blocks=None, center=False, index_out=False,axis=0, metric=None):
|
|
"""Random block crossvalidation
|
|
Leave-one-out is a subset, with n_blocks equals a.shape[-1]
|
|
"""
|
|
index = randperm(a.shape[axis])
|
|
if n_blocks==None:
|
|
n_blocks = a.shape[axis]
|
|
n_in_set = ceil(float(a.shape[axis])/n_blocks)
|
|
out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_blocks)]
|
|
for out in out_ind_sets:
|
|
inn = [i for i in index if i not in out]
|
|
acal = a.take(inn, 0)
|
|
atrue = a.take(out, 0)
|
|
bcal = b.take(inn, 0)
|
|
btrue = b.take(out, 0)
|
|
if center:
|
|
mn_a = acal.mean(0)[newaxis]
|
|
acal = acal - mn_a
|
|
atrue = atrue - mn_a
|
|
mn_b = bcal.mean(0)[newaxis]
|
|
bcal = bcal - mn_b
|
|
btrue = btrue - mn_b
|
|
if metric!=None:
|
|
acal = dot(acal, metric)
|
|
if index_out:
|
|
yield acal, atrue, bcal, btrue, out
|
|
else:
|
|
yield acal, atrue, bcal, btrue
|
|
|
|
|
|
def pca_gen(a, n_sets=None, center=False, index_out=False, axis=0, metric=None):
|
|
"""Returns a generator of crossvalidation sample segments.
|
|
|
|
input:
|
|
-- a, data matrix (m x n)
|
|
-- n_sets, number of segments/subsets to generate.
|
|
-- center, bool, choice of centering each subset
|
|
-- index_out, bool, return subset index
|
|
-- axis, int, which axis to get subset from
|
|
|
|
ouput:
|
|
-- V, generator with (n_sets) memebers (subsets)
|
|
|
|
"""
|
|
m = a.shape[axis]
|
|
index = randperm(m)
|
|
if n_sets==None:
|
|
n_sets = m
|
|
n_in_set = ceil(float(m)/n_sets)
|
|
out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_sets)]
|
|
for out in out_ind_sets:
|
|
inn = [i for i in index if i not in out]
|
|
acal = a.take(inn, 0)
|
|
atrue = a.take(out, 0)
|
|
if center:
|
|
mn_a = acal.mean(0)[newaxis]
|
|
acal = acal - mn_a
|
|
atrue = atrue - mn_a
|
|
if metric!=None:
|
|
acal = dot(acal, metric)
|
|
if index_out:
|
|
yield acal, atrue, out
|
|
else:
|
|
yield acal, atrue
|
|
|
|
def w_pls_gen_jk(a, b, n_sets=None, center=True,
|
|
index_out=False, axis=0):
|
|
"""Random block crossvalidation for wide X (m>>n)
|
|
Leave-one-out is a subset, with n_sets equals a.shape[-1]
|
|
|
|
Returns : X_m and X_m'Y_m
|
|
"""
|
|
m = a.shape[axis]
|
|
ab = dot(a.T, b)
|
|
index = randperm(m)
|
|
if n_sets==None:
|
|
n_sets = m
|
|
n_in_set = ceil(float(m)/n_sets)
|
|
out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_sets)]
|
|
for out in out_ind_sets:
|
|
inn = [i for i in index if i not in out]
|
|
nin = len(inn)
|
|
nout = len(out)
|
|
a_in = a[inn,:]
|
|
mn_a = 0
|
|
mAB = 0
|
|
if center:
|
|
mn_a = a_in.mean(0)[newaxis]
|
|
mAin = dot(-ones((1,nout)), a[out,:])/nin
|
|
mBin = dot(-ones((1,nout)), b[out,:])/nin
|
|
mAB = dot(mAin.T, (mBin*nin))
|
|
ab_in = ab - dot(a[out,].T, b[out,:]) - mAB
|
|
a_in = a_in - mn_a
|
|
|
|
if index_out:
|
|
yield a_in, ab_in, out
|
|
else:
|
|
yield a_in, ab_in
|
|
|
|
def shuffle_1d_block(a, n_sets=None, blocks=None, index_out=False, axis=0):
|
|
"""Random block shuffling along 1d axis
|
|
Returns : Shuffled a by axis
|
|
"""
|
|
m = a.shape[axis]
|
|
if blocks==None:
|
|
blocks = m
|
|
for ii in xrange(n_sets):
|
|
index = randperm(m)
|
|
if blocks==m:
|
|
a_out = a.take(index, axis)
|
|
else:
|
|
index = arange(m)
|
|
dummy = map(random.shuffle, array_split(index, blocks))
|
|
a_out = a.take(index, axis)
|
|
if index_out:
|
|
yield a_out, index
|
|
else:
|
|
yield a_out
|
|
|
|
def shuffle_1d(a, n_sets, axis=0):
|
|
"""Random shuffling along 1d axis.
|
|
|
|
Returns : Shuffled a by axis
|
|
"""
|
|
m = a.shape[axis]
|
|
for ii in xrange(n_sets):
|
|
index = randperm(m)
|
|
yield a.take(index, axis)
|
|
|
|
def diag_pert(a, n_sets=10, center=True, index_out=False):
|
|
"""Alter generator returning sets perturbed with means at diagonals.
|
|
|
|
input:
|
|
X -- matrix, data
|
|
alpha -- scalar, approx. portion of data perturbed
|
|
"""
|
|
|
|
m, n = a.shape
|
|
tr=False
|
|
if m>n:
|
|
a = a.T
|
|
m, n = a.shape
|
|
tr = True
|
|
if n_sets>m or n_sets>n:
|
|
msg = "You may not use more subsets than max(n_rows, n_cols)"
|
|
raise ValueError, msg
|
|
nm=n*m
|
|
start_inds = array_split(randperm(m),n_sets) # we use random start diags
|
|
if center:
|
|
a = a - mean(a, 0)[newaxis]
|
|
for v in range(n_sets):
|
|
a_out = a.copy()
|
|
out = []
|
|
for start in start_inds[v]:
|
|
ind = arange(start+v, nm, n+1)
|
|
[out.append(i) for i in ind]
|
|
if center:
|
|
a_out.put(a.mean(),ind)
|
|
else:
|
|
a_out.put(0, ind)
|
|
if tr:
|
|
a_out = a_out.T
|
|
|
|
if index_out:
|
|
yield a_out, asarray(out)
|
|
else:
|
|
yield a_out
|
|
|
|
|
|
def outerprod_centering(aat, ret_mn=True):
|
|
"""Returns mean centered symmetric outerproduct matrix.
|
|
"""
|
|
n = aat.shape[0]
|
|
h = aat.sum(0)[:,newaxis]
|
|
h = (h - mean(h)/2)/n
|
|
mn_a = h + h.T
|
|
aatc = aat - mn_a
|
|
if ret_mn:
|
|
return aatc, aat.mean(0)
|
|
return aat - mn_a
|
|
|
|
|
|
|