# By Christian Muirhead, Menno Smits and Michael Foord 2008 # WTF license # http://voidspace.org.uk/blog """ ``total_ordering`` and ``force_total_ordering`` are class decorators for Python 2.6 & Python 3. They provides *all* the rich comparison methods on a class by defining *any* one of '__lt__', '__gt__', '__le__', '__ge__'. ``total_ordering`` fills in all unimplemented rich comparison methods, assuming at least one is implemented. ``__lt__`` is taken as the base comparison method on which the others are built, but if that is not available it will be constructed from the first one found. ``force_total_ordering`` does the same, but having taken a comparison method as the base it fills in *all* the others - this overwrites additional comparison methods that may be implemented, guaranteeing consistent comparison semantics. :: from total_ordering import total_ordering @total_ordering class Something(object): def __init__(self, value): self.value = value def __lt__(self, other): return self.value < other.value It also works with Python 2.5, but you need to do the wrapping yourself: :: from total_ordering import total_ordering class Something(object): def __init__(self, value): self.value = value def __lt__(self, other): return self.value < other.value total_ordering(Something) It would be easy to modify for it to work as a class decorator for Python 3.X and a metaclass for Python 2.X. """ import sys as _sys if _sys.version_info[0] == 3: def _has_method(cls, name): for B in cls.__mro__: if B is object: continue if name in B.__dict__: return True return False else: def _has_method(cls, name): for B in cls.mro(): if B is object: continue if name in B.__dict__: return True return False def _ordering(cls, overwrite): def setter(name, value): if overwrite or not _has_method(cls, name): value.__name__ = name setattr(cls, name, value) comparison = None if not _has_method(cls, '__lt__'): for name in 'gt le ge'.split(): if not _has_method(cls, '__' + name + '__'): continue comparison = getattr(cls, '__' + name + '__') if name.endswith('e'): eq = lambda s, o: comparison(s, o) and comparison(o, s) else: eq = lambda s, o: not comparison(s, o) and not comparison(o, s) ne = lambda s, o: not eq(s, o) if name.startswith('l'): setter('__lt__', lambda s, o: comparison(s, o) and ne(s, o)) else: setter('__lt__', lambda s, o: comparison(o, s) and ne(s, o)) break assert comparison is not None, 'must have at least one of ge, gt, le, lt' setter('__ne__', lambda s, o: s < o or o < s) setter('__eq__', lambda s, o: not s != o) setter('__gt__', lambda s, o: o < s) setter('__ge__', lambda s, o: not (s < o)) setter('__le__', lambda s, o: not (s > o)) return cls def total_ordering(cls): return _ordering(cls, False) def force_total_ordering(cls): return _ordering(cls, True) def _test(): class Thing(object): def __init__(self, val): self.val = val class Thing_lt(Thing): def __lt__(self, other): return self.val < other.val class Thing_gt(Thing): def __gt__(self, other): return self.val > other.val class Thing_ge(Thing): def __ge__(self, other): return self.val >= other.val class Thing_le(Thing): def __le__(self, other): return self.val <= other.val for cls in [Thing_lt, Thing_gt, Thing_le, Thing_ge]: print (cls.__name__) for ordering in (total_ordering, force_total_ordering): cls = ordering(cls) t1 = cls(1) t2 = cls(2) assert t1 < t2, 'lt' assert t1 == t1, 'eq' assert t1 != t2, 'ne' assert t2 > t1, 'gt' assert t2 >= t2, 'ge' assert t1 <= t1, 'le' print ('no errors') if __name__ == '__main__': _test()