Often one wants to rebind a name or modify a mutable object, perform a bunch of actions and finally restore the name/object to its original state. An example is redirecting stdout/stderr temporarily (http://www.diveintopython.org/scripts_and_streams/stdin_stdout_stderr.html). The restoring context manager shown below simplifies this pattern::
import sys
# prints in console
print "hello world!"
with restoring('sys.stdout'):
with open('hello.txt', 'w') as sys.stdout:
# prints in file
print "hello world!"
# prints in console again
print "hello world!"
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | from __future__ import with_statement
from contextlib import contextmanager
import sys
__docformat__ = "restructuredtext"
@contextmanager
def restoring(expr, clone=None):
'''A context manager that evaluates an expression when entering the runtime
context and restores its value when exiting.
This context manager makes
.. python::
with restoring(expr, clone) as value:
BODY
a shortcut for
.. python::
value = EXPR
__cloned = clone(value) if clone is not None else value
try:
BODY
finally:
EXPR = __cloned
del __cloned
where ``__cloned`` is a temporary hidden name and ``EXPR`` is ``expr``
substituted textually in the code snippet. Therefore ``expr`` can only be an
assignable expression, i.e. an expression that is allowed on the left hand
side of '=' (e.g. identifier, subscription, attribute reference, etc.).
:param expr: The expression whose value is to be evaluated and restored.
:type expr: str
:param clone: A callable that takes the object ``expr`` evaluates to and
returns an appropriate copy to be used for restoring. If None, the
original object is used.
:type clone: callable or None
'''
f = sys._getframe(2) # bypass the contextmanager frame
# evaluate the expression and make a clone of the value to be restored
value = eval(expr, f.f_globals, f.f_locals)
restored_value = clone(value) if clone is not None else value
try:
yield value
finally:
if expr in f.f_locals: # local or nonlocal name
_update_locals(f, {expr:restored_value})
elif expr in f.f_globals: # global name
f.f_globals[expr] = restored_value
else:
# make a copy of f_locals and bind restored_value to a new name
tmp_locals = dict(f.f_locals)
tmp_name = '__' + min(tmp_locals)
tmp_locals[tmp_name] = restored_value
exec '%s = %s' % (expr, tmp_name) in f.f_globals, tmp_locals
def _update_locals(frame, new_locals, clear=False):
# XXX: obscure, most likely implementation-dependent fact:
# f_locals can be modified (only?) from within a trace function
f_trace = frame.f_trace
try:
sys_trace = sys.gettrace()
except AttributeError: # sys.gettrace() not available before 2.6
sys_trace = None
def update_tracer(frm, event, arg):
# Update the frame's locals and restore both the local and the system's
#trace function
assert frm is frame
if clear:
frm.f_locals.clear()
frm.f_locals.update(new_locals)
frm.f_trace = f_trace
sys.settrace(sys_trace)
# Force tracing on with setting the global tracing function and set
# the frame's local trace function
sys.settrace(lambda frame, event, arg: None)
frame.f_trace = update_tracer
def test_restoring_immutable():
x = 'b'
foo = {'a':3, 'b':4}
with restoring('foo[x]') as y:
assert y == foo[x] == 4
foo[x] = y = None
assert y == foo[x] == None
assert foo[x] == 4 and y == None
assert sorted(locals()) == ['foo', 'x', 'y']
def test_restoring_mutable():
orig_path = sys.path[:]
with restoring('sys.path', clone=list) as path:
assert path is sys.path
path += ['foo']
assert path == orig_path + ['foo']
assert sys.path == orig_path
assert path == orig_path + ['foo']
assert sorted(locals()) == ['orig_path', 'path']
x = 1
def test_restoring_global():
global y; y = 2
global x
with restoring('x'):
x = None
with restoring('y'):
y += 3
assert x == None and y == 5
assert y == 2
assert x == 1
assert not locals()
def test_restoring_local():
x = 5
with restoring('x'):
x = None
assert x == 5
assert sorted(locals()) == ['x']
def test_restoring_nonlocal():
a = []
def nested():
with restoring('a', list):
a.append(1)
assert a == [1]
assert a == []
nested()
assert a == []
if __name__ == '__main__':
test_restoring_immutable()
test_restoring_mutable()
test_restoring_global()
test_restoring_local()
test_restoring_nonlocal()
|
Although this recipe uses eval and exec, is shouldn't pose a security risk since the passed expression is expected to be hardcoded, not provided by an external untrusted source.
The implementation is rather straightforward, except for one case: if the expression is a non-global identifier, i.e. local or non-local. The latter requires modifying a frame's f_locals, which normally isn't possible. According to http://code.google.com/p/ouspg/wiki/AnonymousBlocksInPython#Hack_number_3:_Locals-forcing though, f_locals can be modified from within a trace function. Indeed that works but still it's an obscure, most likely implementation-dependent hack.