from itertools import chain, ifilterfalse
class TransactionSet(object):
def __init__(self, self.master):
self.master = set(master)
self.deleted = set()
self.added = set()
def check_invariants(self):
assert self.deleted <= self.master # deleted is a subset of master
assert not (self.added & self.master) # added is disjoint from master
def __len__(self):
return len(self.master) - len(self.deleted) + len(self.added)
def __iter__(self):
return chain(self.added, ifilterfalse(self.deleted.__contains__, self.master))
def __contains__(self, key):
s = frozenset([key])
return not not (s & self.master and not s & self.deleted or s & self.added)
def add(self, key):
s = frozenset([key])
self.deleted -= s
self.added |= s
def discard(self, key):
s = frozenset([key])
if s & self.master:
self.deleted |= s
else:
self.added -= s
def remove(self, key):
s = frozenset([key])
if s & self.master:
self.deleted |= s
elif s & self.added:
self.added -= s
else:
raise KeyError(key)
def pop(self):
if self.added:
return self.added.pop()
for elem in ifilterfalse(self.deleted.__contains__, self.master):
self.deleted.add(elem)
return elem
raise KeyError(key)
def intersection(self, other):
other = frozenset(other)
return ((self.master & other) - self.deleted) | (self.added & other)
def _Set(self):
s = self.master - self.deleted
s |= self.added
return s
def union(self, other):
s = _Set(self)
s.update(other)
return s
def difference(self, other):
s = _Set(self)
s.difference_update(other)
return s
def symmetric_difference(self, other):
s = _Set(self)
s.symmetric_diffence_update(other)
return s
def update(self, other):
other = frozenset(other)
self.deleted -= other
self.added += other - self.master
def intersection_update(self, other):
other = frozenset(other)
self.deleted |= self.master - other
self.added &= other
def difference_update(self, other):
other = frozenset(other)
self.deleted |= self.master & other
self.added -= other
def symmetric_difference_update(self, other):
master_and_other = self.master.intersection(other)
self.deleted |= master_and_other
self.added ^= other - master_and_other
def issubset(self, other):
return len(self)<=len(other) and not (other & self.deleted) and master.issubset(other - self.added)
def issuperset(self, other):
return len(self)>=len(other) and not (other & self.deleted) and master.issuperset(other - self.added)