Welcome, guest | Sign In | My Account | Store | Cart

The builtin set and frozenset types are based on object equality; they call __eq__ to determine whether an object is a member of the set or not. But there are cases when one needs a set of objects that are compared by other means, apart from the default __eq__ function. There are several ways to achieve that; this recipe presents two classes, FrozenKeyedSet and KeyedSet, that take an additional function key which is used to determine membership and uniqueness. Given two objects which return the same value for key, only one of them will be in the set.

Python, 310 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
FrozenKeyedSet and KeyedSet are set implementations that use 
a custom element equality function, so that set membership and 
uniqueness is not based on the element's __eq__ but on a user-supplied
function.

>>> class Foo(object):
...   def __init__(self, name): self.name = name
...   def __repr__(self): return "%s(%r)" % (type(self).__name__, self.name)
>>> def getname(o):
...   return o.name
>>> objs = [Foo('Joe'), Foo('Jim'), Foo('Tom'), Foo('Jim')]
>>> s = KeyedSet(objs, key=getname)
>>> s
KeyedSet([Foo('Jim'), Foo('Joe'), Foo('Tom')])
>>> s.add(Foo('Joe'))
>>> len(s)
3
>>> joe = Foo('Joe')
>>> joe in s
True
>>> 'Joe' in s
False
>>> s2 = set([Foo('Luc'), Foo('Jim'), Foo('Dan')])
>>> s | s2
KeyedSet([Foo('Dan'), Foo('Jim'), Foo('Luc'), Foo('Joe'), Foo('Tom')])
>>> s & s2
KeyedSet([Foo('Jim')])
>>> s2f = FrozenKeyedSet(s2, key=getname)
>>> s2f - s
FrozenKeyedSet([Foo('Dan'), Foo('Luc')])
>>> for elem in s ^ s2f:
...   print(elem.name)
Dan
Luc
Joe
Tom
>>> s.copy()
KeyedSet([Foo('Jim'), Foo('Joe'), Foo('Tom')])
>>> # FrozenKeyedSet can be used as dictionary key
... d = {}
>>> d[s2f] = ['anything', 'else']

Copyright (C) 2009 Gabriel A. Genellina

