Projects/laydi
Projects
/
laydi
Archived
7
0
Fork 0
This commit is contained in:
Arnar Flatberg 2007-07-30 18:04:42 +00:00
parent aa4007e208
commit 10eba079bc
4 changed files with 151 additions and 22 deletions

View File

@ -203,7 +203,7 @@ class LplsXCorrelationPlot(BlmScatterPlot):
facecolor='gray', facecolor='gray',
alpha=.1, alpha=.1,
zorder=1) zorder=1)
c50 = patches.Circle(center, radius=radius/2.0, c50 = patches.Circle(center, radius= sqrt(radius/2.0),
facecolor='gray', facecolor='gray',
alpha=.1, alpha=.1,
zorder=2) zorder=2)
@ -228,7 +228,7 @@ class LplsZCorrelationPlot(BlmScatterPlot):
facecolor='gray', facecolor='gray',
alpha=.1, alpha=.1,
zorder=1) zorder=1)
c50 = patches.Circle(center, radius=radius/2.0, c50 = patches.Circle(center, radius=sqrt(radius/2.0),
facecolor='gray', facecolor='gray',
alpha=.1, alpha=.1,
zorder=2) zorder=2)

View File

@ -14,6 +14,7 @@ try:
except: except:
has_sym = False has_sym = False
def pca(a, aopt,scale='scores',mode='normal',center_axis=0): def pca(a, aopt,scale='scores',mode='normal',center_axis=0):
""" Principal Component Analysis. """ Principal Component Analysis.
@ -187,7 +188,7 @@ def pcr(a, b, aopt, scale='scores',mode='normal',center_axis=0):
dat.update({'Q':Q, 'F':F, 'expvary':expvary}) dat.update({'Q':Q, 'F':F, 'expvary':expvary})
return dat return dat
def pls(a, b, aopt=2, scale='scores', mode='normal', ax_center=0, ab=None): def pls(a, b, aopt=2, scale='scores', mode='normal', center_axis=0, ab=None):
"""Partial Least Squares Regression. """Partial Least Squares Regression.
Performs PLS on given matrix and returns results in a dictionary. Performs PLS on given matrix and returns results in a dictionary.
@ -244,6 +245,10 @@ def pls(a, b, aopt=2, scale='scores', mode='normal', ax_center=0, ab=None):
assert(m==mm) assert(m==mm)
else: else:
k, l = m_shape(b) k, l = m_shape(b)
if center_axis>=0:
a = a - expand_dims(a.mean(center_axis), center_axis)
b = b - expand_dims(b.mean(center_axis), center_axis)
W = empty((n, aopt)) W = empty((n, aopt))
P = empty((n, aopt)) P = empty((n, aopt))
@ -255,25 +260,28 @@ def pls(a, b, aopt=2, scale='scores', mode='normal', ax_center=0, ab=None):
if ab==None: if ab==None:
ab = dot(a.T, b) ab = dot(a.T, b)
for i in range(aopt): for i in range(aopt):
if ab.shape[1]==1: if ab.shape[1]==1: #pls 1
w = ab.reshape(n, l) w = ab.reshape(n, l)
w = w/vnorm(w) w = w/vnorm(w)
elif n<l: elif n<l: # more yvars than xvars
if has_sym: if has_sym:
s, u = symeig(dot(ab.T, ab),range=[l,l],overwrite=True) s, u = symeig(dot(ab, ab.T),range=[l,l],overwrite=True)
else: else:
u, s, vh = svd(dot(ab, ab.T)) u, s, vh = svd(dot(ab, ab.T))
w = u[:,0] w = u[:,0]
else: else: # standard wide xdata
if has_sym: if has_sym:
s, u = symeig(dot(ab.T, ab),range=[l,l],overwrite=True) s, q = symeig(dot(ab.T, ab),range=[l,l],overwrite=True)
else: else:
u, s, vh = svd(dot(ab.T, ab)) q, s, vh = svd(dot(ab.T, ab))
w = dot(ab, u) q = q[:,:1]
w = dot(ab, q)
w = w/vnorm(w)
r = w.copy() r = w.copy()
if i>0: if i>0:
for j in range(0, i, 1): for j in range(0, i, 1):
r = r - dot(P[:,j].T, w)*R[:,j][:,newaxis] r = r - dot(P[:,j].T, w)*R[:,j][:,newaxis]
print vnorm(r)
t = dot(a, r) t = dot(a, r)
tt = vnorm(t)**2 tt = vnorm(t)**2
p = dot(a.T, t)/tt p = dot(a.T, t)/tt
@ -345,9 +353,13 @@ def w_pls(aat, b, aopt):
""" Pls for wide matrices. """ Pls for wide matrices.
Fast pls for crossval, used in calc rmsep for wide X Fast pls for crossval, used in calc rmsep for wide X
There is no P or W. T is normalised There is no P or W. T is normalised
aat = centered kernel matrix
b = centered y
""" """
bb = b.copy() bb = b.copy()
m, m = aat.shape k, l = m_shape(b)
m, m = m_shape(aat)
U = empty((m, aopt)) # W U = empty((m, aopt)) # W
T = empty((m, aopt)) T = empty((m, aopt))
R = empty((m, aopt)) # R R = empty((m, aopt)) # R
@ -355,23 +367,28 @@ def w_pls(aat, b, aopt):
for i in range(aopt): for i in range(aopt):
if has_sym: if has_sym:
pass s, q = symeig(dot(dot(b.T, aat), b), range=(l,l),overwrite=True)
else: else:
q, s, vh = svd(dot(dot(b.T, aat), b), full_matrices=0) q, s, vh = svd(dot(dot(b.T, aat), b), full_matrices=0)
q = q[:,:1] q = q[:,:1]
u = dot(b , q) #y-factor scores u = dot(b , q) #y-factor scores
U[:,i] = u.ravel() U[:,i] = u.ravel()
t = dot(aat, u) t = dot(aat, u)
print "Norm of t: %s" %vnorm(t)
print "s: %s" %s
t = t/vnorm(t) t = t/vnorm(t)
T[:,i] = t.ravel() T[:,i] = t.ravel()
r = dot(aat, t) #score-weights r = dot(aat, t)#score-weights
#r = r/vnorm(r)
print "Norm R: %s" %vnorm(r)
R[:,i] = r.ravel() R[:,i] = r.ravel()
PROJ[:,: i+1] = dot(T[:,:i+1], inv(dot(T[:,:i+1].T, R[:,:i+1])) ) PROJ[:,: i+1] = dot(T[:,:i+1], inv(dot(T[:,:i+1].T, R[:,:i+1])) )
if i<aopt: if i<aopt:
b = b - dot(PROJ[:,:i+1], dot(R[:,:i+1].T,b) ) b = b - dot(PROJ[:,:i+1], dot(R[:,:i+1].T, b) )
C = dot(bb.T, T) C = dot(bb.T, T)
return {'T':T, 'U':U, 'Q':C, 'H':H} return {'T':T, 'U':U, 'Q':C, 'R':R}
def bridge(a, b, aopt, scale='scores', mode='normal', r=0): def bridge(a, b, aopt, scale='scores', mode='normal', r=0):
"""Undeflated Ridged svd(X'Y) """Undeflated Ridged svd(X'Y)
@ -476,6 +493,8 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], mode='normal', sca
P = empty((n, a_max)) P = empty((n, a_max))
K = empty((o, a_max)) K = empty((o, a_max))
L = empty((u, a_max)) L = empty((u, a_max))
B = empty((a_max, n, l))
b0 = empty((a_max, m, l))
var_x = empty((a_max,)) var_x = empty((a_max,))
var_y = empty((a_max,)) var_y = empty((a_max,))
var_z = empty((a_max,)) var_z = empty((a_max,))
@ -485,8 +504,8 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], mode='normal', sca
print "\n Working on comp. %s" %a print "\n Working on comp. %s" %a
u = Y[:,:1] u = Y[:,:1]
diff = 1 diff = 1
MAX_ITER = 100 MAX_ITER = 200
lim = 1e-5 lim = 1e-16
niter = 0 niter = 0
while (diff>lim and niter<MAX_ITER): while (diff>lim and niter<MAX_ITER):
niter += 1 niter += 1
@ -526,8 +545,8 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], mode='normal', sca
var_y[a] = pow(Y, 2).sum() var_y[a] = pow(Y, 2).sum()
var_z[a] = pow(Z, 2).sum() var_z[a] = pow(Z, 2).sum()
B = dot(dot(W, inv(dot(P.T, W))), Q.T) B[a] = dot(dot(W[:,:a+1], inv(dot(P[:,:a+1].T, W[:,:a+1]))), Q[:,:a+1].T)
b0 = mnY - dot(mnX, B) b0[a] = mnY - dot(mnX, B[a])
# variance explained # variance explained
evx = 100.0*(1 - var_x/varX) evx = 100.0*(1 - var_x/varX)
@ -546,6 +565,116 @@ def nipals_lpls(X, Y, Z, a_max, alpha=.7, mean_ctr=[2, 0, 1], mode='normal', sca
def nipals_pls(X, Y, a_max, alpha=.7, ax_center=0, mode='normal', scale='scores', verbose=False):
"""Partial Least Sqaures Regression by the nipals algorithm.
(X!Z)->Y
:input:
X : data matrix (m, n)
Y : data matrix (m, l)
:output:
T : X-scores
W : X-weights/Z-weights
P : X-loadings
Q : Y-loadings
U : X-Y relation
B : Regression coefficients X->Y
b0: Regression coefficient intercept
evx : X-explained variance
evy : Y-explained variance
evz : Z-explained variance
:Notes:
"""
if ax_center>=0:
mn_x = expand_dims(X.mean(ax_center), ax_center)
mn_y = expand_dims(Y.mean(ax_center), ax_center)
X = X - mn_x
Y = Y - mn_y
varX = pow(X, 2).sum()
varY = pow(Y, 2).sum()
m, n = X.shape
k, l = Y.shape
# initialize
U = empty((k, a_max))
Q = empty((l, a_max))
T = empty((m, a_max))
W = empty((n, a_max))
P = empty((n, a_max))
B = empty((a_max, n, l))
b0 = empty((a_max, m, l))
var_x = empty((a_max,))
var_y = empty((a_max,))
t1 = X[:,:1]
for a in range(a_max):
if verbose:
print "\n Working on comp. %s" %a
u = Y[:,:1]
diff = 1
MAX_ITER = 100
lim = 1e-16
niter = 0
while (diff>lim and niter<MAX_ITER):
niter += 1
#u1 = u.copy()
w = dot(X.T, u)
w = w/sqrt(dot(w.T, w))
#l = dot(Z, w)
#k = dot(Z.T, l)
#k = k/sqrt(dot(k.T, k))
#w = alpha*k + (1-alpha)*w
#w = w/sqrt(dot(w.T, w))
t = dot(X, w)
q = dot(Y.T, t)
q = q/sqrt(dot(q.T, q))
u = dot(Y, q)
diff = vnorm(t1 - t)
t1 = t.copy()
if verbose:
print "Converged after %s iterations" %niter
#tt = dot(t.T, t)
#p = dot(X.T, t)/tt
#q = dot(Y.T, t)/tt
#l = dot(Z, w)
p = dot(X.T, t)/dot(t.T, t)
p_norm = vnorm(p)
t = t*p_norm
w = w*p_norm
p = p/p_norm
U[:,a] = u.ravel()
W[:,a] = w.ravel()
P[:,a] = p.ravel()
T[:,a] = t.ravel()
Q[:,a] = q.ravel()
X = X - dot(t, p.T)
Y = Y - dot(t, q.T)
var_x[a] = pow(X, 2).sum()
var_y[a] = pow(Y, 2).sum()
B[a] = dot(dot(W[:,:a+1], inv(dot(P[:,:a+1].T, W[:,:a+1]))), Q[:,:a+1].T)
b0[a] = mn_y - dot(mn_x, B[a])
# variance explained
evx = 100.0*(1 - var_x/varX)
evy = 100.0*(1 - var_y/varY)
if scale=='loads':
tnorm = apply_along_axis(vnorm, 0, T)
T = T/tnorm
W = W*tnorm
Q = Q*tnorm
return {'T':T, 'W':W, 'P':P, 'Q':Q, 'U':U, 'B':B, 'b0':b0, 'evx':evx, 'evy':evy}
########### Helper routines ######### ########### Helper routines #########

View File

@ -1,7 +1,7 @@
"""This module implements some common validation schemes from pca and pls. """This module implements some common validation schemes from pca and pls.
""" """
from scipy import ones,mean,sqrt,dot,newaxis,zeros,sum,empty,\ from scipy import ones,mean,sqrt,dot,newaxis,zeros,sum,empty,\
apply_along_axis,eye,kron,array,sort apply_along_axis,eye,kron,array,sort,zeros_like,argmax
from scipy.stats import median from scipy.stats import median
from scipy.linalg import triu,inv,svd,norm from scipy.linalg import triu,inv,svd,norm
@ -122,7 +122,7 @@ def lpls_val(X, Y, Z, a_max=2, nsets=None,alpha=.5):
B = dat['B'] B = dat['B']
b0 = dat['b0'] b0 = dat['b0']
for a in range(a_max): for a in range(a_max):
Yhat[a,ind,:] = b0[a][0][0] + dot(xi, B[a]) Yhat[a,ind,:] = b0[a][0][0] + dot(xi-xcal.mean(0), B[a])
Yhat_class = zeros_like(Yhat) Yhat_class = zeros_like(Yhat)
for a in range(a_max): for a in range(a_max):
for i in range(k): for i in range(k):

View File

@ -132,7 +132,7 @@ text.dvipnghack : False # some versions of dvipng don't handle
# default fontsizes for ticklabels, and so on. See # default fontsizes for ticklabels, and so on. See
# http://matplotlib.sourceforge.net/matplotlib.axes.html#Axes # http://matplotlib.sourceforge.net/matplotlib.axes.html#Axes
axes.hold : True # whether to clear the axes by default on axes.hold : True # whether to clear the axes by default on
axes.facecolor : white # axes background color axes.facecolor : 0.6 # axes background color
axes.edgecolor : black # axes edge color axes.edgecolor : black # axes edge color
axes.linewidth : 1.0 # edge linewidth axes.linewidth : 1.0 # edge linewidth
axes.grid : True # display grid or not axes.grid : True # display grid or not