New crossval index generator
This commit is contained in:
parent
18f33decc7
commit
b114d5aeec
|
@ -0,0 +1,66 @@
|
|||
from numpy import array_split,arange
|
||||
|
||||
|
||||
def cv(n, k, randomise=False, sequential=False):
|
||||
"""
|
||||
Generates k (training, validation) index pairs.
|
||||
|
||||
Each pair is a partition of arange(n), where validation is an iterable
|
||||
of length ~n/k.
|
||||
|
||||
If randomise is true, a copy of index is shuffled before partitioning,
|
||||
otherwise its order is preserved in training and validation.
|
||||
|
||||
Randomise overrides the sequential argument. If randomise is true,
|
||||
sequential is False
|
||||
|
||||
If sequential is true the index is partioned in continous blocks,
|
||||
otherwise interleaved ordering is used.
|
||||
"""
|
||||
index = xrange(N)
|
||||
if randomise:
|
||||
from random import shuffle
|
||||
index = list(index)
|
||||
shuffle(index)
|
||||
sequential = False
|
||||
if sequential:
|
||||
for validation in array_split(index, K):
|
||||
training = [i for i in index if i not in validation]
|
||||
yield training, validation
|
||||
else:
|
||||
for k in xrange(K):
|
||||
training = [i for i in index if i % K != k]
|
||||
validation = [i for i in index if i % K == k]
|
||||
yield training, validation
|
||||
|
||||
def shuffle_diag(shape, K, randomise=False, sequential=False):
|
||||
"""
|
||||
Generates k (training, validation) index pairs.
|
||||
"""
|
||||
m, n = shape
|
||||
|
||||
if K>m or K>n:
|
||||
msg = "You may not use more subsets than max(n_rows, n_cols)"
|
||||
raise ValueError, msg
|
||||
|
||||
mon = max(m, n)
|
||||
#index = xrange(n)
|
||||
index = [i for i in range(m*n) if i % m == 0]
|
||||
print index
|
||||
if randomise:
|
||||
from random import shuffle
|
||||
index = list(index)
|
||||
shuffle(index)
|
||||
sequential = False
|
||||
|
||||
if sequential:
|
||||
start_inds = array_split(index, K)
|
||||
else:
|
||||
for k in xrange(K):
|
||||
start_inds = [index[i] for i in xrange(n) if i % K == k]
|
||||
|
||||
print start_inds
|
||||
for start in start_inds:
|
||||
ind = arange(start, n*m, mon+1)
|
||||
yield ind
|
||||
|
Reference in New Issue