67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
from numpy import array_split,arange
|
|
|
|
|
|
def cv(n, k, randomise=False, sequential=False):
|
|
"""
|
|
Generates k (training, validation) index pairs.
|
|
|
|
Each pair is a partition of arange(n), where validation is an iterable
|
|
of length ~n/k.
|
|
|
|
If randomise is true, a copy of index is shuffled before partitioning,
|
|
otherwise its order is preserved in training and validation.
|
|
|
|
Randomise overrides the sequential argument. If randomise is true,
|
|
sequential is False
|
|
|
|
If sequential is true the index is partioned in continous blocks,
|
|
otherwise interleaved ordering is used.
|
|
"""
|
|
index = xrange(N)
|
|
if randomise:
|
|
from random import shuffle
|
|
index = list(index)
|
|
shuffle(index)
|
|
sequential = False
|
|
if sequential:
|
|
for validation in array_split(index, K):
|
|
training = [i for i in index if i not in validation]
|
|
yield training, validation
|
|
else:
|
|
for k in xrange(K):
|
|
training = [i for i in index if i % K != k]
|
|
validation = [i for i in index if i % K == k]
|
|
yield training, validation
|
|
|
|
def shuffle_diag(shape, K, randomise=False, sequential=False):
|
|
"""
|
|
Generates k (training, validation) index pairs.
|
|
"""
|
|
m, n = shape
|
|
|
|
if K>m or K>n:
|
|
msg = "You may not use more subsets than max(n_rows, n_cols)"
|
|
raise ValueError, msg
|
|
|
|
mon = max(m, n)
|
|
#index = xrange(n)
|
|
index = [i for i in range(m*n) if i % m == 0]
|
|
print index
|
|
if randomise:
|
|
from random import shuffle
|
|
index = list(index)
|
|
shuffle(index)
|
|
sequential = False
|
|
|
|
if sequential:
|
|
start_inds = array_split(index, K)
|
|
else:
|
|
for k in xrange(K):
|
|
start_inds = [index[i] for i in xrange(n) if i % K == k]
|
|
|
|
print start_inds
|
|
for start in start_inds:
|
|
ind = arange(start, n*m, mon+1)
|
|
yield ind
|
|
|