"""Matrix cross validation selection generators
"""
from scipy import take,arange,ceil,repeat,newaxis,mean,asarray,dot,ones,\
     random,array_split,floor,vstack,asarray,minimum
from cx_utils import randperm

def w_pls_gen(aat,b,n_blocks=None,center=True,index_out=False):
     """Random block crossvalidation for wide (XX.T) trick in PLS.
     Leave-one-out is a subset, with n_blocks equals nSamples
     
     aat -- outerproduct of X
     b -- Y
     n_blocks = 
     center -- use centering of calibration ,sets (aat_in,b_in) are centered

     Returns:
         -- aat_in,aat_out,b_in,b_out,[out]
     """
     m,n = aat.shape
     index = randperm(m)
     nValuesInBlock = m/n_blocks
     if n_blocks==m:
         index = arange(m)
     out_ind = [index[i*nValuesInBlock:(i+1)*nValuesInBlock] for i in range(n_blocks)]
     
     for out in out_ind:
          inn = [i for i in index if i not in out]
          aat_in = aat[inn,:][:,inn]
          aat_out = aat[out,:][:,inn]
          b_in = b[inn,:]
          b_out = b[out,:]
          if center:
               aat_in, mn = outerprod_centering(aat_in)
               aat_out = aat_out - mn
          if index_out:
               yield aat_in,aat_out,b_in,b_out,out
          else:
               yield aat_in,aat_out,b_in,b_out

def pls_gen(a, b, n_blocks=None, center=False, index_out=False,axis=0, metric=None):
     """Random block crossvalidation
    Leave-one-out is a subset, with n_blocks equals a.shape[-1]
    """
     index = randperm(a.shape[axis])
     if n_blocks==None:
          n_blocks = a.shape[axis]
     n_in_set = ceil(float(a.shape[axis])/n_blocks)
     out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_blocks)]
     for out in out_ind_sets:
         inn = [i for i in index if i not in out]
         acal = a.take(inn, 0)
         atrue = a.take(out, 0)
         bcal = b.take(inn, 0)
         btrue = b.take(out, 0)
         if center:
              mn_a = acal.mean(0)[newaxis]
              acal = acal - mn_a
              atrue = atrue - mn_a
              mn_b = bcal.mean(0)[newaxis]
              bcal = bcal - mn_b
              btrue = btrue - mn_b
         if metric!=None:
              acal = dot(acal, metric)
         if index_out:
              yield acal, atrue, bcal, btrue, out
         else:     
              yield acal, atrue, bcal, btrue

         
def pca_gen(a, n_sets=None, center=False, index_out=False, axis=0, metric=None):
     """Returns a generator of crossvalidation sample segments.

     input:
           -- a, data matrix (m x n)
           -- n_sets, number of segments/subsets to generate.
           -- center, bool, choice of centering each subset
           -- index_out, bool, return subset index
           -- axis, int, which axis to get subset from

     ouput:
           -- V, generator with (n_sets) memebers (subsets)
     
     """
     m = a.shape[axis]
     index = randperm(m)
     if n_sets==None:
          n_sets = m
     n_in_set = ceil(float(m)/n_sets)
     out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_sets)]
     for out in out_ind_sets:
         inn = [i for i in index if i not in out]
         acal = a.take(inn, 0)
         atrue = a.take(out, 0)
         if center:
              mn_a = acal.mean(0)[newaxis]
              acal = acal - mn_a
              atrue = atrue - mn_a
         if metric!=None:
              acal = dot(acal, metric)
         if index_out:
              yield acal, atrue, out
         else:
              yield acal, atrue

def w_pls_gen_jk(a, b, n_sets=None, center=True,
                 index_out=False, axis=0):
     """Random block crossvalidation for wide X (m>>n)
     Leave-one-out is a subset, with n_sets equals a.shape[-1]

     Returns : X_m and X_m'Y_m
     """
     m = a.shape[axis]
     ab = dot(a.T, b)
     index = randperm(m)
     if n_sets==None:
          n_sets = m
     n_in_set = ceil(float(m)/n_sets)
     out_ind_sets = [index[i*n_in_set:(i+1)*n_in_set] for i in range(n_sets)]
     for out in out_ind_sets:
         inn = [i for i in index if i not in out]
         nin = len(inn)
         nout = len(out)
         a_in = a[inn,:]
         mn_a = 0
         mAB = 0
         if center:
              mn_a = a_in.mean(0)[newaxis]
              mAin = dot(-ones((1,nout)), a[out,:])/nin
              mBin = dot(-ones((1,nout)), b[out,:])/nin
              mAB = dot(mAin.T, (mBin*nin))
         ab_in = ab - dot(a[out,].T, b[out,:]) - mAB
         a_in = a_in - mn_a

         if index_out:
              yield a_in, ab_in, out
         else:     
              yield a_in, ab_in

def shuffle_1d_block(a, n_sets=None, blocks=None, index_out=False, axis=0):
     """Random block shuffling along 1d axis
     Returns : Shuffled a by axis
     """
     m = a.shape[axis]
     if blocks==None:
         blocks = m
     for ii in xrange(n_sets):
         index = randperm(m)
         if blocks==m:
             a_out = a.take(index, axis)
         else:
             index = arange(m)
             dummy = map(random.shuffle, array_split(index, blocks))
             a_out = a.take(index, axis)
         if index_out:
              yield a_out, index
         else:
              yield a_out

def shuffle_1d(a, n_sets, axis=0):
     """Random shuffling along 1d axis.

     Returns : Shuffled a by axis
     """
     m = a.shape[axis]
     for ii in xrange(n_sets):
         index = randperm(m)
         yield a.take(index, axis)
         
def diag_pert(a, n_sets=10, center=True, index_out=False):
    """Alter generator returning sets perturbed with means at diagonals.

    input:
            X -- matrix, data
            alpha -- scalar, approx. portion of data perturbed  
    """
    
    m, n = a.shape
    tr=False
    if m>n:
         a = a.T
         m, n = a.shape
         tr = True
    if n_sets>m or n_sets>n:
         msg = "You may not use more subsets than max(n_rows, n_cols)"
         raise ValueError, msg
    nm=n*m
    start_inds = array_split(randperm(m),n_sets) # we use random start diags
    if center:
         a = a - mean(a, 0)[newaxis]
    for v in range(n_sets):
        a_out = a.copy()
        out = []
        for start in start_inds[v]: 
            ind = arange(start+v, nm, n+1)
            [out.append(i) for i in ind]
            if center:
                a_out.put(a.mean(),ind) 
            else:
                 a_out.put(0, ind)
        if tr:
             a_out = a_out.T
             
        if index_out:
             yield a_out, asarray(out)
        else:
             yield a_out


def outerprod_centering(aat, ret_mn=True):
    """Returns mean centered symmetric outerproduct matrix.
    """
    n = aat.shape[0]
    h = aat.sum(0)[:,newaxis]
    h = (h - mean(h)/2)/n
    mn_a = h + h.T
    aatc = aat - mn_a
    if ret_mn:
        return aatc, aat.mean(0)
    return aat - mn_a