"""

__author__ = 'Gabriel A. Genellina' 
__version__ = "$Revision: 1.12 $"[11:-2]

from collections import Set, MutableSet


# Abstract classes Set and MutableSet define only the operators (&, &=, etc.)
# but not the corresponding "worded" methods (intersection, intersection_update, etc.)
# The former require a Set second argument, the later accept any iterable.
# This helper function is used to build the "worded" method variant on top of
# the corresponding operator.
def _build_variant(cls, opname):
    fn = getattr(cls, opname)
    if hasattr(fn, 'im_func'):
        fn = fn.im_func
    def method(self, other, fn=fn):
        if not isinstance(other, Set):
            other = self._from_iterable(other)
        return fn(self, other)
    return method


class FrozenKeyedSet(Set):
    """A frozen set that uses a custom element equality function."""

    # "named" methods like those of frozenset, in addition to operators
    intersection = _build_variant(Set, '__and__')
    union = _build_variant(Set, '__or__')
    difference = _build_variant(Set, '__sub__')
    symmetric_difference = _build_variant(Set, '__xor__')
    issubset = _build_variant(Set, '__le__')
    issuperset = _build_variant(Set, '__ge__')

    def __init__(self, iterable, key=lambda x: x):
        """Create a FrozenKeyedSet from iterable; key function determines uniqueness.

        `key` must be a callable taking a single argument; it is
        applied to every item in the iterable, and its result
        is used to determine set membership. That is, if key(item)
        returns the same for two items, only one of them will be
        in the set.
        The key *must* return a hashable object.
        """
        self._items = dict((key(item),item) for item in iterable)
        self._key = key

    # Implementation of abstract methods from the Set ABC

    def __iter__(self):
        return iter(self._items.values())

    def __contains__(self, value):
        try: key = self._key(value)
        except Exception: return False
        return key in self._items

    def __len__(self):
        return len(self._items)

    # NOT a classmethod because self.key must be transferred too!
    # Fortunately it is always called as self._from_iterable(...)
    # in _abccoll.py
    def _from_iterable(self, iterable):
        return type(self)(iterable, key=self._key)

    def copy(self):
        return type(self)(self._items.values(), key=self._key)

    def __repr__(self):
        return "%s(%r)" % (type(self).__name__, list(self._items.values()))

    def __hash__(self):
        # Set already provides an implementation for a hash method (_hash), 
        # but we have to apply it to the keys only.
        # Python 2.x requires the 'self' argument to be a Set instance, 
        # so we call the underlying function directly. Python 3.x has 
        # relaxed the requirement.
        _hash = Set._hash
        if hasattr(_hash, 'im_func'):
            _hash = _hash.im_func
        return _hash(self._items.keys())


class KeyedSet(FrozenKeyedSet, MutableSet):
    """A mutable set that uses a custom element equality function."""

    # "named" methods like those of `set` class, in addition to operators
    intersection_update = _build_variant(MutableSet, '__iand__')
    update = _build_variant(MutableSet, '__ior__')
    difference_update = _build_variant(MutableSet, '__isub__')
    symmetric_difference_update = _build_variant(MutableSet, '__ixor__')

    __hash__ = None # because FrozenKeyedSet implements it
    
    # Implementation of abstract methods from the MutableSet ABC

    def add(self, value):
        """Add an element."""
        key = self._key(value)
        if key not in self._items: 
            self._items[key] = value

    def discard(self, value):
        """Remove an element.  Do not raise an exception if absent."""
        key = self._key(value)
        try: del self._items[key]
        except KeyError: pass

    # performance: override implementation in MutableSet
    def clear(self):
        self._items.clear()


# From this point on, only tests.
# Perhaps there are more test cases than really required,
# but the MutableSet ABC isn't tested very much in the
# standard regression tests.

from unittest import TestCase

class Foo(object):
    def __init__(self, name): self.name = name
    def __repr__(self): return "%s(%r)" % (type(self).__name__, self.name)


class KeyedSetTestCase(TestCase):

    def setUp(self):
        self.lst1 = lst1 = [Foo('Joe'), Foo('Jim'), Foo('Tom'), Foo('Jim')]
        self.s1 = KeyedSet(lst1, key=lambda o:o.name)
        self.s1f = FrozenKeyedSet(lst1, key=lambda o:o.name)
        self.lst2 = lst2 = [Foo('Luc'), Foo('Jim'), Foo('Dan')]
        self.s2 = KeyedSet(lst2, key=lambda o:o.name)
        self.s2f = FrozenKeyedSet(lst2, key=lambda o:o.name)

    def test_addremove(self):
        s1 = self.s1
        s1f = self.s1f

        s1.add(Foo('Joe'))
        self.assertEquals(3, len(s1))
        self.assertTrue(Foo('Joe') in s1)
        s1.remove(Foo('Joe'))
        self.assertTrue(Foo('Joe') not in s1)
        s1.discard(Foo('Joe'))
        self.assertRaises(KeyError, s1.remove, Foo('Joe'))
        s1.discard(Foo('Jim'))
        self.assertTrue(Foo('Jim') not in s1)

        self.assertEquals(3, len(s1f))
        self.assertTrue(Foo('Joe') in s1f)
        self.assertFalse(hasattr(s1f,'add') or hasattr(s1f,'remove') or hasattr(s1f,'discard'))

    def test_eq_hash(self):
        jimtomjoe = [Foo('Jim'), Foo('Tom'), Foo('Joe')]
        s3 = KeyedSet(jimtomjoe, key=lambda o:o.name)
        self.assertEquals(self.s1, s3)
        self.assertEquals(s3, self.s1)
        s3f = FrozenKeyedSet(jimtomjoe, key=lambda o:o.name)
        self.assertEquals(self.s1, s3f)
        self.assertEquals(s3f, self.s1)
        self.assertEquals(s3, set(jimtomjoe))
        self.assertEquals(s3f, set(jimtomjoe))
        self.assertEquals(s3f, FrozenKeyedSet(self.s1, key=lambda o:o.name))
        self.assertEquals(hash(s3f), hash(FrozenKeyedSet(self.s1, key=lambda o:o.name)))
        self.assertEquals(s3f, KeyedSet(self.s1, key=lambda o:o.name))
        self.assertRaises(TypeError, hash, self.s1)

    def test_manip(self):
        s3 = self.s1.copy()
        self.assertEquals(self.s1, s3)
        self.assertEquals(self.s1f, s3)
        e = s3.pop()
        self.assertTrue(e in self.s1 and e not in s3)
        s3.clear()
        self.assertEquals(0, len(s3))

    def test_repr(self):
        self.assertEquals(eval(repr(self.s1)), self.s1)
        self.assertTrue(isinstance(eval(repr(self.s1)), type(self.s1)))
        self.assertEquals(eval(repr(self.s1f)), self.s1f)
        self.assertTrue(isinstance(eval(repr(self.s1f)), type(self.s1f)))

    def test_and(self):
        expected = ['Jim']
        s = self.s1f & self.s2
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1.copy()
        s &= self.s2
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1f.intersection(self.lst2)
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1.copy()
        s.intersection_update(self.lst2)
        self.assertEquals(sorted(x.name for x in s), expected)

    def test_or(self):
        expected = ['Dan', 'Jim', 'Joe', 'Luc', 'Tom']
        s = self.s1f | self.s2
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1.copy()
        s |= self.s2
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1f.union(self.lst2)
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1.copy()
        s.update(self.lst2)
        self.assertEquals(sorted(x.name for x in s), expected)

    def test_minus(self):
        expected = ['Joe', 'Tom']
        s = self.s1f - self.s2
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1.copy()
        s -= self.s2
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1f.difference(self.lst2)
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1.copy()
        s.difference_update(self.lst2)
        self.assertEquals(sorted(x.name for x in s), expected)

    def test_xor(self):
        expected = ['Dan', 'Joe', 'Luc', 'Tom']
        s = self.s1f ^ self.s2
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1.copy()
        s ^= self.s2
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1f.symmetric_difference(self.lst2)
        self.assertEquals(sorted(x.name for x in s), expected)
        s = self.s1.copy()
        s.symmetric_difference_update(self.lst2)
        self.assertEquals(sorted(x.name for x in s), expected)

    def test_relations(self):
        self.assertFalse(self.s1f.isdisjoint(self.s2))
        s3 = self.s1 - self.s2
        self.assertTrue(s3.isdisjoint(self.s2))

        self.assertFalse(self.s2f.issubset(self.s1))
        s3 = self.s1 & self.s2
        self.assertTrue(s3.issubset(self.s1f))
        self.assertTrue(s3 <= self.s1)
        self.assertTrue(s3 < self.s1)
        self.assertTrue(s3 <= s3.copy())
        self.assertFalse(s3 < s3.copy())

        self.assertFalse(self.s2f.issuperset(self.s1))
        s3 = self.s1 & self.s2
        self.assertTrue(self.s1f.issuperset(s3))
        self.assertTrue(self.s1 >= s3)
        self.assertTrue(self.s1 > s3)
        self.assertTrue(s3.copy() >= s3)
        self.assertFalse(s3.copy() > s3)


if __name__ == "__main__":
    import unittest
    from doctest import DocTestSuite
    suite = unittest.TestLoader().loadTestsFromTestCase(KeyedSetTestCase)
    suite.addTest(DocTestSuite())
    unittest.TextTestRunner().run(suite)

Suppose we have a group of many people and we want a set containing one of each profession, no matter who. One alternative would be building a dictionary:

dict((person.profession, person) for person in group)

but it's a dictionary, not a real set. Another one, creating a somewhat artificial PersonThatComparesByProfession:

class PersonThatComparesByProfession(Person):
    def __eq__(self, other):
        return self.profession == other.profession

set(PersonThatComparesByProfession(person) for person in group)

but saying that Joe and Mary are the same person just because they share the same profession doesn't look right.

This recipe presents another alternative, a KeyedSet class that takes a function used to compute element membership and uniqueness:

KeyedSet(group, key=lambda person: person.profession)

Of those persons that have the same key value (that is, that have the same profession), only one of them will be on the set (actually, the first one encountered). A KeyedSet implements the collections.MutableSet interface (same as the builtin set type) and can be used and combined the same way as any other set:

>>> objs = [Foo('Joe'), Foo('Jim'), Foo('Tom'), Foo('Jim')]
>>> s = KeyedSet(objs, key=lambda o: o.name)
>>> s
KeyedSet([Foo('Jim'), Foo('Joe'), Foo('Tom')])
>>> s.add(Foo('Joe'))
>>> len(s)
3
>>> joe = Foo('Joe')
>>> joe in s
True
>>> 'Joe' in s
False
>>> s2 = set([Foo('Luc'), Foo('Jim'), Foo('Dan')])
>>> s | s2
KeyedSet([Foo('Dan'), Foo('Jim'), Foo('Luc'), Foo('Joe'), Foo('Tom')])
>>> s & s2
KeyedSet([Foo('Jim')])

There is a FrozenKeyedSet class too, that implements the collections.Set interface (similar to the builtin frozenset type):

>>> s2f = FrozenKeyedSet(s2, key=lambda o: o.name)
>>> s2f - s
FrozenKeyedSet([Foo('Dan'), Foo('Luc')])
>>> for elem in s ^ s2f:
...   print(elem.name)
Dan
Luc
Joe
Tom

FrozenKeyedSet can be used as a dictionary key:

>>> d = {}
>>> d[s2f] = ['anything', 'else']

The key function defaults to identity (in that case KeyedSet behaves exactly as the normal set type, but much slower). It must return a hashable key that will be used to determine whether the object is already in the set or not (the key is compared by equality, __eq__, as in normal sets). Note that (unlike the key argument to sort, sorted, max, min) a KeyedSet/FrozenKeyedSet object holds a reference to the function itself, and calls it to implement some operations after the set has been created (e.g. in, add, remove).

This code works without modifications on Python 2.6, 2.7 (development), and 3.1

1 comment

Raymond Hettinger 13 years, 8 months ago  # | flag

In Py2.7 and Py3.x, if you build a dictionary person.profession-->person, the keys view and items view provide set operations. Does that help with this recipe?