"""This module implements some common validation schemes from pca and pls.
"""
from scipy import ones,mean,sqrt,dot,newaxis,zeros,sum,empty,\
     apply_along_axis,eye,kron,array,sort
from scipy.stats import median
from scipy.linalg import triu,inv,svd,norm

from select_generators import w_pls_gen,w_pls_gen_jk,pls_gen,pca_gen,diag_pert
from engines import w_simpls,pls,bridge,pca
from cx_utils import m_shape

def w_pls_cv_val(X, Y, amax, n_blocks=None, algo='simpls'):
    """Returns and  RMSEP for pls tailored for wide X.
    """
    k, l = m_shape(Y)
    PRESS = zeros((l, amax+1), dtype='f')
    # X,Y are centered0
    if n_blocks==None:
        n_blocks = Y.shape[0]
    XXt = dot(X, X.T)
    V = w_pls_gen(XXt, Y, n_blocks=n_blocks, center=True)
    for Din, Doi, Yin, Yout in V:
        ym = -sum(Yout, 0)[newaxis]/(1.0*Yin.shape[0])
        Yin = Yin - ym
        PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)
        if algo=='simpls':
            dat = w_simpls(Din, Yin, amax)
            Q, U, H = dat['Q'], dat['U'], dat['H']
            That = dot(Doi, dot(U, inv(triu(dot(H.T,U))) ))
        else:
            raise NotImplementedError
        #Yhat = empty((amax, k, l),dtype='<f8')
        Yhat = []
        for j in range(l):
            TQ = dot(That, triu(dot(Q[j,:][:,newaxis], ones((1,amax)))) )
            E = Yout[:,j][:,newaxis] - TQ
            E = E + sum(E, 0)/Din.shape[0]
            PRESS[j,1:] = PRESS[j,1:] + sum(E**2, 0)
    #Yhat = Y - dot(That,Q.T)
    rmsep = sqrt(PRESS/Y.shape[0])
    aopt = find_aopt_from_sep(rmsep)
    return rmsep, aopt

def pls_val(X, Y, amax=2, n_blocks=10,algo='pls'):
    """ Validation results of pls model. 
    """    
    k, l = m_shape(Y)
    PRESS = zeros((l, amax+1), dtype='<f8')
    EE = zeros((amax, k, l), dtype='<f8')
    Yhat = zeros((amax, k, l), dtype='<f8')
    # X,Y are centered
    V = pls_gen(X, Y, n_blocks=n_blocks, center=True, index_out=True)
    for Xin, Xout, Yin, Yout, out in V:
        ym = -sum(Yout,0)[newaxis]/Yin.shape[0]
        Yin = (Yin - ym)
        PRESS[:,0] = PRESS[:,0] + ((Yout - ym)**2).sum(0)

        if algo=='pls':
            dat = pls(Xin, Yin, amax, mode='normal')
        elif algo=='bridge':
            dat = simpls(Xin, Yin, amax, mode='normal')
        
        for a in range(amax):
            Ba = dat['B'][a,:,:]
            Yhat[a,out[:],:] = dot(Xout, Ba)
            E = Yout -  dot(Xout, Ba)
            EE[a,out,:] = E
            PRESS[:,a+1] = PRESS[:,a+1] + sum(E**2,0)

    rmsep = sqrt(PRESS/(k-1.))
    aopt = find_aopt_from_sep(rmsep)
    return rmsep, aopt

def pca_alter_val(a, amax, n_sets=10, method='diag'):
    """Pca validation by altering elements in X.
    """
    # todo: it is just as easy to do jk-estimates her as well
    V = diag_pert(a, n_sets, center=True, index_out=True)
    sep = empty((n_sets, amax), dtype='f')
    for i, (xi, ind) in enumerate(V):
        dat_i = pca(xi, amax, mode='detailed')
        Ti,Pi = dat_i['T'],dat_i['P']
        for j in xrange(amax):
            Xhat = dot(Ti[:,:j+1], Pi[:,:j+1].T)
            a_sub = a.ravel().take(ind)
            EE = a_sub - Xhat.ravel().take(ind)
            tot = (a_sub**2).sum()
            sep[i,j] = (EE**2).sum()/tot
    sep = sqrt(sep)
    aopt = find_aopt_from_sep(sep)
    return sep, aopt

