"""Matrix cross validation selection generators """ from scipy import take,arange,ceil,repeat,newaxis,mean,asarray,dot,ones,\ random,array_split,floor,vstack,asarray,minimum from cx_utils import randperm def w_pls_gen(aat,b,n_blocks=None,center=True,index_out=False): """Random block crossvalidation for wide (XX.T) trick in PLS. Leave-one-out is a subset, with n_blocks equals nSamples aat -- outerproduct of X b -- Y n_blocks = center -- use centering of calibration ,sets (aat_in,b_in) are centered Returns: -- aat_in,aat_out,b_in,b_out,[out] """ m,n = aat.shape index = randperm(m) nValuesInBlock = m/n_blocks if n_blocks==m: index = arange(m) out_ind = [index[i*nValuesInBlock:(i+1)*nValuesInBlock] for i in range(n_blocks)] for out in out_ind: inn = [i for i in index if i not in out] aat_in = aat[inn,:][:,inn] aat_out = aat[out,:][:,inn] b_in = b[inn,:] b_out = b[out,:] if center: # centering projector: I - (1/n)11' # nin = len(inn) # Pc = eye(nin) - outer(ones((nin,)),ones((nin,)))/nin # xxt - x( outer(ones((nin,)),ones((nin,)))/nin ) x.T # de jong: h = sum(aat_in,0)[ :,newaxis] h = (h - mean(h)/2)/len(inn) mn_a = h + h.T aat_in = aat_in - mn_a if index_out: yield aat_in,aat_out,b_in,b_out,out else: yield aat_in,aat_out,b_in,b_out def pls_gen(a,b, n_blocks=None, center=False, index_out=False,axis=0): """Random block crossvalidation Leave-one-out is a subset, with n_blocks equals a.shape[-1] """ index = randperm(a.shape[axis]) if n_blocks==None: n_blocks = a.shape[axis] n_in_set = ceil(float(a.shape[axis])/n_blocks) out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_blocks)] for out in out_ind_sets: inn = [i for i in index if i not in out] if center: a = a - mean(a,0)[newaxis] b = b - mean(b,0)[newaxis] if index_out: yield a.take(inn,0),a.take(out,0), b.take(inn,0),b.take(out,0),out else: yield a.take(inn,0),a.take(out,0), b.take(inn,0),b.take(out,0) def pca_gen(a,n_sets=None, center=False, index_out=False,axis=0): """PCA random block crossval generator. """ m = a.shape[axis] index = randperm(m) if n_sets==None: n_sets = m n_in_set = ceil(float(m)/n_sets) out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_sets)] for out in out_ind_sets: inn = [i for i in index if i not in out] if center: a = a - mean(a,0)[newaxis] if index_out: yield a.take(inn,0),a.take(out,0),out else: yield a.take(inn,0),a.take(out,0) def w_pls_gen_jk(a,b,n_sets=None,center=True,index_out=False,axis=0): """Random block crossvalidation for wide X (m>>n) Leave-one-out is a subset, with n_sets equals a.shape[-1] Returns : X_m and X_m'Y_m """ m = a.shape[axis] ab = dot(a.T,b) index = randperm(m) if n_sets==None: n_sets = m n_in_set = ceil(float(m)/n_sets) out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_sets)] for out in out_ind_sets: inn = [i for i in index if i not in out] nin = len(inn) nout = len(out) a_in = a[inn,:] mn_a = 0 mAB = 0 if center: mn_a = mean(a,0)[newaxis] mAin = dot(-ones((1,nout)),a[out,:])/nin mBin = dot(-ones((1,nout)),b[out,:])/nin mAB = dot(mAin.T,(mBin*nin)) ab_in = ab - dot(a[out,].T,b[out,:]) - mAB a_in = a_in - mn_a if index_out: yield ain,ab, out else: yield a_in, ab def shuffle_1d_block(a, n_sets=None, blocks=None, index_out=False, axis=0): """Random block shuffling along 1d axis Returns : Shuffled a by axis """ m = a.shape[axis] if blocks==None: blocks = m for ii in xrange(n_sets): index = randperm(m) if blocks==m: a_out = a.take(index, axis) else: index = arange(m) dummy = map(random.shuffle, array_split(index, blocks)) a_out = a.take(index, axis) if index_out: yield a_out, index else: yield a_out def shuffle_1d(a, n_sets, axis=0): """Random shuffling along 1d axis. Returns : Shuffled a by axis """ m = a.shape[axis] for ii in xrange(n_sets): index = randperm(m) yield a.take(index, axis) def diag_pert(a, n_sets=10, center=True, index_out=False): """Alter generator returning sets perturbed with means at diagonals. input: X -- matrix, data alpha -- scalar, approx. portion of data perturbed """ m, n = a.shape tr=False if m>n: a = a.T m, n = a.shape tr = True if n_sets>m or n_sets>n: msg = "You may not use more subsets than max(n_rows, n_cols)" raise ValueError, msg nm=n*m start_inds = array_split(randperm(m),n_sets) # we use random start diags if center: a = a - mean(a, 0)[newaxis] for v in range(n_sets): a_out = a.copy() out = [] for start in start_inds[v]: ind = arange(start+v, nm, n+1) [out.append(i) for i in ind] if center: a_out.put(a.mean(),ind) else: a_out.put(0, ind) if tr: a_out = a_out.T if index_out: yield a_out, asarray(out) else: yield a_out