This repository has been archived on 2024-07-04. You can view files and clone it, but cannot push or open issues or pull requests.
pyblm/tests/test_lplsengine.py

144 lines
3.8 KiB
Python

"""Testing routines for the lpls engine.
"""
from math import sqrt as msqrt
from numpy.testing import *
set_package_path()
from pyblm import lpls
from numpy import dot, eye, random,asarray,empty
from numpy.random import rand, randn
from numpy.linalg import svd,norm
restore_path()
def blm_array(shape=(5,10), comp=3, noise=0,seed=1,dtype='d'):
assert(min(*shape)>=comp)
random.seed(seed)
t = rand(shape[0], comp)
p = rand(shape[1], comp)
x = dot(t, p.T)
if noise>0:
noise = noise*randn(*shape)
return x + noise
class LplsTestCase(NumpyTestCase):
def setUp(self):
self.a = blm_array(shape=(5,10),noise=.1)
self.b = blm_array(shape=(5,3), noise=.1)
self.c = blm_array(shape=(10,10), noise=.1)
self.nc = 2
def check_single(self):
self.a = asarray(self.a, dtype='f')
self.b = asarray(self.b, dtype='f')
self.c = asarray(self.c, dtype='f')
self.do()
def check_double(self):
a = asarray(self.a, dtype='d')
b = asarray(self.b, dtype='d')
c = asarray(self.c, dtype='d')
self.do()
def do(self,*args):
pass
#raise NotImplementedError
class testAlphaZero(LplsTestCase):
def do(self):
#dat = lpls(self.a, self.b, self.c, self.nc, alpha=0.0)
#assert_almost_equal(t1, t2)
pass
class testAlphaOne(LplsTestCase):
pass
class testZidentity(LplsTestCase):
def do(self):
I = eye(self.a.shape[1])
dat = lpls(self.a, self.b, I, 2, alpha=1.0)
dat2 = lpls(self.a, self.b, self.c, self.nc, alpha=0.0)
assert_almost_equal(dat['T'], dat2['T'])
class testYidentity(LplsTestCase):
def do(self):
I = eye(self.b.shape[0], dtype=self.a.dtype)
T0 = lpls(self.a, I, self.c, self.nc, alpha=0.0, mean_ctr=[-1,-1,-1])['T']
u, s, vt = svd(self.a, 0)
T = u*s
assert_almost_equal(abs(T0), abs(T[:,:self.nc]),5)
class testWideX(LplsTestCase):
pass
class testTallX(LplsTestCase):
pass
class testWideY(LplsTestCase):
pass
class testTallY(LplsTestCase):
pass
class testWideZ(LplsTestCase):
pass
class testTallZ(LplsTestCase):
pass
class testRankDeficientX(LplsTestCase):
pass
class testRankDeficientY(LplsTestCase):
pass
class testRankDeficientZ(LplsTestCase):
pass
class testCenterX(LplsTestCase):
def do(self):
T = lpls(self.a, self.b, self.c, self.nc, mean_ctr=[0,-1,-1])['T']
assert_almost_equal(T.mean(0), 0)
W = lpls(self.a, self.b, self.c, self.nc, alpha=0,mean_ctr=[1,-1,-1])['W']
assert_almost_equal(W.mean(0), 0)
class testResiduals(NumpyTestCase):
def setUp(self):
self.a = blm_array(shape=(5,5),noise=0, comp=3)
self.b = self.a.copy()
self.c = self.a.copy().T
self.nc = 3
def check_single(self):
self.a = asarray(self.a, dtype='f')
self.b = asarray(self.b, dtype='f')
self.c = asarray(self.c, dtype='f')
self.do()
def check_double(self):
a = asarray(self.a, dtype='d')
b = asarray(self.b, dtype='d')
c = asarray(self.c, dtype='d')
self.do()
def do(self):
dat = lpls(self.a, self.b, self.c, self.nc, mean_ctr=[-1,-1,-1])
class testOrthogonality(LplsTestCase):
def do(self):
dat = lpls(self.a, self.b, self.c, self.nc, mean_ctr=[0,0,0],scale='loads')
T, W, L, E, F = dat['T'],dat['W'],dat['L'],dat['E'],dat['F']
assert_almost_equal(dot(T.T,T), eye(T.shape[1]))
for i,w in enumerate(W.T):
W[:,i] = w/norm(w)
assert_almost_equal(dot(W.T, W), eye(W.shape[1]), 3)
assert_almost_equal(dot(T.T,E), 0, 3)
assert_almost_equal(dot(T.T,F), 0, 3)
if __name__ == '__main__':
NumpyTest().run()