Welcome, guest | Sign In | My Account | Store | Cart

In the card game SET!, players are shown an array of 12 (or more) symbol cards and try to identify a so-called 3-card set among these cards as quickly as possible.

A card has four attributes (number, shape, color and shading), each of which can take 3 possible values. In a set, for each attribute, all three cards should have either the same value, or the three different values.

This recipe solves the problem of finding all sets within an array of an arbitrary number of cards, showing some clever optimizations and celebrating the clarity of Python in expressing the algorithms.

Python, 113 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import random

class Table:

    # a random table of n different cards
    def __init__(self,n=12):
        self.cards = random.sample(Card.allcards(),n)
    
    ###############################################
    # four different algorithms that list all sets
    # on the table, from slow to fast
    
    def findsets_gnt(self):     # generate and test
        found = []
        for i,ci in enumerate(self.cards):
            for j,cj in enumerate(self.cards[i+1:],i+1):
                for k,ck in enumerate(self.cards[j+1:],j+1):
                    if ci.isset(cj,ck):
                        found.append((ci,cj,ck))
        return found

    def findsets_gnt_mod(self):   # generate and test (faster)
        found = []
        for i,ci in enumerate(self.cards):
            for j,cj in enumerate(self.cards[i+1:],i+1):
                for k,ck in enumerate(self.cards[j+1:],j+1):
                    if ci.isset_mod(cj,ck):
                        found.append((ci,cj,ck))
        return found
    
    def findsets_simple(self):  # using thirdcard_simple
        found = []
        have = {}
        for pos,card in enumerate(self.cards):
            have[card]=pos
        for i,ci in enumerate(self.cards):
            for j,cj in enumerate(self.cards[i+1:],i+1):
                k = have.get(ci.thirdcard_simple(cj))
                if k > j:  # False if k is None
                    found.append((ci,cj,self.cards[k]))
        return found
        
    def findsets_fast(self):  # using thirdcard_fast
        found = []
        have = [None for _ in range(256)]
        for pos,card in enumerate(self.cards):
            have[card.bits]=pos
        for i,ci in enumerate(self.cards):
            for j,cj in enumerate(self.cards[i+1:],i+1):
                k = have[ci.thirdcard_fast(cj)]
                if k > j:  # False if k is None
                    found.append((ci,cj,self.cards[k]))
        return found
            
class Card:
    
    def __init__(self,*attrs):
        # a card is a simple 4-tuple of attributes
        # each attr is supposed to be either 0, 1 or 2
        self.attrs = attrs
        # alternative representation of attrs, in 8 bits:
        # 2 bits per attr, highest bits represent first attr
        self.bits = sum(a<<(2*i) for i,a in enumerate(attrs[::-1]))
    
    def __eq__(self,other):
        return self.attrs == other.attrs
    
    def __hash__(self):
        return hash(self.attrs)
    
    def __repr__(self):
        return 'Card({})'.format(','.join(self.attrs))
    
    # most readable way to express what a SET is
    def isset(self,card1,card2):
        def allsame(v0,v1,v2):
            return v0==v1 and v1==v2
        def alldifferent(v0,v1,v2):
            return len(set((v0,v1,v2)))==3
        return all(allsame(v0,v1,v2) or alldifferent(v0,v1,v2)
                   for (v0,v1,v2) in zip(self.attrs,card1.attrs,card2.attrs))
    
    # a more mathy (and slightly faster) way
    def isset_mod(self,card1,card2):
        return all((v0+v1+v2)%3==0 for (v0,v1,v2) in zip(self.attrs,card1.attrs,card2.attrs))
    
    # which third card is needed to complete the set
    def thirdcard_simple(self,other):
        def thirdval((v0,v1)):
            return (-v0-v1)%3
        return Card(*map(thirdval,zip(self.attrs,other.attrs)))
    
    # same thing, but using the 8-bit representation
    def thirdcard_fast(self,other):
        # NB returns bits
        x,y = self.bits,other.bits
        xor = x^y
        swap = ((xor & mask1) >> 1) | ((xor & mask0) << 1)
        return (x&y) | (~(x|y) & swap)

    # all 81 possible cards
    @staticmethod
    def allcards():
        return [ Card(att0,att1,att2,att3)
                   for att0 in (0,1,2)
                   for att1 in (0,1,2) 
                   for att2 in (0,1,2)
                   for att3 in (0,1,2)
               ]
               
