ActiveState Code

Recipe 521906: K fold cross validation partition


Takes a sequence and yields K partitions of it into training and validation test sets. Training sets are of size (k-1)*len(X)/K and partition sets are of size len(X)/K

Python
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def k_fold_cross_validation(X, K, randomise = False):
	"""
	Generates K (training, validation) pairs from the items in X.

	Each pair is a partition of X, where validation is an iterable
	of length len(X)/K. So each training iterable is of length (K-1)*len(X)/K.

	If randomise is true, a copy of X is shuffled before partitioning,
	otherwise its order is preserved in training and validation.
	"""
	if randomise: from random import shuffle; X=list(X); shuffle(X)
	for k in xrange(K):
		training = [x for i, x in enumerate(X) if i % K != k]
		validation = [x for i, x in enumerate(X) if i % K == k]
		yield training, validation

X = [i for i in xrange(97)]
for training, validation in k_fold_cross_validation(X, K=7):
	for x in X: assert (x in training) ^ (x in validation), x

Discussion

This is a common task in machine learning.

Any improvements welcome. There's probably a one liner out there :)

Comments

  1. 1. At 6:45 a.m. on 14 jun 2007, Matteo Dell'Amico said:

    You could use the alist[start::step] idiom, and if you don't care about the order, end with:

    from random import shuffle
    
    def k_fold_cross_validation(items, k, randomize=False):
    
        if randomize:
            items = list(items)
            shuffle(items)
    
        slices = [items[i::k] for i in xrange(k)]
    
        for i in xrange(k):
            validation = slices[i]
            training = [item
                        for s in slices if s is not validation
                        for item in s]
            yield training, validation
    
    if __name__ == '__main__':
        items = range(97)
        for training, validation in k_fold_cross_validation(items, 7):
            for item in items:
                assert (item in training) ^ (item in validation)
    

    If you do care about the order, you'd need to do some weaving though.

  2. 2. At 5:30 a.m. on 15 jun 2007, John Reid (the author) said:

    That was essentially the way I thought of at first. but is that better in some way than the version above? The version above does maintain the order unless randomise is true.

  3. 3. At 2:31 p.m. on 15 jun 2007, Steven Bethard said:

    no need for index_filter(). The indirection through index_filter() is unnecessary, confusing and slower than simple list comprehensions:

    training = [x for i, x in enumerate(X) if i % K != k]
    validation = [x for i, x in enumerate(X) if i % K == k]
    
  4. 4. At 2:42 p.m. on 15 jun 2007, Steven Bethard said:

    docstring is a little confusing. That docstring is a little confusing for me. I think I'd instead write it as:

    """Generates K (training, validation) pairs from the items in X.
    
    The validation iterables are a partition of X, and each validation
    iterable is of length len(X)/K. Each training iterable is the
    complement (within X) of the validation iterable, and so each training
    iterable is of length (K-1)*len(X)/K.
    """
    
  5. 5. At 1:41 a.m. on 16 jun 2007, John Reid (the author) said:

    Thanks.

  6. 6. At 9:24 a.m. on 26 jun 2007, Matteo Dell'Amico said:

    Not really. It is just maybe marginally faster though.

Sign in to comment