Welcome, guest | Sign In | My Account | Store | Cart

This recipe applies the "once and only once" principle to function default values. Often two or more callables specify the same default values for one or more arguments. This is especially typical when overriding a method. Using the defaultsfrom(func) decorator, a method may 'inherit' the default values from the super method. More generally, any function may inherit the default values from another one.

Python, 126 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import inspect
import itertools as it
from collections import deque

__all__ = ['defaultsfrom']

#======= demos ===============================================================

def demo_func():
    '''Reuse function default values.'''

    def genericsort(iterable, cmp=None, key=None, reverse=False):
        raise NotImplementedError()

    @defaultsfrom(genericsort)
    def timsort(iterable, cmp, key, reverse):
        return sorted(iterable, cmp, key, reverse)

    s = ['hello', 'WORLD']
    timsort(s)
    timsort(s, key=str.lower)


def demo_class():
    '''Reuse method default values.'''

    import logging
    class GenericLogger(object):
        def __init__(self, name, level=logging.WARNING, handlers=(), 
                     filters=()):
            self._logger = logging.getLogger(name)
            self._logger.setLevel(level)
            for h in handlers: self._logger.addHandler(h)
            for f in filters: self._logger.addFilter(f)


    class FancyLogger(GenericLogger):

        @defaultsfrom(GenericLogger)
        def __init__(self, name, level, handlers, filters, bells=None,
                     whistles=None, debug=False):
            super(FancyLogger,self).__init__(name, level, handlers, filters)
            for attr in 'bells', 'whistles', 'debug':
                setattr(self, '_'+attr, eval(attr))

    mylog = FancyLogger('demo', filters=[logging.Filter()], debug=True)


#======= defaultsfrom ========================================================

def defaultsfrom(funcOrClass):
    '''Return a decorator d so that d(func) updates func's default arguments.

    If funcOrClass is a function (or method) 'foo', its default arguments are
    'inherited' by any function 'bar' decorated by the returned decorator:

    >>> def foo(a, x=0, y=''): pass
    >>> @defaultsfrom(foo)
    ... def bar(a, b, y, x, z=None): return a,b,y,x,z
    >>> bar(1,2,3,4,5)
    (1, 2, 3, 4, 5)
    >>> bar(1,2,3,4)
    (1, 2, 3, 4, None)
    >>> bar(1,2,3)
    (1, 2, 3, 0, None)
    >>> bar(1,2)
    (1, 2, '', 0, None)

    Any default arguments redefined by 'bar' are not inherited by 'foo':
    >>> @defaultsfrom(foo)
    ... def zong(a, x, y=-1): # y redefined
    ...     return a,x,y
    >>> zong(2)
    (2, 0, -1)

    Default arguments (inherited or not) cannot precede non-default ones:
    >>> @defaultsfrom(foo)
    ... def zap(a, x, b, y): # b is not a default arg; x cannot be inherited
    ...     return a,x,b,y
    Traceback (most recent call last):
        ...
    TypeError: ...

    If funcOrClass is a class, its method with the same name with the
    decorated function is handled as above:
    >>> class Base(object):
    ...    def __init__(self, a, b=0, c=None): pass
    >>> class Derived(Base):
    ...    @defaultsfrom(Base)  # equivalent to Base.__init__
    ...    def __init__(self, a, b, c): print (a,b,c)
    >>> d = Derived(1)
    (1, 0, None)
    '''

    def decorator(newfunc):
        if inspect.isclass(funcOrClass):
            func = getattr(funcOrClass, newfunc.__name__)
        else:
            func = funcOrClass
        args,_,_,defaults = inspect.getargspec(func)
        # map each default argument of func to its value
        arg2default = dict(zip(args[-len(defaults):],defaults))
        newargs,_,_,newdefaults = inspect.getargspec(newfunc)
        if newdefaults is None: newdefaults = ()
        nondefaults = newargs[:len(newargs)-len(newdefaults)]
        # starting from the last non-default argument towards the first, as
        # long as the non-defaults of newfunc are default in func, make them
        # default in newfunc too
        iter_nondefaults = reversed(nondefaults)
        newdefaults = deque(newdefaults)
        for arg in it.takewhile(arg2default.__contains__, iter_nondefaults):
            newdefaults.appendleft(arg2default[arg])
        # all inherited defaults should be placed together; no gaps allowed
        for arg in it.ifilter(arg2default.__contains__, iter_nondefaults):
            raise TypeError('%s cannot inherit the default arguments of '
                            '%s' % (newfunc, func))
        newfunc.func_defaults = tuple(newdefaults)
        return newfunc
    return decorator


if __name__ == '__main__':
    import doctest
    doctest.testmod(optionflags=doctest.ELLIPSIS)
    demo_func()
    demo_class()