Welcome, guest | Sign In | My Account | Store | Cart
# By Christian Muirhead, Menno Smits and Michael Foord 2008
# WTF license
# http://voidspace.org.uk/blog

"""
``total_ordering`` and ``force_total_ordering`` are class decorators for 
Python 2.6 & Python 3.

They provides *all* the rich comparison methods on a class by defining *any*
one of '__lt__', '__gt__', '__le__', '__ge__'.

``total_ordering`` fills in all unimplemented rich comparison methods, assuming
at least one is implemented. ``__lt__`` is taken as the base comparison method
on which the others are built, but if that is not available it will be
constructed from the first one found.

``force_total_ordering`` does the same, but having taken a comparison method as
the base it fills in *all* the others - this overwrites additional comparison
methods that may be implemented, guaranteeing consistent comparison semantics.

::
    
    from total_ordering import total_ordering
    
    @total_ordering
    class Something(object):
        def __init__(self, value):
            self.value = value
        def __lt__(self, other):
            return self.value < other.value

It also works with Python 2.5, but you need to do the wrapping yourself:

::
    
    from total_ordering import total_ordering
    
    class Something(object):
        def __init__(self, value):
            self.value = value
        def __lt__(self, other):
            return self.value < other.value

    total_ordering(Something)

It would be easy to modify for it to work as a class decorator for Python
3.X and a metaclass for Python 2.X.
"""


import sys as _sys

if _sys.version_info[0] == 3:
    def _has_method(cls, name):
        for B in cls.__mro__:
            if B is object:
                continue
            if name in B.__dict__:
                return True
        return False
else:
    def _has_method(cls, name):
        for B in cls.mro():
            if B is object:
                continue
            if name in B.__dict__:
                return True
        return False



def _ordering(cls, overwrite):
    def setter(name, value):
        if overwrite or not _has_method(cls, name):
            value.__name__ = name
            setattr(cls, name, value)
            
    comparison = None
    if not _has_method(cls, '__lt__'):
        for name in 'gt le ge'.split():
            if not _has_method(cls, '__' + name + '__'):
                continue
            comparison = getattr(cls, '__' + name + '__')
            if name.endswith('e'):
                eq = lambda s, o: comparison(s, o) and comparison(o, s)
            else:
                eq = lambda s, o: not comparison(s, o) and not comparison(o, s)
            ne = lambda s, o: not eq(s, o)
            if name.startswith('l'):
                setter('__lt__', lambda s, o: comparison(s, o) and ne(s, o))
            else:
                setter('__lt__', lambda s, o: comparison(o, s) and ne(s, o))
            break
        assert comparison is not None, 'must have at least one of ge, gt, le, lt'

    setter('__ne__', lambda s, o: s < o or o < s)
    setter('__eq__', lambda s, o: not s != o)
    setter('__gt__', lambda s, o: o < s)
    setter('__ge__', lambda s, o: not (s < o))
    setter('__le__', lambda s, o: not (s > o))
    return cls


def total_ordering(cls):
    return _ordering(cls, False)

def force_total_ordering(cls):
    return _ordering(cls, True)


def _test():
    class Thing(object):
        def __init__(self, val):
            self.val = val

    class Thing_lt(Thing):
        def __lt__(self, other):
            return self.val < other.val

    class Thing_gt(Thing):
        def __gt__(self, other):
            return self.val > other.val

    class Thing_ge(Thing):
        def __ge__(self, other):
            return self.val >= other.val

    class Thing_le(Thing):
        def __le__(self, other):
            return self.val <= other.val    

    for cls in [Thing_lt, Thing_gt, Thing_le, Thing_ge]:

        print (cls.__name__)
        for ordering in (total_ordering, force_total_ordering):
            cls = ordering(cls)
            t1 = cls(1)
            t2 = cls(2)

            assert t1 < t2, 'lt'
            assert t1 == t1, 'eq'
            assert t1 != t2, 'ne'
            assert t2 > t1, 'gt'
            assert t2 >= t2, 'ge'
            assert t1 <= t1, 'le'
    print ('no errors')

if __name__ == '__main__':
    _test()

History

  • revision 5 (15 years ago)
  • previous revisions are not available