Projects/laydi
Projects
/
laydi
Archived
7
0
Fork 0

New crossval index generator

This commit is contained in:
Arnar Flatberg 2007-09-25 06:31:40 +00:00
parent 18f33decc7
commit b114d5aeec
1 changed files with 66 additions and 0 deletions

66
fluents/lib/cv_index.py Normal file
View File

@ -0,0 +1,66 @@
from numpy import array_split,arange
def cv(n, k, randomise=False, sequential=False):
"""
Generates k (training, validation) index pairs.
Each pair is a partition of arange(n), where validation is an iterable
of length ~n/k.
If randomise is true, a copy of index is shuffled before partitioning,
otherwise its order is preserved in training and validation.
Randomise overrides the sequential argument. If randomise is true,
sequential is False
If sequential is true the index is partioned in continous blocks,
otherwise interleaved ordering is used.
"""
index = xrange(N)
if randomise:
from random import shuffle
index = list(index)
shuffle(index)
sequential = False
if sequential:
for validation in array_split(index, K):
training = [i for i in index if i not in validation]
yield training, validation
else:
for k in xrange(K):
training = [i for i in index if i % K != k]
validation = [i for i in index if i % K == k]
yield training, validation
def shuffle_diag(shape, K, randomise=False, sequential=False):
"""
Generates k (training, validation) index pairs.
"""
m, n = shape
if K>m or K>n:
msg = "You may not use more subsets than max(n_rows, n_cols)"
raise ValueError, msg
mon = max(m, n)
#index = xrange(n)
index = [i for i in range(m*n) if i % m == 0]
print index
if randomise:
from random import shuffle
index = list(index)
shuffle(index)
sequential = False
if sequential:
start_inds = array_split(index, K)
else:
for k in xrange(K):
start_inds = [index[i] for i in xrange(n) if i % K == k]
print start_inds
for start in start_inds:
ind = arange(start, n*m, mon+1)
yield ind