Welcome, guest | Sign In | My Account | Store | Cart
import collections

class OrderedSet(collections.MutableSet):

    def __init__(self, iterable=None):
        self.end = end = [] 
        end += [None, end, end]         # sentinel node for doubly linked list
        self.map = {}                   # key --> [key, prev, next]
        if iterable is not None:
            self |= iterable

    def __len__(self):
        return len(self.map)

    def __contains__(self, key):
        return key in self.map

    def add(self, key):
        if key not in self.map:
            end = self.end
            curr = end[1]
            curr[2] = end[1] = self.map[key] = [key, curr, end]

    def discard(self, key):
        if key in self.map:        
            key, prev, next = self.map.pop(key)
            prev[2] = next
            next[1] = prev

    def __iter__(self):
        end = self.end
        curr = end[2]
        while curr is not end:
            yield curr[0]
            curr = curr[2]

    def __reversed__(self):
        end = self.end
        curr = end[1]
        while curr is not end:
            yield curr[0]
            curr = curr[1]

    def pop(self, last=True):
        if not self:
            raise KeyError('set is empty')
        key = self.end[1][0] if last else self.end[2][0]
        self.discard(key)
        return key

    def __repr__(self):
        if not self:
            return '%s()' % (self.__class__.__name__,)
        return '%s(%r)' % (self.__class__.__name__, list(self))

    def __eq__(self, other):
        if isinstance(other, OrderedSet):
            return len(self) == len(other) and list(self) == list(other)
        return set(self) == set(other)

            
if __name__ == '__main__':
    s = OrderedSet('abracadaba')
    t = OrderedSet('simsalabim')
    print(s | t)
    print(s & t)
    print(s - t)

Diff to Previous Revision

--- revision 8 2012-12-18 20:38:26
+++ revision 9 2012-12-19 07:12:32
@@ -1,6 +1,4 @@
 import collections
-
-KEY, PREV, NEXT = range(3)
 
 class OrderedSet(collections.MutableSet):
 
@@ -20,33 +18,33 @@
     def add(self, key):
         if key not in self.map:
             end = self.end
-            curr = end[PREV]
-            curr[NEXT] = end[PREV] = self.map[key] = [key, curr, end]
+            curr = end[1]
+            curr[2] = end[1] = self.map[key] = [key, curr, end]
 
     def discard(self, key):
         if key in self.map:        
             key, prev, next = self.map.pop(key)
-            prev[NEXT] = next
-            next[PREV] = prev
+            prev[2] = next
+            next[1] = prev
 
     def __iter__(self):
         end = self.end
-        curr = end[NEXT]
+        curr = end[2]
         while curr is not end:
-            yield curr[KEY]
-            curr = curr[NEXT]
+            yield curr[0]
+            curr = curr[2]
 
     def __reversed__(self):
         end = self.end
-        curr = end[PREV]
+        curr = end[1]
         while curr is not end:
-            yield curr[KEY]
-            curr = curr[PREV]
+            yield curr[0]
+            curr = curr[1]
 
     def pop(self, last=True):
         if not self:
             raise KeyError('set is empty')
-        key = next(reversed(self)) if last else next(iter(self))
+        key = self.end[1][0] if last else self.end[2][0]
         self.discard(key)
         return key
 
@@ -62,5 +60,8 @@
 
             
 if __name__ == '__main__':
-    print(OrderedSet('abracadaba'))
-    print(OrderedSet('simsalabim'))
+    s = OrderedSet('abracadaba')
+    t = OrderedSet('simsalabim')
+    print(s | t)
+    print(s & t)
+    print(s - t)

History