Projects/laydi
Projects
/
laydi
Archived
7
0
Fork 0
This repository has been archived on 2024-07-04. You can view files and clone it, but cannot push or open issues or pull requests.
laydi/fluents/lib/cv_index.py

67 lines
1.9 KiB
Python

from numpy import array_split,arange
def cv(n, k, randomise=False, sequential=False):
"""
Generates k (training, validation) index pairs.
Each pair is a partition of arange(n), where validation is an iterable
of length ~n/k.
If randomise is true, a copy of index is shuffled before partitioning,
otherwise its order is preserved in training and validation.
Randomise overrides the sequential argument. If randomise is true,
sequential is False
If sequential is true the index is partioned in continous blocks,
otherwise interleaved ordering is used.
"""
index = xrange(N)
if randomise:
from random import shuffle
index = list(index)
shuffle(index)
sequential = False
if sequential:
for validation in array_split(index, K):
training = [i for i in index if i not in validation]
yield training, validation
else:
for k in xrange(K):
training = [i for i in index if i % K != k]
validation = [i for i in index if i % K == k]
yield training, validation
def shuffle_diag(shape, K, randomise=False, sequential=False):
"""
Generates k (training, validation) index pairs.
"""
m, n = shape
if K>m or K>n:
msg = "You may not use more subsets than max(n_rows, n_cols)"
raise ValueError, msg
mon = max(m, n)
#index = xrange(n)
index = [i for i in range(m*n) if i % m == 0]
print index
if randomise:
from random import shuffle
index = list(index)
shuffle(index)
sequential = False
if sequential:
start_inds = array_split(index, K)
else:
for k in xrange(K):
start_inds = [index[i] for i in xrange(n) if i % K == k]
print start_inds
for start in start_inds:
ind = arange(start, n*m, mon+1)
yield ind