Welcome, guest | Sign In | My Account | Store | Cart

Takes a sequence and yields K partitions of it into training and validation test sets. Training sets are of size (k-1)*len(X)/K and partition sets are of size len(X)/K

Python, 19 lines
 ``` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19``` ```def k_fold_cross_validation(X, K, randomise = False): """ Generates K (training, validation) pairs from the items in X. Each pair is a partition of X, where validation is an iterable of length len(X)/K. So each training iterable is of length (K-1)*len(X)/K. If randomise is true, a copy of X is shuffled before partitioning, otherwise its order is preserved in training and validation. """ if randomise: from random import shuffle; X=list(X); shuffle(X) for k in xrange(K): training = [x for i, x in enumerate(X) if i % K != k] validation = [x for i, x in enumerate(X) if i % K == k] yield training, validation X = [i for i in xrange(97)] for training, validation in k_fold_cross_validation(X, K=7): for x in X: assert (x in training) ^ (x in validation), x ```

This is a common task in machine learning.

Any improvements welcome. There's probably a one liner out there :) Matteo Dell'Amico 16 years, 3 months ago

You could use the alist[start::step] idiom, and if you don't care about the order, end with:

``````from random import shuffle

def k_fold_cross_validation(items, k, randomize=False):

if randomize:
items = list(items)
shuffle(items)

slices = [items[i::k] for i in xrange(k)]

for i in xrange(k):
validation = slices[i]
training = [item
for s in slices if s is not validation
for item in s]
yield training, validation

if __name__ == '__main__':
items = range(97)
for training, validation in k_fold_cross_validation(items, 7):
for item in items:
assert (item in training) ^ (item in validation)
``````

If you do care about the order, you'd need to do some weaving though. John Reid (author) 16 years, 3 months ago

That was essentially the way I thought of at first. but is that better in some way than the version above? The version above does maintain the order unless randomise is true. Steven Bethard 16 years, 3 months ago

no need for index_filter(). The indirection through index_filter() is unnecessary, confusing and slower than simple list comprehensions:

``````training = [x for i, x in enumerate(X) if i % K != k]
validation = [x for i, x in enumerate(X) if i % K == k]
`````` Steven Bethard 16 years, 3 months ago

docstring is a little confusing. That docstring is a little confusing for me. I think I'd instead write it as:

``````"""Generates K (training, validation) pairs from the items in X.

The validation iterables are a partition of X, and each validation
iterable is of length len(X)/K. Each training iterable is the
complement (within X) of the validation iterable, and so each training
iterable is of length (K-1)*len(X)/K.
"""
`````` John Reid (author) 16 years, 3 months ago

Thanks. Matteo Dell'Amico 16 years, 3 months ago

Not really. It is just maybe marginally faster though. Created by John Reid on Thu, 14 Jun 2007 (PSF)