From 10eba079bccc1bafd7bc848b6e040f10e2ecfbc2 Mon Sep 17 00:00:00 2001
From: flatberg <flatberg@pvv.ntnu.no>
Date: Mon, 30 Jul 2007 18:04:42 +0000
Subject: [PATCH] iii

---
 fluents/lib/blmplots.py   |   4 +-
 fluents/lib/engines.py    | 163 ++++++++++++++++++++++++++++++++++----
 fluents/lib/validation.py |   4 +-
 matplotlibrc              |   2 +-
 4 files changed, 151 insertions(+), 22 deletions(-)

diff --git a/fluents/lib/blmplots.py b/fluents/lib/blmplots.py
index 6a62898..15b0369 100644
--- a/fluents/lib/blmplots.py
+++ b/fluents/lib/blmplots.py
@@ -203,7 +203,7 @@ class LplsXCorrelationPlot(BlmScatterPlot):
                               facecolor='gray',
                               alpha=.1,
                               zorder=1)
-        c50 = patches.Circle(center, radius=radius/2.0,
+        c50 = patches.Circle(center, radius= sqrt(radius/2.0),
                              facecolor='gray',
                              alpha=.1,
                              zorder=2)
@@ -228,7 +228,7 @@ class LplsZCorrelationPlot(BlmScatterPlot):
                               facecolor='gray',
                               alpha=.1,
                               zorder=1)
