Removed arpack eigs in pls_jk due to convergence problems

This commit is contained in:
Arnar Flatberg 2007-12-03 16:17:35 +00:00
parent 8e9c4c1c58
commit bc23d933c3
3 changed files with 25 additions and 10 deletions

View File

@ -6,7 +6,7 @@ __all__ = ['lpls_val', 'pls_jk', 'lpls_jk']
__docformat__ = "restructuredtext en"
from numpy import dot,empty,zeros,sqrt,atleast_2d,argmax,asarray,median,\
array_split,arange
array_split,arange, isnan, any
from numpy.random import shuffle
from engines import pls
@ -188,6 +188,8 @@ def pls_jk(X, Y, a_opt, nsets=None, center_axis=True, verbose=False):
if verbose:
print "Segment number: %d" %i
dat = pls(X[cal,:], Y[cal,:], a_opt, scale='loads', mode='fast', center_axis=[0, 0])
if any(isnan(dat['W'])):
1/0
Wcv[i,:,:] = dat['W']
return Wcv

View File

@ -8,7 +8,7 @@ __docformat__ = "restructuredtext en"
from math import sqrt as msqrt
from numpy import dot,empty,zeros,apply_along_axis,newaxis,finfo,sqrt,r_,expand_dims,\
minimum
minimum, any, isnan
from numpy.linalg import inv,svd
from scipy.sandbox import arpack
@ -335,16 +335,30 @@ def pls(X, Y, aopt=2, scale='scores', mode='normal', center_axis=[0, 0]):
w = XY.reshape(n, l)
w = w/vnorm(w)
elif n < l: # more yvars than xvars
s, w = arpack.eigen_symmetric(dot(XY, XY.T),k=1, tol=1e-10, maxiter=100)
#w, s, vh = svd(dot(XY, XY.T))
#w = w[:,:1]
#!!! fixme
# Arpack has convergence problems on large equal eigenvalues
# which is typical for design/category in Y so we switch to regular svd.
# Need to decide wether to remove arpack here or check for system
# with many samples, many x-vars and many non-orth y-vars (where arpack speed
# shines)
#############
#s, w = arpack.eigen_symmetric(dot(XY, XY.T),k=1, tol=1e-10, maxiter=1000)
#if s[0] == 0:
# print "Arpack did not converge... using svd"
w, s, vh = svd(dot(XY, XY.T))
w = w[:,:1]
else: # more xvars than yvars
s, q = arpack.eigen_symmetric(dot(XY.T, XY), k=1, tol=1e-10, maxiter=100)
#q, s, vh = svd(dot(XY.T, XY))
#q = q[:,:1]
#s, q = arpack.eigen_symmetric(dot(XY.T, XY), k=1, tol=1e-10, maxiter=1000)
#if s[0] == 0:
# print "Arpack did not converge... using svd"
q, s, vh = svd(dot(XY.T, XY))
q = q[:,:1]
w = dot(XY, q)
w = w/vnorm(w)
r = w.copy()
if i > 0:
for j in range(0, i, 1):

View File

@ -262,8 +262,7 @@ def lpls_qvals(X, Y, Z, aopt=None, alpha=.3, zx_alpha=.5, n_iter=20,
return _fdr(cal_tsq_z, pert_tsq_z, median), _fdr(cal_tsq_x, pert_tsq_x, median)
def pls_qvals(X, Y, aopt, alpha=.3, zx_alpha=.5, n_iter=20,
sim_method='shuffle',p_center='med', cov_center=median,
def pls_qvals(X, Y, aopt, alpha=.3, n_iter=20,p_center='med', cov_center=median,
crot=True,strict=False, center_axis=[0,0], nsets=None, zorth=False):
"""Returns qvals for pls model by permutation analysis.