""" A collection of some statistical utitlites used.
"""
__all__ = ['hotelling', 'lpls_qvals']
__docformat__ = 'restructuredtext en'

from math import sqrt as msqrt

from numpy import dot,empty,zeros,eye,median,sign,arange,argsort
from numpy.linalg import svd,inv,det
from numpy.random import shuffle

from crossvalidation import lpls_jk
from engines import nipals_lpls as lpls


def hotelling(Pcv, P, p_center='median', cov_center=median,
              alpha=0.3, crot=True, strict=False):
    """Returns regularized hotelling T^2.

    Hotelling, is a generalization of Student's t statistic that is
    used in multivariate hypothesis testing. In order to avoid small variance
    samples to become significant this version allows borrowing variance
    from the pooled covariance.
    
    *Parameters*:
    
        Pcv : {array}
            Crossvalidation segements of paramter
        P : {array}
            Calibration model paramter
        p_center : {'median', 'mean', 'cal_model'}, optional
            Location method for sub-segments
        cov_center : {py_func}, optional
            Location function
        alpha : {float}, optional
            Regularisation towards pooled covariance estimate.
        crot : {boolean}, optional
            Rotate sub-segments toward calibration model.
        strict : {boolean}, optional
            Only rotate 90 degree
    
    *Returns*:
    
       tsq : {array}
           Hotellings T^2 estimate

    *Reference*:

        Gidskehaug et al., A framework for significance analysis of
        gene expression datausing dimension reduction methods, BMC
        bioinformatics, 2007
        
    *Notes*

        The rotational freedom in the solution of bilinear
        models may require that a rotation onto the calibration
        model. One way of doing that is procrustes rotation.
        
    """
    m, n = P.shape
    n_sets, n, amax = Pcv.shape
    T_sq = empty((n,), dtype='d')
    Cov_i = zeros((n, amax, amax), dtype='d')
    
    # rotate sub_models to full model
    if crot:
        for i, Pi in enumerate(Pcv):
            Pcv[i] = procrustes(P, Pi, strict=strict)

    # center of pnull
    if p_center=='median':
        P_ctr = median(Pcv)
    elif p_center=='mean':
        P_ctr = Pcv.mean(0)
    else: # calibration model
        P_ctr = P

    for i in xrange(n):
        Pi = Pcv[:,i,:] # (n_sets x amax) 
        Pi_ctr = P_ctr[i,:] # (1 x amax)
        Pim = (Pi - Pi_ctr)*msqrt(n_sets-1)
        Cov_i[i] = (1./n_sets)*dot(Pim.T, Pim)
        
    Cov = cov_center(Cov_i)
    reg_cov = (1. - alpha)*Cov_i + alpha*Cov
    for i in xrange(n):
        Pc = P_ctr[i,:]
        sigma = reg_cov[i]
        T_sq[i] = dot(dot(Pc, inv(sigma)), Pc)
    return T_sq

def procrustes(a, b, strict=True, center=False, verbose=False):
    """Orthogonal rotation of b to a.

    Procrustes rotation is an orthogonal rotoation of one subspace
    onto another by minimising the squared error.

    *Parameters*:
    
        a : {array}
            Input array
        b : {array}
            Input array
        strict : {boolean}
            Only do flipping and shuffling
        center : {boolean}
            Center before rotation, translate back after
        verbose : {boolean}
            Show sum of squares

    *Returns*:

        b_rot : {array}
            B-matrix rotated

    *Reference*:

        Schonemann, A generalized solution of the orthogonal Procrustes problem,
        Psychometrika, 1966 
    """
    
    if center:
        mn_a = a.mean(0)
        a = a - mn_a
        mn_b = b.mean(0)
        b = b - mn_b
    
    u, s, vt = svd(dot(b.T, a))    
    Cm = dot(u, vt) # Cm: orthogonal rotation matrix
    if strict:
       Cm = _ensure_strict(Cm)
    b_rot = dot(b, Cm)
    if verbose:
        print Cm.round()
        fit = sum(ravel(b - b_rot)**2)
        print "Error: %.3E" %fit
    if center:
        return mn_b + b_rot
    else:
        return b_rot

def _ensure_strict(C, only_flips=True):
    """Ensure that a rotation matrix does only 90 degree rotations.
    
    In multiplication with pcs this allows flips and reordering.
    if only_flips is True there will onlt be flips allowed

    *Parameters*:

        C : {array}
            Rotation matrix
        only_flips : {boolean}
            Only accept columns to flip (switch signs)

    *Returns*:

        C_rot : {array}
            Restricted rotation matrix
    
    *Notes*:
    
        This function is not ready for use. Use (only_flips=True)
    
    """
    if only_flips:
        C = eye(Cm.shape[0])*sign(Cm)
        return C
    Cm = zeros(C.shape, dtype='d')
    Cm[abs(C)>.6] = 1.
    if det(Cm)>1:
        raise NotImplementedError
    return Cm*S