def pca_cv_val(a, amax, n_sets):
    """ Returns PRESS from cross-validated pca using random segments.

    input:
          -- a, data matrix (m x n)
          -- amax, maximum nuber of components used
          -- n_sets, number of segments to calculate
    output:
          -- sep, (amax x m x n), squared error of prediction (press)
          -- aopt, guestimated optimal number of components
    """
    m, n = a.shape
    E = empty((amax, m, n), dtype='f')
    xtot = (a**2).sum() # this needs centering
    V = pca_gen(a, n_sets=7, center=True, index_out=True)
    for xi, xout, ind in V:
        dat_i = pca(xi, amax, mode='fast')
        Pi = dat_i['P']
        for a in xrange(amax):
            Pia = Pi[:,:a+1]
            E[a][ind,:] = (X[ind,:] - dot(xout, dot(Pia,Pia.T) ))**2

    sep = []
    for a in xrange(amax):
        sep.append(E[a].sum()/xtot)
    sep = array(sep)
    aopt = find_aopt_from_sep(sep)
    return sep, aopt

def pls_jkW(a, b, amax, n_blocks=None, algo='pls', use_pack=False, center=True):
    """ Returns CV-segments of paramter W for wide X.

    todo: add support for T,Q and B
    """
    if n_blocks == None:
        n_blocks = b.shape[0]

    Wcv = empty((n_blocks, a.shape[1], amax), dtype='f')

    if use_pack:
        u, s, inflater = svd(a, full_matrices=0)
        a = u*s

    V = pls_gen(a, b, n_blocks=n_blocks, center=center)
    for nn,(a_in, a_out, b_in, b_out) in enumerate(V):
        if algo=='pls':
            dat = pls(a_in, b_in, amax, 'loads', 'fast')
        elif algo=='bridge':
            dat = bridge(a_in, b_in, amax, 'loads', 'fast')
        W = dat['W']

        if use_pack:
            W = dot(inflater.T, W)

        Wcv[nn,:,:] = W
        
    return Wcv

def pca_jkP(a, aopt, n_blocks=None):
    """Returns loading from PCA on CV-segments.
    
    input:
           -- a, data matrix (n x m)
           -- aopt, number of components in model.
           -- nblocks, number of segments
    output:
           -- PP, loadings collected in a three way matrix
           (n_segments, m, aopt)

    comments:
    * The loadings are scaled with the (1/samples)*eigenvalues.
    * Crossvalidation method is currently set to random blocks of samples.

    todo: add support for T
    fixme: more efficient to add this in validation loop
    """
    if n_blocks == None:
        n_blocks = a.shape[0]

    PP = empty((n_blocks, a.shape[1], aopt), dtype='f')
    V = pca_gen(a, n_sets=n_blocks, center=True)
    for nn,(a_in, a_out) in enumerate(V):  
        dat = pca(a_in, aopt, mode='fast', scale='loads')
        P = dat['P']
        PP[nn,:,:] = P
        
    return PP

def find_aopt_from_sep(sep, method='75perc'):
    """Returns an estimate of optimal number of components from rmsecv.
    """
    if method=='vanilla':
        # min rmsep
        rmsecv = sqrt(sep.mean(0))
        return rmsecv.argmin() + 1

    elif method=='75perc':
        prct = .75 #percentile
        ind = 1.*sep.shape[0]*prct
        med = median(sep)
        prc_75 = []
        for col in sep.T:
            col.sort()
            prc_75.append(col[int(ind)])
        prc_75 = array(prc_75)
        for i in range(1, sep.shape[1], 1):
            if med[i-1]<prc_75[i]:
                return i
        return len(med)