Welcome, guest | Sign In | My Account | Store | Cart
#!/usr/bin/env python

from functools import wraps

class InvalidSignature(Exception): pass

def returns(rettype):
    def check(ret):
        if not isinstance(ret, rettype): raise InvalidSignature()
        return ret
    def returnchecker(func):
        @wraps(func)
        def _func(*args, **kwargs):
            return check(func(*args, **kwargs))

        _func.returns = rettype
        return _func
    return returnchecker

def signature(*argtypes, **kwtypes):
    def check(args, kwargs):
        if not len(args) == len(argtypes): raise InvalidSignature()
        if not all(isinstance(a, b) for a, b in zip(args, argtypes)): raise InvalidSignature()

        if not len(kwargs) == len(kwtypes): raise InvalidSignature()
        if not set(kwargs) == set(kwtypes) : raise InvalidSignature()
        if not all(isinstance(kwargs[kw], kwtypes[kw]) for kw in kwtypes): raise InvalidSignature()

    def typechecker(func):
        @wraps(func)
        def _func(*args, **kwargs):
            check(args, kwargs)
            return func(*args, **kwargs)

        _func.signature = argtypes, kwtypes
        return _func
    return typechecker

def overloaded(func):
    @wraps(func)
    def overloaded_func(*args, **kwargs):
        for f in overloaded_func.overloads:
            try:
                return f(*args, **kwargs)
            except (InvalidSignature, TypeError):
                pass
        else:
            raise TypeError("No compatible signatures")

    def overload_with(func):
        overloaded_func.overloads.append(func)
        return overloaded_func
    overloaded_func.overloads = [func]
    overloaded_func.overload_with = overload_with
    return overloaded_func

#############

@overloaded
def a():
    print 'no args a'
    pass
@a.overload_with
def a(n):
    print 'arged a'
    pass

a()
a(4)


@overloaded
@returns(int)
@signature(int, int, float)
def foo(a, b, c):
    return int(a * b * c)

@foo.overload_with
@returns(int)
@signature(int, float, c=int)
def foo(a, b, c):
    return int(a + b + c)

print foo(2, 3, 4.)
print foo(10, 3., c=30)
print foo(1, 9., 3, 3)

History