def lpls_qvals(X, Y, Z, aopt=None, alpha=.3, zx_alpha=.5, n_iter=20,
               sim_method='shuffle',p_center='med', cov_center=median,
               crot=True,strict=False,mean_ctr=[2,0,2], nsets=None):

    """Returns qvals for l-pls model by permutation analysis.

    The response (Y) is randomly permuted, and the number of false positives
    is registered by comparing hotellings T2 statistics of the calibration model.
    
    *Parameters*:
    
        X : {array}
            Main data matrix (m, n)
        Y : {array}
            External row data (m, l)
        Z : {array}
            External column data (k, n)
        aopt : {integer}
            Optimal number of components
        alpha : {float}, optional
            Parameter to control the amount of influence from Z-matrix.
            0 is none, which returns a pls-solution, 1 is max
        mean_center : {array-like}, optional
            A three element array-like structure with elements in [-1,0,1,2],
            that decides the type of centering used.
            -1 : nothing
            0 : row center
            1 : column center
            2 : double center
        n_iter : {integer}, optional
            Number of permutations
        sim_method : ['shuffle'], optional
            Permutation method
        p_center : {'median', 'mean', 'cal_model'}, optional
            Location method for sub-segments
        cov_center : {py_func}, optional
            Location function
        alpha : {float}, optional
            Regularisation towards pooled covariance estimate.
        crot : {boolean}, optional
            Rotate sub-segmentss toward calibration model.
        strict : {boolean}, optional
            Only rotate 90 degree

        nsets : {integer}
            Number of crossvalidation segements
    
    *Reference*:

        Gidskehaug et al., A framework for significance analysis of
        gene expression data using dimension reduction methods, BMC
        bioinformatics, 2007
    """
    
    m, n = X.shape
    k, nz = Z.shape
    assert(n==nz)
    try:
        my, l = Y.shape
    except:
        # make Y a column vector
        Y = atleast_2d(Y).T
        my, l = Y.shape
    assert(m==my)
    
    pert_tsq_x = zeros((n, n_iter), dtype='d') # (nxvars x n_subsets)
    pert_tsq_z = zeros((k, n_iter), dtype='d') # (nzvars x n_subsets)

    # Full model
    dat = lpls(X, Y, Z, aopt, scale='loads', mean_ctr=mean_ctr)
    Wc, Lc = lpls_jk(X, Y, Z ,aopt)
    cal_tsq_x = hotelling(Wc, dat['W'], alpha=alpha)
    cal_tsq_z = hotelling(Lc, dat['L'], alpha=alpha)
    
    # Perturbations
    index = arange(m)
    for i in range(n_iter):
        indi = index.copy()
        shuffle(indi)
        dat = lpls(X, Y[indi,:], Z, aopt, scale='loads', mean_ctr=mean_ctr)
        Wi, Li = lpls_jk(X, Y[indi,:], Z, aopt, nsets=nsets)
        pert_tsq_x[:,i] = hotelling(Wi, dat['W'], alpha=alpha)
        pert_tsq_z[:,i] = hotelling(Li, dat['L'], alpha=alpha)
   
    return _fdr(cal_tsq_z, pert_tsq_z, median), _fdr(cal_tsq_x, pert_tsq_x, median)

def _fdr(tsq, tsqp, loc_method=median):
    """Returns false discovery rate.

    Fdr is a method used in multiple hypothesis testing to correct for multiple
    comparisons. It controls the expected proportion of incorrectly rejected null
    hypotheses (type I errors) in a list of rejected hypotheses.
    
    *Parameters*:

        tsq : {array}
            Hotellings T2, calibration model
        tsqp : {array}
            Hotellings T2, submodels

        loc_method : {py_func}
            Location method
    
    *Returns*:

        fdr : {array}
            False discovery rate

    """
    n, = tsq.shape
    k, m = tsqp.shape
    assert(n==k)
    n_false = empty((n, m), 'd')
    sort_index = argsort(tsq)[::-1]
    r_index = argsort(sort_index)
    for i in xrange(m):
        for j in xrange(n):
            n_false[j,i] = (tsqp[:,i]>tsq[j]).sum()
    fp = loc_method(n_false.T)
    n_signif = (arange(n) + 1.0)[r_index]
    fd_rate = fp/n_signif
    fd_rate[fd_rate>1] = 1
    return fd_rate