Welcome, guest | Sign In | My Account | Store | Cart
import inspect


def get_init_params(cls):
    """Return the parameters expected when calling the class."""
    initializer = cls.__init__
    if initializer is object.__init__ and cls.__new__ is not object.__new__:
        initializer = cls.__new__
    try:
        return list(inspect.signature(initializer).parameters)[1:]
    except TypeError:
        return None


def copiable(fields=None):
    """A class decorator factory that adds a copy method to the class.

    If fields is not passed, the following a tried in order:

    1. cls.__all__
    2. cls._fields
    3. the parameters of cls.__init__
    4. the parameters of cls.__new__

    """
    if isinstance(fields, type):
        return copiable(None)(fields)

    if isinstance(fields, str):
        fields = fields.replace(',', ' ').split()

    def decorator(cls):
        """Return the class after adding the copy() method."""
        names = fields
        if names is None:
            names = getattr(cls, '__all__',
                            getattr(cls, '_fields',
                                    get_init_params(cls)))
        if names is None:
            raise TypeError("could not determine the fields for this class.")

        def copy(self, **kwargs):
            """Return a copy of this object, with updates."""
            ns = dict((name, getattr(self, name)) for name in names)
            ns.update(kwargs)
            return type(self)(**ns)

        method_name = '_copy' if 'copy' in names else 'copy'
        if hasattr(cls, method_name):
            raise TypeError("{!r} already exists on the class")
        setattr(cls, method_name, copy)
        return cls
    return decorator

History