Welcome, guest | Sign In | My Account | Store | Cart
# (c) Michael Foord, 2010
# http://voidspace.org.uk/blog
'''
Create objects that act as both context managers *and* as decorators, and behave the same in both cases.

Works with Python 2.4 - 2.7 and Python 3. The tests require unittest2 or Python 3.2 to run.

Example:

from contextdecorator import ContextDecorator

class mycontext(ContextDecorator):

    def __init__(self, *args):
        """Normal initialiser"""

    def before(self):
        """
        Called on entering the with block or starting the decorated function.
        
        If used in a with statement whatever this method returns will be the
        context manager.
        """
    
    def after(self, *exc):
        """
        Called on exit. Arguments and return value of this method have
        the same meaning as the __exit__ method of a normal context
        manager.
        """

@mycontext('some', 'args')
def function():
    pass

with mycontext('some', 'args') as something:
    pass

See the tests for more usage examples.
'''

# Only needed for tests
from __future__ import with_statement

import sys

try:
    from functools import wraps
except ImportError:
    # Python 2.4 compatibility
    def wraps(original):
        def inner(f):
            f.__name__ = original.__name__
            return f
        return inner

# horrible reraise code for compatibility
# with Python 2 & 3
if sys.version_info >= (3,0):
    exec ("""
def _reraise(cls, val, tb):
    raise val
""")
else:
    exec ("""
def _reraise(cls, val, tb):
    raise cls, val, tb
""")


EXC = (None, None, None)

class ContextDecorator(object):
    before = None
    after = None
        
    def __call__(self, f):
        @wraps(f)
        def inner(*args, **kw):
            if self.before is not None:
                self.before()
            
            exc = EXC
            try:
                result = f(*args, **kw)
            except Exception:
                exc = sys.exc_info()
            
            catch = False
            if self.after is not None:
                catch = self.after(*exc)
            
            if not catch and exc is not EXC:
                _reraise(*exc)
            return result
        return inner
            
    def __enter__(self):
        if self.before is not None:
            return self.before()
    
    def __exit__(self, *exc):
        catch = False
        if self.after is not None:
            catch = self.after(*exc)
        return catch


if __name__ == '__main__':
    import sys

    if sys.version_info >= (3, 2):
        import unittest as unittest2
    else:
        import unittest2
    
    class mycontext(ContextDecorator):
        started = False
        exc = None
        catch = False
    
        def before(self):
            self.started = True
            return self
    
        def after(self, *exc):
            self.exc = exc
            return self.catch


    class TestContext(unittest2.TestCase):

        def test_context(self):
            context = mycontext()
            with context as result:
                self.assertIs(result, context)
                self.assertTrue(context.started)
            
            self.assertEqual(context.exc, (None, None, None))
    
        def test_context_with_exception(self):
            context = mycontext()
        
            with self.assertRaisesRegexp(NameError, 'foo'):
                with context:
                    raise NameError('foo')
        
            context.exc = (None, None, None)
            context.catch = True
            with context:
                raise NameError('foo')
            self.assertNotEqual(context.exc, (None, None, None))
        
        def test_decorator(self):
            context = mycontext()
        
            @context
            def test():
                self.assertIsNone(context.exc)
                self.assertTrue(context.started)
            test()
            self.assertEqual(context.exc, (None, None, None))
    
        def test_decorator_with_exception(self):
            context = mycontext()
        
            @context
            def test():
                self.assertIsNone(context.exc)
                self.assertTrue(context.started)
                raise NameError('foo')
        
            with self.assertRaisesRegexp(NameError, 'foo'):
                test()
            self.assertNotEqual(context.exc, (None, None, None))

        def test_decorating_method(self):
            context = mycontext()
        
            class Test(object):
            
                @context
                def method(self, a, b, c=None):
                    self.a = a
                    self.b = b
                    self.c = c
        
            test = Test()
            test.method(1, 2)
            self.assertEqual(test.a, 1)
            self.assertEqual(test.b, 2)
            self.assertEqual(test.c, None)

            test.method('a', 'b', 'c')
            self.assertEqual(test.a, 'a')
            self.assertEqual(test.b, 'b')
            self.assertEqual(test.c, 'c')

    unittest2.main()

Diff to Previous Revision

--- revision 11 2010-06-27 12:49:03
+++ revision 12 2010-06-27 15:15:01
@@ -7,7 +7,7 @@
 
 Example:
 
-from context import ContextDecorator
+from contextdecorator import ContextDecorator
 
 class mycontext(ContextDecorator):
 

History