Welcome, guest | Sign In | My Account | Store | Cart

Takes a sequence and yields K partitions of it into training and validation test sets. Training sets are of size (k-1)*len(X)/K and partition sets are of size len(X)/K

Python, 19 lines
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def k_fold_cross_validation(X, K, randomise = False):
	"""
	Generates K (training, validation) pairs from the items in X.

	Each pair is a partition of X, where validation is an iterable
	of length len(X)/K. So each training iterable is of length (K-1)*len(X)/K.

	If randomise is true, a copy of X is shuffled before partitioning,
	otherwise its order is preserved in training and validation.
	"""
	if randomise: from random import shuffle; X=list(X); shuffle(X)
	for k in xrange(K):
		training = [x for i, x in enumerate(X) if i % K != k]
		validation = [x for i, x in enumerate(X) if i % K == k]
		yield training, validation

X = [i for i in xrange(97)]
for training, validation in k_fold_cross_validation(X, K=7):
	for x in X: assert (x in training) ^ (x in validation), x

This is a common task in machine learning.

Any improvements welcome. There's probably a one liner out there :)

6 comments

Matteo Dell'Amico 16 years, 9 months ago  # | flag

You could use the alist[start::step] idiom, and if you don't care about the order, end with:

from random import shuffle

def k_fold_cross_validation(items, k, randomize=False):

    if randomize:
        items = list(items)
        shuffle(items)

    slices = [items[i::k] for i in xrange(k)]

    for i in xrange(k):
        validation = slices[i]
        training = [item
                    for s in slices if s is not validation
                    for item in s]
        yield training, validation

if __name__ == '__main__':
    items = range(97)
    for training, validation in k_fold_cross_validation(items, 7):
        for item in items:
            assert (item in training) ^ (item in validation)

If you do care about the order, you'd need to do some weaving though.

John Reid (author) 16 years, 9 months ago  # | flag

That was essentially the way I thought of at first. but is that better in some way than the version above? The version above does maintain the order unless randomise is true.

Steven Bethard 16 years, 9 months ago  # | flag

no need for index_filter(). The indirection through index_filter() is unnecessary, confusing and slower than simple list comprehensions:

training = [x for i, x in enumerate(X) if i % K != k]
validation = [x for i, x in enumerate(X) if i % K == k]
Steven Bethard 16 years, 9 months ago  # | flag

docstring is a little confusing. That docstring is a little confusing for me. I think I'd instead write it as:

"""Generates K (training, validation) pairs from the items in X.

The validation iterables are a partition of X, and each validation
iterable is of length len(X)/K. Each training iterable is the
complement (within X) of the validation iterable, and so each training
iterable is of length (K-1)*len(X)/K.
"""
John Reid (author) 16 years, 9 months ago  # | flag

Thanks.

Matteo Dell'Amico 16 years, 9 months ago  # | flag

Not really. It is just maybe marginally faster though.

Created by John Reid on Thu, 14 Jun 2007 (PSF)
Python recipes (4591)
John Reid's recipes (5)

Required Modules

Other Information and Tasks