Welcome, guest | Sign In | My Account | Store | Cart
"""enum.py

This module provides a simple enum type, Enum, which may be subclassed
to create a new enum type:

>>> class Breakfast(Enum):
...     SPAM, HAM, EGGS
...     BACON, SAUSAGE
>>> Breakfast.SPAM
EnumValue(<Breakfast.SPAM>)

Enum values implement the bitwise operators:

>>> Breakfast.SPAM | Breakfast.HAM
EnumValue(<Breakfast.(SPAM|HAM)>)

The idea of using __prepare__() to make this work came from Michael Foord:

http://mail.python.org/pipermail/python-ideas/2013-January/019108.html

"""
__all__ = ['Enum']


class EnumValue:
    """An enum value, associated with an Enum.

    EnumValue instances behave much like frozen sets, particularly with
    regard to bitwise operations (&, ^, |) as well implementing
    subtraction and inversion (~).

    """

    def __init__(self, enum, name, values=None):
        # XXX restrict name to all caps?
        self._enum = enum
        self._name = name
        self._values = values

    def __repr__(self):
        if self._name is not None:
            values = self._name
        else:
            values = "({})".format("|".join(v._name for v in self.values))
        return "{}(<{}.{}>)".format(self.__class__.__name__,
                                    self._enum.__name__, values)

    def __eq__(self, other):
        return self is other

    def __hash__(self):
        return hash(self._name or self.values)

    def _merge(self, op, other):
        try:
            if self._enum != other._enum:
                return NotImplemented
            values = op(other.values)
        except AttributeError:
            return NotImplemented
        else:
            if not values:
                raise TypeError("An EnumValue must not be empty")
            return self._enum._join(values)

    def __and__(self, other):
        return self._merge(self.values.__and__, other)

    def __xor__(self, other):
        return self._merge(self.values.__xor__, other)

    def __or__(self, other):
        return self._merge(self.values.__or__, other)

    def __sub__(self, other):
        return self._merge(self.values.__sub__, other)

    def __invert__(self):
        inverted = self._enum.values - self.values
        return self._enum.join(inverted)

    @property
    def values(self):
        if self._name is not None:
            return frozenset([self])
        else:
            return self._values


class _EnumNamespace(dict):
    def __init__(self):
        self._names = set()
    def __getitem__(self, key):
        try:
            return dict.__getitem__(self, key)
        except KeyError:
            if key.startswith("_"):
                raise
            self._names.add(key)
    def __setitem__(self, key, value):
        if key.startswith("_"):
            dict.__setitem__(self, key, value)
        else:
            raise KeyError(key)
    def __delitem__(self, key):
        raise KeyError(key)


class _EnumMeta(type):

    @classmethod
    def make(meta, names, name='SimpleEnum'):
        """Return a new Enum subclass based on the passed names."""
        if isinstance(names, str):
            names = names.replace(',', ' ').split()
        else:
            names = list(map(str, names))

        ns = _EnumNamespace()
        ns._names = names
        return meta(name, (Enum,), ns)

    @classmethod
    def __prepare__(meta, name, bases):
        return _EnumNamespace()

    def __init__(cls, name, bases, namespace):
        # validation
        if len(bases) > 1:
            raise TypeError("expected 1 base class, got {}".format(len(bases)))
        elif not bases or bases == (object,):
            pass
        elif not issubclass(bases[0], Enum):
            raise TypeError("enums can only inherit from Enum")
        else:
            for value in bases[0].values:
                namespace._names.add(value._name)


        # set the values
        values = frozenset(EnumValue(cls, name) for name in namespace._names)
        super().__setattr__('values', values)
        for value in values:
            super().__setattr__(value._name, value)

        super().__setattr__('_joined_values', {})

    def __setattr__(cls, name, value):
        raise TypeError("enums are immutable")

    def __delattr__(cls, name):
        raise TypeError("enums are immutable")

    def __len__(cls):
        return len(cls._values)

    def __contains__(cls, value):
        return value in cls._values

    def __iter__(cls):
        return iter(cls._values)

    def _join(cls, values):
        # XXX make sure the values are in the enum?
        if values in cls._joined_values:
            return cls._joined_values[values]
        elif len(values) == 1:
            value = next(iter(values))
            cls._joined_values[values] = value
            return value
        else:
            joined = EnumValue(cls, None, values)
            cls._joined_values[values] = joined
            return joined

    def export(cls, namespace):
        """Copy the enums values into the namespace."""
        namespace.update((v._name, v) for v in cls.values)


class Enum(metaclass=_EnumMeta):
    """A base enum type, strictly for inheriting."""
    def __new__(cls, *args, **kwargs):
        raise TypeError("enums do not have instances")

Diff to Previous Revision

--- revision 1 2013-02-13 06:35:25
+++ revision 2 2013-02-13 06:53:38
@@ -165,7 +165,9 @@
         if values in cls._joined_values:
             return cls._joined_values[values]
         elif len(values) == 1:
-            return next(iter(values))
+            value = next(iter(values))
+            cls._joined_values[values] = value
+            return value
         else:
             joined = EnumValue(cls, None, values)
             cls._joined_values[values] = joined
@@ -173,7 +175,7 @@
 
     def export(cls, namespace):
         """Copy the enums values into the namespace."""
-        namespace.update(cls._values)
+        namespace.update((v._name, v) for v in cls.values)
 
 
 class Enum(metaclass=_EnumMeta):

History