Welcome, guest | Sign In | My Account | Store | Cart

If you need to generate the cross product of a variable number of lists, here is how to do it with an obscure one-liner instead of a nice and clean recursive function. I came to this while trying to find a short and simple solution for generating the cross product of a variable number of lists. Well, it might be short, but it is obscure as well. Maybe somebody can provide me with a more elegant solution. The function works on list of lists as well.

Python, 44 lines
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
f=lambda ss,row=[],level=0: len(ss)>1 \
   and reduce(lambda x,y:x+y,[f(ss[1:],row+[i],level+1) for i in ss[0]]) \
   or [row+[i] for i in ss[0]]

# Example:

# If you have some sets ...

s1=['a1','a2','a3','a4']
s2=['b1','b2']
s3=['c1','c2','c3']

# ... simply put them in a "super=set" ...
ss=[s1,s2,s3]

# ... and call the function
cross_product =f(ss)

# this is the result
assert cross_product==[['a1', 'b1', 'c1'],
                      ['a1', 'b1', 'c2'],
                      ['a1', 'b1', 'c3'],
                      ['a1', 'b2', 'c1'],
                      ['a1', 'b2', 'c2'],
                      ['a1', 'b2', 'c3'],
                      ['a2', 'b1', 'c1'],
                      ['a2', 'b1', 'c2'],
                      ['a2', 'b1', 'c3'],
                      ['a2', 'b2', 'c1'],
                      ['a2', 'b2', 'c2'],
                      ['a2', 'b2', 'c3'],
                      ['a3', 'b1', 'c1'],
                      ['a3', 'b1', 'c2'],
                      ['a3', 'b1', 'c3'],
                      ['a3', 'b2', 'c1'],
                      ['a3', 'b2', 'c2'],
                      ['a3', 'b2', 'c3'],
                      ['a4', 'b1', 'c1'],
                      ['a4', 'b1', 'c2'],
                      ['a4', 'b1', 'c3'],
                      ['a4', 'b2', 'c1'],
                      ['a4', 'b2', 'c2'],
                      ['a4', 'b2', 'c3']],\
                      'cros sproduct failed'

As the length os the ss gets shorter with every recursive call the len(ss)>1 condition checks if the "lowest level" has been reached or not.

If the lowest level has not been reached the "and" condition is exected: the first list of ss is stripped off, the recursion level is increased, the current item of the stripped list is accumulated to the result item (row) and function calls itself with these new arguments. Reduction is necessary to avoid cascaded inclusion of lists if your initial lists are lists of lists themselves.

If the lowest level has been reached the "or" condition is executed: sipmly the elements of the inner-most list are added to the result item.

2 comments

Raymond Hettinger 21 years, 5 months ago  # | flag

Build the product set by set.

def cross(*args):
    ans = [[]]
    for arg in args:
        ans = [x+[y] for x in ans for y in arg]
    return ans

print cross(s1,s2,s3)
Steven Taschuk 20 years, 11 months ago  # | flag

Iterator version.

def cross(*sets):
    wheels = map(iter, sets) # wheels like in an odometer
    digits = [it.next() for it in wheels]
    while True:
        yield digits[:]
        for i in range(len(digits)-1, -1, -1):
            try:
                digits[i] = wheels[i].next()
                break
            except StopIteration:
                wheels[i] = iter(sets[i])
                digits[i] = wheels[i].next()
        else:
            break
Created by Attila Vásárhelyi on Wed, 30 Oct 2002 (PSF)
Python recipes (4591)
Attila Vásárhelyi's recipes (2)

Required Modules

  • (none specified)

Other Information and Tasks