"""Module contain algorithms for  (burdensome) calculations.

There is no typechecking of any kind here, just focus on speed
"""

from scipy.linalg import svd,norm,inv,pinv,qr
from scipy import dot,empty,eye,newaxis,zeros,sqrt,diag,\
     apply_along_axis,mean,ones,randn,empty_like,outer,c_,\
     rand,sum,cumsum,matrix

def pca(a, aopt, scale='scores', mode='normal'):
    """ Principal Component Analysis model
    mode:
         -- fast : returns smallest dim scaled (T for n<=m, P for n>m )
         -- normal : returns all model params and residuals after aopt comp
         -- detailed    : returns all model params and all residuals
    """
    
    m, n = a.shape

    if m*10.>n:
        u, s, vt = esvd(a)
    else:
        u, s, vt = svd(a, full_matrices=0)
    eigvals = (1./m)*s
    T = u*s
    T = T[:,:aopt]
    P = vt[:aopt,:].T
    
    if scale=='loads':
        tnorm = apply_along_axis(norm, 0, T)
        T = T/tnorm
        P = P*tnorm

    if mode == 'fast':
        return {'T':T, 'P':P}
    
    if mode=='detailed':
        """Detailed mode returns residual matrix for all comp.
        That is E, is a three-mode matrix: (amax, m, n) """
        E = empty((aopt,  m,  n))
        for ai in range(aopt):
            e = a - dot(T[:,:ai+1], P[:,:ai+1].T)
            E[ai,:,:] = e.copy()
    else:
        E = a - dot(T,P.T)
            
    return {'T':T, 'P':P, 'E':E}

def pcr(a, b, aopt=2, scale='scores', mode='normal'):
    """Returns Principal component regression model."""
    m, n = a.shape
    try:
        k, l = b.shape
    except:
        k = b.shape[0]
        l = 1
    B = empty((aopt, n, l))
    U, s, Vt = svd(a, full_matrices=True)
    T = U*s
    T = T[:,:aopt]
    P = Vt[:aopt,:].T
    Q = dot(dot(inv(dot(T.T, T)), T.T), b).T
    for i in range(aopt):
        ti = T[:,:i+1]
        r = dot(dot(inv(dot(ti.T,ti)), ti.T), b)
        B[i] = dot(P[:,:i+1], r)
    E = a - dot(T, P.T)
    F = b - dot(T, Q.T)

    return {'T':T, 'P':P,'Q': Q, 'B':B, 'E':E, 'F':F}
    
def pls(a, b, aopt=2, scale='scores', mode='normal', ab=None):
    """Kernel pls for tall/wide matrices.

    Fast pls for calibration. Only inefficient for many Y-vars.
    
    """
    m, n = a.shape
    if ab!=None:
        mm, l = m_shape(ab)
    else:
        k, l = m_shape(b)

    W = empty((n, aopt))
    P = empty((n, aopt))
    R = empty((n, aopt))
    Q = empty((l, aopt))
    T = empty((m, aopt))
    B = empty((aopt, n, l))

    if ab==None: 
        ab = dot(a.T, b)
    for i in range(aopt):
        if ab.shape[1]==1:
            w = ab.reshape(n, l)
        else:
            u, s, vh = svd(dot(ab.T, ab))
            w = dot(ab, u[:,:1])
    
        w = w/norm(w)
        r = w.copy()
        if i>0:
            for j in range(0,i,1):
                r = r - dot(P[:,j].T, w)*R[:,j][:,newaxis]
        t = dot(a, r)
        tt = norm(t)**2
        p  = dot(a.T, t)/tt
        q = dot(r.T, ab).T/tt
        ab = ab - dot(p, q.T)*tt
        T[:,i] = t.ravel()
        W[:,i] = w.ravel()
        P[:,i] = p.ravel()
        R[:,i] = r.ravel()

        if mode=='fast' and i==aopt-1:
            if scale=='loads':
                tnorm = apply_along_axis(norm, 0, T)
                T = T/tnorm
                W = W*tnorm
            return {'T':T, 'W':W}

        Q[:,i] = q.ravel()
        B[i] = dot(R[:,:i+1], Q[:,:i+1].T)
    
    if mode=='detailed':
        E = empty((aopt, m, n))
        F = empty((aopt, k, l))
        for i in range(1, aopt+1, 1):
            E[i-1] = a - dot(T[:,:i], P[:,:i].T)
            F[i-1] = b - dot(T[:,:i], Q[:,:i].T)
    else:
        E = a - dot(T[:,:aopt], P[:,:aopt].T)
        F = b - dot(T[:,:aopt], Q[:,:aopt].T)

    if scale=='loads':
        tnorm = apply_along_axis(norm, 0, T)
        T = T/tnorm
        W = W*tnorm
        Q = Q*tnorm
        P = P*tnorm
        
    return {'B':B, 'Q':Q, 'P':P, 'T':T, 'W':W, 'R':R, 'E':E, 'F':F}