# bit masks for low and high bits of the attributes        
mask0 = sum(1<<(2*i) for i in range(4))    # 01010101
mask1 = sum(1<<(2*i+1) for i in range(4))  # 10101010

A card is basically represented as a 4-tuple of integers, each of which can be 0, 1 or 2. This is an example of three cards forming a set:

first card:   (1, 0, 2, 2)
second card:  (1, 1, 0, 2)
third card:   (1, 2, 1, 2)
               |  |  |  L__ all the same
               |  |  L_____ all different
               |  L________ all different
               L___________ all the same

And these cards do not form a set:

first card:   (0, 0, 1, 1)
second card:  (1, 1, 1, 2)
third card:   (1, 2, 1, 0)
               |  |  |  L__ all different
               |  |  L_____ all the same
               |  L________ all different
               L___________ WRONG!

Indeed, all attributes have to pass the test.

Around the 4-tuple we construct a small class Card; this allows us to use a more object-oriented style (card0.isset(card1,card2) instead of modulename.isset(card0,card1,card2)) and also to use an alternative representation.

The obvious way to find all sets in a table of n cards is to check all 3-card combinations using a three-level nested loop, like Table.findsets_gnt() does. Not every level needs to loop over n cards: this would visit each 3-card combination 6 separate times (and moreover, investigate combinations that include the same card more than once). Instead, the loops are constructed such that for each 3-card combination, the first card (in table order) is represented by the outer loop variable, the second by the middle loop variable, and the third by the inner loop variable. Each combination is checked; when it is a set, it is appended to the list of found sets.

The first step to optimization is to realize that an individual attribute passes the set-test if and only if the sum of the values on the three cards for this attribute, modulo 3, equals 0: (0+0+0)%3==0, (1+1+1)%3==0, (2+2+2)%3==0, and (0+1+2)%3==0. This is exploited in Card.isset_mod() and Table.findsets_gnt_mod().

The second optimization step is to exploit the fact that each combination of two cards (card0,card1) forms a set with one unique other card. This card can be determined by setting card2.attrs[i] = (-card0.attrs[i]-card1.attrs[i])%3 for all i (as implemented in Card.thirdcard_simple()). There are faster ways for checking whether this card is on the table than by looping over the remaining cards. In Table.findsets_simple(), we put all the table cards in a dictionary have; a dictionary lookup is performed using the card's hash value instead of looping over the dictionary items. In addition, we need to check whether the third card is behind the other two cards in table order, otherwise sets are doubly counted.

In the final optimization, we represent the 4 attributes in one integer, with 2 bits per attribute:

# values of each attribute are encoded like this
0 -> 00
1 -> 01
2 -> 10

# example encoding for a whole card:
(1,0,2,1) -> 01 00 10 01   # so Card(1,0,2,1).bits==73

If we can devise a function that calculates (-x-y)%3 in this representation using bitwise operators, we can apply it to all 8 bits at the same time. And indeed we can (with a little help from Karnaugh maps); the result is Card.thirdcard_fast(). With this 8-bit representation of cards, we also chose to put the table cards in a 256-element list; we only need to index the list to check whether a card is on the table, and this is probably faster than a dictionary lookup.

Timing results

Using ipython.

>>> t = m.Table(12)
>>> timeit t.findsets_gnt()
1000 loops, best of 3: 684 us per loop
>>> timeit t.findsets_gnt_mod()
1000 loops, best of 3: 460 us per loop
>>> timeit t.findsets_simple()
1000 loops, best of 3: 404 us per loop
>>> timeit t.findsets_fast()
10000 loops, best of 3: 77.6 us per loop

Naturally, the differences become more pronounced with a table of 81 cards.

>>> t = m.Table(81)
>>> timeit t.findsets_gnt()
1 loops, best of 3: 238 ms per loop
>>> timeit t.findsets_gnt_mod()
10 loops, best of 3: 163 ms per loop
>>> timeit t.findsets_simple()
10 loops, best of 3: 20.6 ms per loop
>>> timeit t.findsets_fast()
100 loops, best of 3: 2.75 ms per loop

Of course, the motivation for this recipe aren't the milliseconds but the fun of computer science puzzles!