New crossval index generator
This commit is contained in:
parent
18f33decc7
commit
b114d5aeec
|
@ -0,0 +1,66 @@
|
||||||
|
from numpy import array_split,arange
|
||||||
|
|
||||||
|
|
||||||
|
def cv(n, k, randomise=False, sequential=False):
|
||||||
|
"""
|
||||||
|
Generates k (training, validation) index pairs.
|
||||||
|
|
||||||
|
Each pair is a partition of arange(n), where validation is an iterable
|
||||||
|
of length ~n/k.
|
||||||
|
|
||||||
|
If randomise is true, a copy of index is shuffled before partitioning,
|
||||||
|
otherwise its order is preserved in training and validation.
|
||||||
|
|
||||||
|
Randomise overrides the sequential argument. If randomise is true,
|
||||||
|
sequential is False
|
||||||
|
|
||||||
|
If sequential is true the index is partioned in continous blocks,
|
||||||
|
otherwise interleaved ordering is used.
|
||||||
|
"""
|
||||||
|
index = xrange(N)
|
||||||
|
if randomise:
|
||||||
|
from random import shuffle
|
||||||
|
index = list(index)
|
||||||
|
shuffle(index)
|
||||||
|
sequential = False
|
||||||
|
if sequential:
|
||||||
|
for validation in array_split(index, K):
|
||||||
|
training = [i for i in index if i not in validation]
|
||||||
|
yield training, validation
|
||||||
|
else:
|
||||||
|
for k in xrange(K):
|
||||||
|
training = [i for i in index if i % K != k]
|
||||||
|
validation = [i for i in index if i % K == k]
|
||||||
|
yield training, validation
|
||||||
|
|
||||||
|
def shuffle_diag(shape, K, randomise=False, sequential=False):
|
||||||
|
"""
|
||||||
|
Generates k (training, validation) index pairs.
|
||||||
|
"""
|
||||||
|
m, n = shape
|
||||||
|
|
||||||
|
if K>m or K>n:
|
||||||
|
msg = "You may not use more subsets than max(n_rows, n_cols)"
|
||||||
|
raise ValueError, msg
|
||||||
|
|
||||||
|
mon = max(m, n)
|
||||||
|
#index = xrange(n)
|
||||||
|
index = [i for i in range(m*n) if i % m == 0]
|
||||||
|
print index
|
||||||
|
if randomise:
|
||||||
|
from random import shuffle
|
||||||
|
index = list(index)
|
||||||
|
shuffle(index)
|
||||||
|
sequential = False
|
||||||
|
|
||||||
|
if sequential:
|
||||||
|
start_inds = array_split(index, K)
|
||||||
|
else:
|
||||||
|
for k in xrange(K):
|
||||||
|
start_inds = [index[i] for i in xrange(n) if i % K == k]
|
||||||
|
|
||||||
|
print start_inds
|
||||||
|
for start in start_inds:
|
||||||
|
ind = arange(start, n*m, mon+1)
|
||||||
|
yield ind
|
||||||
|
|
Reference in New Issue