def w_simpls(aat, b, aopt):
    """ Simpls for wide matrices.
    Fast pls for crossval, used in calc rmsep for wide X
    There is no P,W.  T is normalised
    """
    bb = b.copy()
    m, m = aat.shape
    U = empty((m, aopt))
    T = empty((m, aopt))
    H = empty((m, aopt)) #just like W in simpls
    PROJ = empty((m, aopt)) #just like R in simpls

    for i in range(aopt):
        u, s, vh = svd(dot(dot(b.T, aat), b), full_matrices=0)
        u = dot(b, u[:,:1]) #y-factor scores
        U[:,i] = u.ravel()
        t = dot(aat, u)
        t = t/norm(t)
        T[:,i] = t.ravel()
        h = dot(aat, t) #score-weights
        H[:,i] = h.ravel()
        PROJ[:,:i+1] = dot(T[:,:i+1], inv(dot(T[:,:i+1].T, H[:,:i+1])) )
        if i<aopt:
            b = b - dot(PROJ[:,:i+1], dot(H[:,:i+1].T,b) )
    C = dot(bb.T, T)

    return {'T':T, 'U':U, 'Q':C, 'H':H}

def bridge(a, b, aopt, scale='scores', mode='normal', r=0):
    """Undeflated Ridged svd(X'Y)
    """
    m, n = a.shape
    k, l = m_shape(b)
    u, s, vt = svd(b, full_matrices=0)
    g0 = dot(u*s, u.T)
    g = (1 - r)*g0 + r*eye(m)
    ag = dot(a.T, g)
    u, s, vt = svd(ag, full_matrices=0)
    W = u[:,:aopt]
    K = vt[:aopt,:].T
    T = dot(a, W)
    tnorm = apply_along_axis(norm, 0, T) # norm of T-columns

    if mode == 'fast':
        if scale=='loads':
            T = T/tnorm
            W = W*tnorm
        return {'T':T, 'W':W}

    U = dot(g0, K) #fixme check this 
    Q = dot(b.T, dot(T, inv(dot(T.T, T)) ))
    B = zeros((aopt, n, l), dtype='f')
    for i in range(aopt):
        B[i] = dot(W[:,:i+1], Q[:,:i+1].T)
    # leverages
    # fixme: probably need an orthogonal basis for row-space leverage
    #        T (scores) are not orthogonal
    #        Using a qr decomp to get an orthonormal basis for row-space
    #Tq = qr(T)[0]
    #s_lev,v_lev = leverage(aopt,Tq,W)
    # explained variance
    #var_x, exp_var_x = variances(a,T,W)
    #qnorm = apply_along_axis(norm, 0, Q)
    #var_y, exp_var_y = variances(b,U,Q/qnorm)
    
    if mode == 'detailed':
        E = empty((aopt, m, n))
        F = empty((aopt, k, l))
        for i in range(aopt):
            E[i] = a - dot(T[:,:i+1], W[:,:i+1].T)
            F[i] = b - dot(a, B[i])
    else: #normal
        F = b - dot(a, B[-1])
        E = a - dot(T, W.T)

    if scale=='loads':
        T = T/tnorm
        W = W*tnorm
        Q = Q*tnorm
        
    return {'B':B, 'W':W, 'T':T, 'Q':Q, 'E':E, 'F':F, 'U':U, 'P':W}
    

def m_shape(array):
    return matrix(array).shape

def esvd(data,economy=1):
    """SVD with the option of economy sized calculation
    Calculate subspaces of X'X or XX' depending on the shape
    of the matrix.

    Good for extreme fat or thin matrices

    Numpy supports this by setting full_matrices=0
    """
    m, n = data.shape
    if m>=n:
        u, s, vt = svd(dot(data.T, data))
        u = dot(data, vt.T)
        v = vt.T
        for i in xrange(n):
            s[i] = norm(u[:,i])
            u[:,i] = u[:,i]/s[i]
    else:
        u, s, vt = svd(dot(data, data.T))
        v = dot(u.T, data)
        for i in xrange(m):
            s[i] = norm(v[i,:])
            v[i,:] = v[i,:]/s[i]

    return u, s, v