-        c50 = patches.Circle(center, radius=radius/2.0,
+        c50 = patches.Circle(center, radius=sqrt(radius/2.0),
                              facecolor='gray',
                              alpha=.1,
                              zorder=2)
diff --git a/fluents/lib/engines.py b/fluents/lib/engines.py
index 219f1ce..0d52917 100644
--- a/fluents/lib/engines.py
+++ b/fluents/lib/engines.py
@@ -14,6 +14,7 @@ try:
 except:
     has_sym = False
 
+
 def pca(a, aopt,scale='scores',mode='normal',center_axis=0):
     """ Principal Component Analysis.
 
@@ -187,7 +188,7 @@ def pcr(a, b, aopt, scale='scores',mode='normal',center_axis=0):
     dat.update({'Q':Q, 'F':F, 'expvary':expvary})
     return dat
 
-def pls(a, b, aopt=2, scale='scores', mode='normal', ax_center=0, ab=None):
+def pls(a, b, aopt=2, scale='scores', mode='normal', center_axis=0, ab=None):
     """Partial Least Squares Regression.
 
     Performs PLS on given matrix and returns results in a dictionary.
@@ -244,6 +245,10 @@ def pls(a, b, aopt=2, scale='scores', mode='normal', ax_center=0, ab=None):
         assert(m==mm)
     else:
          k, l = m_shape(b)
+
+    if center_axis>=0:
+        a = a - expand_dims(a.mean(center_axis), center_axis)
+        b = b - expand_dims(b.mean(center_axis), center_axis)
     
     W = empty((n, aopt))
     P = empty((n, aopt))
@@ -255,25 +260,28 @@ def pls(a, b, aopt=2, scale='scores', mode='normal', ax_center=0, ab=None):
     if ab==None:
         ab = dot(a.T, b)
     for i in range(aopt):
-        if ab.shape[1]==1:
+        if ab.shape[1]==1: #pls 1
             w = ab.reshape(n, l)
             w = w/vnorm(w)
-        elif n<l:
+        elif n<l: # more yvars than xvars
             if has_sym:
-                s, u = symeig(dot(ab.T, ab),range=[l,l],overwrite=True)
+                s, u = symeig(dot(ab, ab.T),range=[l,l],overwrite=True)
             else:
                 u, s, vh = svd(dot(ab, ab.T))
             w = u[:,0]
-        else:
+        else: # standard wide xdata
             if has_sym:
-                s, u = symeig(dot(ab.T, ab),range=[l,l],overwrite=True)
+                s, q = symeig(dot(ab.T, ab),range=[l,l],overwrite=True)
             else:
-                u, s, vh = svd(dot(ab.T, ab))
-            w = dot(ab, u)
+                q, s, vh = svd(dot(ab.T, ab))
+                q = q[:,:1]
+            w = dot(ab, q)
+            w = w/vnorm(w)
         r = w.copy()
         if i>0:
             for j in range(0, i, 1):
                 r = r - dot(P[:,j].T, w)*R[:,j][:,newaxis]
+        print vnorm(r)
         t = dot(a, r)
         tt = vnorm(t)**2
         p  = dot(a.T, t)/tt
@@ -345,9 +353,13 @@ def w_pls(aat, b, aopt):
     """ Pls for wide matrices.
     Fast pls for crossval, used in calc rmsep for wide X
     There is no P or W.  T is normalised
+
+    aat = centered kernel matrix
+    b = centered y
     """
     bb = b.copy()
-    m, m = aat.shape
+    k, l = m_shape(b)
+    m, m = m_shape(aat)
     U = empty((m, aopt)) # W
     T = empty((m, aopt))
     R = empty((m, aopt)) # R
@@ -355,23 +367,28 @@ def w_pls(aat, b, aopt):
 
     for i in range(aopt):
         if has_sym:
-            pass
+            s, q = symeig(dot(dot(b.T, aat), b), range=(l,l),overwrite=True)
         else:
             q, s, vh = svd(dot(dot(b.T, aat), b), full_matrices=0)
             q = q[:,:1]
         u = dot(b , q) #y-factor scores
         U[:,i] = u.ravel()
         t = dot(aat, u)
+        print "Norm of t: %s" %vnorm(t)
+        print "s: %s" %s
+        
         t = t/vnorm(t)
         T[:,i] = t.ravel()
-        r = dot(aat, t) #score-weights
+        r = dot(aat, t)#score-weights
+        #r = r/vnorm(r)
+        print "Norm R: %s" %vnorm(r)
         R[:,i] = r.ravel()
         PROJ[:,: i+1] = dot(T[:,:i+1], inv(dot(T[:,:i+1].T, R[:,:i+1])) )
         if i<aopt:
-            b = b - dot(PROJ[:,:i+1], dot(R[:,:i+1].T,b) )
+            b = b - dot(PROJ[:,:i+1], dot(R[:,:i+1].T,  b) )
     C = dot(bb.T, T)
 
-    return {'T':T, 'U':U, 'Q':C, 'H':H}
+    return {'T':T, 'U':U, 'Q':C, 'R':R}
 
 def bridge(a, b, aopt, scale='scores', mode='normal', r=0):
     """Undeflated Ridged svd(X'Y)
@@ -476,6 +493,8 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], mode='normal', sca
     P = empty((n, a_max))
     K = empty((o, a_max))
     L = empty((u, a_max))
+    B = empty((a_max, n, l))
+    b0 = empty((a_max, m, l))
     var_x = empty((a_max,))
     var_y = empty((a_max,))
     var_z = empty((a_max,))
@@ -485,8 +504,8 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], mode='normal', sca
             print "\n Working on comp. %s" %a
         u = Y[:,:1]
         diff = 1
-        MAX_ITER = 100
-        lim = 1e-5
+        MAX_ITER = 200
+        lim = 1e-16
         niter = 0
         while (diff>lim and niter<MAX_ITER):
             niter += 1
@@ -526,8 +545,8 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], mode='normal', sca
         var_y[a] = pow(Y, 2).sum()
         var_z[a] = pow(Z, 2).sum()
     
-    B = dot(dot(W, inv(dot(P.T, W))), Q.T)
-    b0 = mnY - dot(mnX, B)
+        B[a] = dot(dot(W[:,:a+1], inv(dot(P[:,:a+1].T, W[:,:a+1]))), Q[:,:a+1].T)
+        b0[a] = mnY - dot(mnX, B[a])
     
     # variance explained
     evx = 100.0*(1 - var_x/varX)
@@ -546,6 +565,116 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], mode='normal', sca
 
 
 
+def nipals_pls(X, Y, a_max, alpha=.7, ax_center=0, mode='normal', scale='scores', verbose=False):
+    """Partial Least Sqaures Regression by the nipals algorithm.
+
+    (X!Z)->Y
+    :input:
+        X : data matrix (m, n)
+        Y : data matrix (m, l)
+
+    :output:
+      T : X-scores
+      W : X-weights/Z-weights
+      P : X-loadings
+      Q : Y-loadings
+      U : X-Y relation
+      B : Regression coefficients X->Y
+      b0: Regression coefficient intercept
+      evx : X-explained variance
+      evy : Y-explained variance
+      evz : Z-explained variance
+
+    :Notes:
+    
+    """
+    if ax_center>=0:
+        mn_x = expand_dims(X.mean(ax_center), ax_center)
+        mn_y = expand_dims(Y.mean(ax_center), ax_center)
+        X = X - mn_x
+        Y = Y - mn_y
+
+    varX = pow(X, 2).sum()
+    varY = pow(Y, 2).sum()
+    
+    m, n = X.shape
+    k, l = Y.shape
+
+    # initialize 
+    U = empty((k, a_max))
+    Q = empty((l, a_max))
+    T = empty((m, a_max))
+    W = empty((n, a_max))
+    P = empty((n, a_max))
+    B = empty((a_max, n, l))
+    b0 = empty((a_max, m, l))
+    var_x = empty((a_max,))
+    var_y = empty((a_max,))
+
+    t1 = X[:,:1]
+    for a in range(a_max):
+        if verbose:
+            print "\n Working on comp. %s" %a
+        u = Y[:,:1]
+        diff = 1
+        MAX_ITER = 100
+        lim = 1e-16
+        niter = 0
+        while (diff>lim and niter<MAX_ITER):
+            niter += 1
+            #u1 = u.copy()
+            w = dot(X.T, u)
+            w = w/sqrt(dot(w.T, w))
+            #l = dot(Z, w)
+            #k = dot(Z.T, l)
+            #k = k/sqrt(dot(k.T, k))
+            #w = alpha*k + (1-alpha)*w
+            #w = w/sqrt(dot(w.T, w))
+            t = dot(X, w)
+            q = dot(Y.T, t)
+            q = q/sqrt(dot(q.T, q))
+            u = dot(Y, q)
+            diff = vnorm(t1 - t)
+            t1 = t.copy()
+        if verbose:
+            print "Converged after %s iterations" %niter
+        #tt = dot(t.T, t)
+        #p = dot(X.T, t)/tt
+        #q = dot(Y.T, t)/tt
+        #l = dot(Z, w)
+        p = dot(X.T, t)/dot(t.T, t)
+        p_norm = vnorm(p)
+        t = t*p_norm
+        w = w*p_norm
+        p = p/p_norm
+        
+        U[:,a] = u.ravel()
+        W[:,a] = w.ravel()
+        P[:,a] = p.ravel()
+        T[:,a] = t.ravel()
+        Q[:,a] = q.ravel()
+
+        X = X - dot(t, p.T)
+        Y = Y - dot(t, q.T)
+
+        var_x[a] = pow(X, 2).sum()
+        var_y[a] = pow(Y, 2).sum()
+    
+        B[a] = dot(dot(W[:,:a+1], inv(dot(P[:,:a+1].T, W[:,:a+1]))), Q[:,:a+1].T)
+        b0[a] =  mn_y - dot(mn_x, B[a])
+    
+    # variance explained
+    evx = 100.0*(1 - var_x/varX)
+    evy = 100.0*(1 - var_y/varY)
+    
+    if scale=='loads':
+        tnorm = apply_along_axis(vnorm, 0, T)
+        T = T/tnorm
+        W = W*tnorm
+        Q = Q*tnorm
+    
+    return {'T':T, 'W':W, 'P':P, 'Q':Q, 'U':U, 'B':B, 'b0':b0, 'evx':evx, 'evy':evy}
+
 
 ########### Helper routines #########
 
diff --git a/fluents/lib/validation.py b/fluents/lib/validation.py
index eb49197..6ce394c 100644
--- a/fluents/lib/validation.py
+++ b/fluents/lib/validation.py
@@ -1,7 +1,7 @@
 """This module implements some common validation schemes from pca and pls.
 """
 from scipy import ones,mean,sqrt,dot,newaxis,zeros,sum,empty,\
-     apply_along_axis,eye,kron,array,sort
+     apply_along_axis,eye,kron,array,sort,zeros_like,argmax
 from scipy.stats import median
 from scipy.linalg import triu,inv,svd,norm
 
@@ -122,7 +122,7 @@ def lpls_val(X, Y, Z, a_max=2, nsets=None,alpha=.5):
         B = dat['B']
         b0 = dat['b0']
         for a in range(a_max):
-            Yhat[a,ind,:] = b0[a][0][0] + dot(xi, B[a])
+            Yhat[a,ind,:] = b0[a][0][0] + dot(xi-xcal.mean(0), B[a])
     Yhat_class = zeros_like(Yhat)
     for a in range(a_max):
         for i in range(k):
diff --git a/matplotlibrc b/matplotlibrc
index 31ccdc1..ad3b671 100644
--- a/matplotlibrc
+++ b/matplotlibrc
@@ -132,7 +132,7 @@ text.dvipnghack     : False  # some versions of dvipng don't handle
 # default fontsizes for ticklabels, and so on.  See
 # http://matplotlib.sourceforge.net/matplotlib.axes.html#Axes
 axes.hold           : True    # whether to clear the axes by default on
-axes.facecolor      : white   # axes background color
+axes.facecolor      : 0.6   # axes background color
 axes.edgecolor      : black   # axes edge color
 axes.linewidth      : 1.0     # edge linewidth
 axes.grid           : True   # display grid or not