This recipe implements the parallel
context manager, which executes the __enter__
and __exit__
method of its arguments concurrently.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | from __future__ import with_statement
import contextlib
import sys
import threading
import traceback
import unittest
__all__ = ["MultipleError", "parallel"]
class parallel(object):
"""Concurrently start and stop serveral context managers in different
threads.
Typical usage::
with parallel(Foo(), Bar()) as managers:
foo, bar = managers
foo.do_something()
bar.do_something()
"""
def __init__(self, *managers):
self.managers = managers
def __enter__(self):
errors = []
threads = []
for mgr in self.managers:
t = threading.Thread(target=run,
args=(mgr.__enter__, tuple(), errors))
t.start()
threads.append(t)
for thread in threads:
thread.join()
if errors:
err = MultipleError(errors)
raise err
return self.managers
def __exit__(self, *exc_info):
errors = []
threads = []
for mgr in self.managers:
t = threading.Thread(target=run,
args=(mgr.__exit__, exc_info, errors))
t.start()
threads.append(t)
for thread in threads:
thread.join()
if errors:
raise MultipleError(errors)
class MultipleError(Exception):
"""Exception class to collect several errors in a single object."""
def __init__(self, errors):
super(Exception, self).__init__()
self.errors = errors
def __str__(self):
bits = []
for exc_type, exc_val, exc_tb in self.errors:
bits.extend(traceback.format_exception(exc_type, exc_val, exc_tb))
return "".join(bits)
def run(func, args, errors):
"""Helper for ``parallel``.
"""
try:
func(*args)
except:
errors.append(sys.exc_info())
class ParallelTest(unittest.TestCase):
def test_parallel(self):
"""Basic tests.
"""
with parallel(database(), web_server()):
pass
def test_errors(self):
"""Tests for errors thrown in context manager methods.
"""
try:
with parallel(error_in_enter()):
pass
except MultipleError, err:
self.assertEqual(1, len(err.errors))
self.assertEqual("enter", str(err.errors[0][1]))
try:
with parallel(error_in_exit()):
pass
except MultipleError, err:
self.assertEqual(1, len(err.errors))
self.assertEqual("exit", str(err.errors[0][1]))
try:
with parallel(error_in_enter(), error_in_exit()):
pass
except MultipleError, err:
self.assertEqual(1, len(err.errors))
self.assertEqual("enter", str(err.errors[0][1]))
try:
with parallel(error_in_enter(), error_in_enter()):
pass
except MultipleError, err:
self.assertEqual(2, len(err.errors))
self.assertEqual("enter", str(err.errors[0][1]))
self.assertEqual("enter", str(err.errors[1][1]))
try:
with parallel(error_in_exit(), error_in_exit()):
pass
except MultipleError, err:
self.assertEqual(2, len(err.errors))
self.assertEqual("exit", str(err.errors[0][1]))
self.assertEqual("exit", str(err.errors[1][1]))
@contextlib.contextmanager
def web_server():
"""Sample context manager.
"""
yield
@contextlib.contextmanager
def database():
"""Sample context manager.
"""
yield
class error_in_enter(object):
"""Sample context manager."""
def __enter__(self):
raise Exception("enter")
def __exit__(self, *exc_info):
pass
class error_in_exit(object):
"""Sample context manager."""
def __enter__(self):
return self
def __exit__(self, *exc_info):
raise Exception("exit")
if __name__ == "__main__":
unittest.main()
|
I have written several context managers to start server processes for my unit tests. My context managers wait for the process to be up and running before leavind the __enter__
method:
with mysql():
with memcached():
with twisted_app1():
with twisted_app2():
unittest.main()
Many of those processes, however, do not need to be started and stopped sequentially. A lot of time can be saved when they are started in parallel, like this:
with parallel(memcached(), mysql()):
with parallel(twisted_app1(), twisted_app2()):
unittest.main()