In the card game SET!, players are shown an array of 12 (or more) symbol cards and try to identify a so-called 3-card set among these cards as quickly as possible.
A card has four attributes (number, shape, color and shading), each of which can take 3 possible values. In a set, for each attribute, all three cards should have either the same value, or the three different values.
This recipe solves the problem of finding all sets within an array of an arbitrary number of cards, showing some clever optimizations and celebrating the clarity of Python in expressing the algorithms.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | import random
class Table:
# a random table of n different cards
def __init__(self,n=12):
self.cards = random.sample(Card.allcards(),n)
###############################################
# four different algorithms that list all sets
# on the table, from slow to fast
def findsets_gnt(self): # generate and test
found = []
for i,ci in enumerate(self.cards):
for j,cj in enumerate(self.cards[i+1:],i+1):
for k,ck in enumerate(self.cards[j+1:],j+1):
if ci.isset(cj,ck):
found.append((ci,cj,ck))
return found
def findsets_gnt_mod(self): # generate and test (faster)
found = []
for i,ci in enumerate(self.cards):
for j,cj in enumerate(self.cards[i+1:],i+1):
for k,ck in enumerate(self.cards[j+1:],j+1):
if ci.isset_mod(cj,ck):
found.append((ci,cj,ck))
return found
def findsets_simple(self): # using thirdcard_simple
found = []
have = {}
for pos,card in enumerate(self.cards):
have[card]=pos
for i,ci in enumerate(self.cards):
for j,cj in enumerate(self.cards[i+1:],i+1):
k = have.get(ci.thirdcard_simple(cj))
if k > j: # False if k is None
found.append((ci,cj,self.cards[k]))
return found
def findsets_fast(self): # using thirdcard_fast
found = []
have = [None for _ in range(256)]
for pos,card in enumerate(self.cards):
have[card.bits]=pos
for i,ci in enumerate(self.cards):
for j,cj in enumerate(self.cards[i+1:],i+1):
k = have[ci.thirdcard_fast(cj)]
if k > j: # False if k is None
found.append((ci,cj,self.cards[k]))
return found
class Card:
def __init__(self,*attrs):
# a card is a simple 4-tuple of attributes
# each attr is supposed to be either 0, 1 or 2
self.attrs = attrs
# alternative representation of attrs, in 8 bits:
# 2 bits per attr, highest bits represent first attr
self.bits = sum(a<<(2*i) for i,a in enumerate(attrs[::-1]))
def __eq__(self,other):
return self.attrs == other.attrs
def __hash__(self):
return hash(self.attrs)
def __repr__(self):
return 'Card({})'.format(','.join(self.attrs))
# most readable way to express what a SET is
def isset(self,card1,card2):
def allsame(v0,v1,v2):
return v0==v1 and v1==v2
def alldifferent(v0,v1,v2):
return len(set((v0,v1,v2)))==3
return all(allsame(v0,v1,v2) or alldifferent(v0,v1,v2)
for (v0,v1,v2) in zip(self.attrs,card1.attrs,card2.attrs))
# a more mathy (and slightly faster) way
def isset_mod(self,card1,card2):
return all((v0+v1+v2)%3==0 for (v0,v1,v2) in zip(self.attrs,card1.attrs,card2.attrs))
# which third card is needed to complete the set
def thirdcard_simple(self,other):
def thirdval((v0,v1)):
return (-v0-v1)%3
return Card(*map(thirdval,zip(self.attrs,other.attrs)))
# same thing, but using the 8-bit representation
def thirdcard_fast(self,other):
# NB returns bits
x,y = self.bits,other.bits
xor = x^y
swap = ((xor & mask1) >> 1) | ((xor & mask0) << 1)
return (x&y) | (~(x|y) & swap)
# all 81 possible cards
@staticmethod
def allcards():
return [ Card(att0,att1,att2,att3)
for att0 in (0,1,2)
for att1 in (0,1,2)
for att2 in (0,1,2)
for att3 in (0,1,2)
]
# bit masks for low and high bits of the attributes
mask0 = sum(1<<(2*i) for i in range(4)) # 01010101
mask1 = sum(1<<(2*i+1) for i in range(4)) # 10101010
|
A card is basically represented as a 4-tuple of integers, each of which can be 0, 1 or 2. This is an example of three cards forming a set:
first card: (1, 0, 2, 2)
second card: (1, 1, 0, 2)
third card: (1, 2, 1, 2)
| | | L__ all the same
| | L_____ all different
| L________ all different
L___________ all the same
And these cards do not form a set:
first card: (0, 0, 1, 1)
second card: (1, 1, 1, 2)
third card: (1, 2, 1, 0)
| | | L__ all different
| | L_____ all the same
| L________ all different
L___________ WRONG!
Indeed, all attributes have to pass the test.
Around the 4-tuple we construct a small class Card
; this allows us to use a more object-oriented style (card0.isset(card1,card2)
instead of modulename.isset(card0,card1,card2)
) and also to use an alternative representation.
The obvious way to find all sets in a table of n cards is to check all 3-card combinations using a three-level nested loop, like Table.findsets_gnt()
does. Not every level needs to loop over n cards: this would visit each 3-card combination 6 separate times (and moreover, investigate combinations that include the same card more than once). Instead, the loops are constructed such that for each 3-card combination, the first card (in table order) is represented by the outer loop variable, the second by the middle loop variable, and the third by the inner loop variable.
Each combination is checked; when it is a set, it is appended to the list of found
sets.
The first step to optimization is to realize that an individual attribute passes the set-test if and only if the sum of the values on the three cards for this attribute, modulo 3, equals 0: (0+0+0)%3==0
, (1+1+1)%3==0
, (2+2+2)%3==0
, and (0+1+2)%3==0
. This is exploited in Card.isset_mod()
and Table.findsets_gnt_mod()
.
The second optimization step is to exploit the fact that each combination of two cards (card0,card1)
forms a set with one unique other card. This card can be determined by setting card2.attrs[i] = (-card0.attrs[i]-card1.attrs[i])%3
for all i
(as implemented in Card.thirdcard_simple()
). There are faster ways for checking whether this card is on the table than by looping over the remaining cards. In Table.findsets_simple()
, we put all the table cards in a dictionary have
; a dictionary lookup is performed using the card's hash value instead of looping over the dictionary items. In addition, we need to check whether the third card is behind the other two cards in table order, otherwise sets are doubly counted.
In the final optimization, we represent the 4 attributes in one integer, with 2 bits per attribute:
# values of each attribute are encoded like this
0 -> 00
1 -> 01
2 -> 10
# example encoding for a whole card:
(1,0,2,1) -> 01 00 10 01 # so Card(1,0,2,1).bits==73
If we can devise a function that calculates (-x-y)%3
in this representation using bitwise operators, we can apply it to all 8 bits at the same time. And indeed we can (with a little help from Karnaugh maps); the result is Card.thirdcard_fast()
. With this 8-bit representation of cards, we also chose to put the table cards in a 256-element list; we only need to index the list to check whether a card is on the table, and this is probably faster than a dictionary lookup.
Timing results
Using ipython.
>>> t = m.Table(12)
>>> timeit t.findsets_gnt()
1000 loops, best of 3: 684 us per loop
>>> timeit t.findsets_gnt_mod()
1000 loops, best of 3: 460 us per loop
>>> timeit t.findsets_simple()
1000 loops, best of 3: 404 us per loop
>>> timeit t.findsets_fast()
10000 loops, best of 3: 77.6 us per loop
Naturally, the differences become more pronounced with a table of 81 cards.
>>> t = m.Table(81)
>>> timeit t.findsets_gnt()
1 loops, best of 3: 238 ms per loop
>>> timeit t.findsets_gnt_mod()
10 loops, best of 3: 163 ms per loop
>>> timeit t.findsets_simple()
10 loops, best of 3: 20.6 ms per loop
>>> timeit t.findsets_fast()
100 loops, best of 3: 2.75 ms per loop
Of course, the motivation for this recipe aren't the milliseconds but the fun of computer science puzzles!