Welcome, guest | Sign In | My Account | Store | Cart

Here's a class decorator that adds a rudimentary copy() method onto the decorated class. Use it like this:

@copiable
class SomethingDifferent:
    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c

or like this:

@copiable("a b c")
class SomethingDifferent:
    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c

s = SomethingDifferent(1,2,3)
sc = s.copy()
assert vars(s) == vars(sc)

(Python 3.3)

Python, 53 lines
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import inspect


def get_init_params(cls):
    """Return the parameters expected when calling the class."""
    initializer = cls.__init__
    if initializer is object.__init__ and cls.__new__ is not object.__new__:
        initializer = cls.__new__
    try:
        return list(inspect.signature(initializer).parameters)[1:]
    except TypeError:
        return None


def copiable(fields=None):
    """A class decorator factory that adds a copy method to the class.

    If fields is not passed, the following a tried in order:

    1. cls.__all__
    2. cls._fields
    3. the parameters of cls.__init__
    4. the parameters of cls.__new__

    """
    if isinstance(fields, type):
        return copiable(None)(fields)

    if isinstance(fields, str):
        fields = fields.replace(',', ' ').split()

    def decorator(cls):
        """Return the class after adding the copy() method."""
        names = fields
        if names is None:
            names = getattr(cls, '__all__',
                            getattr(cls, '_fields',
                                    get_init_params(cls)))
        if names is None:
            raise TypeError("could not determine the fields for this class.")

        def copy(self, **kwargs):
            """Return a copy of this object, with updates."""
            ns = dict((name, getattr(self, name)) for name in names)
            ns.update(kwargs)
            return type(self)(**ns)

        method_name = '_copy' if 'copy' in names else 'copy'
        if hasattr(cls, method_name):
            raise TypeError("{!r} already exists on the class")
        setattr(cls, method_name, copy)
        return cls
    return decorator

1 comment

James Mills 11 years, 2 months ago  # | flag

Nice ;)