Welcome, guest | Sign In | My Account | Store | Cart

Often one wants to rebind a name or modify a mutable object, perform a bunch of actions and finally restore the name/object to its original state. An example is redirecting stdout/stderr temporarily (http://www.diveintopython.org/scripts_and_streams/stdin_stdout_stderr.html). The restoring context manager shown below simplifies this pattern::

import sys

# prints in console
print "hello world!"

with restoring('sys.stdout'):
    with open('hello.txt', 'w') as sys.stdout:
        # prints in file
        print "hello world!"

# prints in console again
print "hello world!"
Python, 140 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import with_statement
from contextlib import contextmanager
import sys

__docformat__ = "restructuredtext"


@contextmanager
def restoring(expr, clone=None):
    '''A context manager that evaluates an expression when entering the runtime
    context and restores its value when exiting.
    
    This context manager makes
    
    .. python::
        with restoring(expr, clone) as value:
            BODY        
    
    a shortcut for
    
    .. python::
        value = EXPR
        __cloned = clone(value) if clone is not None else value
        try:
            BODY
        finally:
            EXPR = __cloned
            del __cloned
    
    where ``__cloned`` is a temporary hidden name and ``EXPR`` is ``expr``
    substituted textually in the code snippet. Therefore ``expr`` can only be an
    assignable expression, i.e. an expression that is allowed on the left hand
    side of '=' (e.g. identifier, subscription, attribute reference, etc.).
        
    :param expr: The expression whose value is to be evaluated and restored.
    :type expr: str
    :param clone: A callable that takes the object ``expr`` evaluates to and
        returns an appropriate copy to be used for restoring. If None, the
        original object is used.        
    :type clone: callable or None
    '''
    f = sys._getframe(2)    # bypass the contextmanager frame
    # evaluate the expression and make a clone of the value to be restored
    value = eval(expr, f.f_globals, f.f_locals)
    restored_value = clone(value) if clone is not None else value
    try:
        yield value
    finally:
        if expr in f.f_locals:      # local or nonlocal name
            _update_locals(f, {expr:restored_value})
        elif expr in f.f_globals:   # global name
            f.f_globals[expr] = restored_value
        else:
            # make a copy of f_locals and bind restored_value to a new name
            tmp_locals = dict(f.f_locals)
            tmp_name = '__' + min(tmp_locals)
            tmp_locals[tmp_name] = restored_value
            exec '%s = %s' % (expr, tmp_name) in f.f_globals, tmp_locals


def _update_locals(frame, new_locals, clear=False):
    # XXX: obscure, most likely implementation-dependent fact:
    # f_locals can be modified (only?) from within a trace function
    f_trace = frame.f_trace
    try:
        sys_trace = sys.gettrace()
    except AttributeError: # sys.gettrace() not available before 2.6
        sys_trace = None
    def update_tracer(frm, event, arg):
        # Update the frame's locals and restore both the local and the system's
        #trace function
        assert frm is frame
        if clear:
            frm.f_locals.clear()
        frm.f_locals.update(new_locals)
        frm.f_trace = f_trace
        sys.settrace(sys_trace)
    # Force tracing on with setting the global tracing function and set
    # the frame's local trace function
    sys.settrace(lambda frame, event, arg: None)
    frame.f_trace = update_tracer


def test_restoring_immutable():
    x = 'b'
    foo = {'a':3, 'b':4}
    with restoring('foo[x]') as y:
        assert y == foo[x] == 4
        foo[x] = y = None
        assert y == foo[x] == None
    assert foo[x] == 4 and y == None 
    assert sorted(locals()) == ['foo', 'x', 'y']

def test_restoring_mutable():
    orig_path = sys.path[:]
    with restoring('sys.path', clone=list) as path: 
        assert path is sys.path
        path += ['foo']
        assert path == orig_path + ['foo']
    assert sys.path == orig_path 
    assert path == orig_path + ['foo']
    assert sorted(locals()) == ['orig_path', 'path']

x = 1
def test_restoring_global():
    global y; y = 2
    global x
    with restoring('x'):
        x = None
        with restoring('y'):
            y += 3
            assert x == None and y == 5 
        assert y == 2
    assert x == 1
    assert not locals()

def test_restoring_local():
    x = 5
    with restoring('x'):
        x = None
    assert x == 5
    assert sorted(locals()) == ['x']

def test_restoring_nonlocal():
    a = []
    def nested():
        with restoring('a', list):
            a.append(1)
            assert a == [1]
        assert a == []
    nested()
    assert a == []


if __name__ == '__main__':
    test_restoring_immutable()
    test_restoring_mutable()
    test_restoring_global()
    test_restoring_local()
    test_restoring_nonlocal()

Although this recipe uses eval and exec, is shouldn't pose a security risk since the passed expression is expected to be hardcoded, not provided by an external untrusted source.

The implementation is rather straightforward, except for one case: if the expression is a non-global identifier, i.e. local or non-local. The latter requires modifying a frame's f_locals, which normally isn't possible. According to http://code.google.com/p/ouspg/wiki/AnonymousBlocksInPython#Hack_number_3:_Locals-forcing though, f_locals can be modified from within a trace function. Indeed that works but still it's an obscure, most likely implementation-dependent hack.