Welcome, guest | Sign In | My Account | Store | Cart

This recipe implements the parallel context manager, which executes the __enter__ and __exit__ method of its arguments concurrently.

Python, 180 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from __future__ import with_statement

import contextlib
import sys
import threading
import traceback
import unittest


__all__ = ["MultipleError", "parallel"]


class parallel(object):

    """Concurrently start and stop serveral context managers in different
    threads.

    Typical usage::

        with parallel(Foo(), Bar()) as managers:
            foo, bar = managers
            foo.do_something()
            bar.do_something()

    """

    def __init__(self, *managers):
        self.managers = managers

    def __enter__(self):
        errors = []
        threads = []

        for mgr in self.managers:
            t = threading.Thread(target=run,
                                 args=(mgr.__enter__, tuple(), errors))
            t.start()
            threads.append(t)
        for thread in threads:
            thread.join()

        if errors:
            err = MultipleError(errors)
            raise err

        return self.managers

    def __exit__(self, *exc_info):
        errors = []
        threads = []

        for mgr in self.managers:
            t = threading.Thread(target=run,
                                 args=(mgr.__exit__, exc_info, errors))
            t.start()
            threads.append(t)

        for thread in threads:
            thread.join()

        if errors:
            raise MultipleError(errors)


class MultipleError(Exception):

    """Exception class to collect several errors in a single object."""

    def __init__(self, errors):
        super(Exception, self).__init__()
        self.errors = errors

    def __str__(self):
        bits = []
        for exc_type, exc_val, exc_tb in self.errors:
            bits.extend(traceback.format_exception(exc_type, exc_val, exc_tb))
        return "".join(bits)


def run(func, args, errors):
    """Helper for ``parallel``.

    """
    try:
        func(*args)
    except:
        errors.append(sys.exc_info())


class ParallelTest(unittest.TestCase):

    def test_parallel(self):
        """Basic tests.

        """
        with parallel(database(), web_server()):
            pass

    def test_errors(self):
        """Tests for errors thrown in context manager methods.

        """
        try:
            with parallel(error_in_enter()):
                pass
        except MultipleError, err:
            self.assertEqual(1, len(err.errors))
            self.assertEqual("enter", str(err.errors[0][1]))

        try:
            with parallel(error_in_exit()):
                pass
        except MultipleError, err:
            self.assertEqual(1, len(err.errors))
            self.assertEqual("exit", str(err.errors[0][1]))

        try:
            with parallel(error_in_enter(), error_in_exit()):
                pass
        except MultipleError, err:
            self.assertEqual(1, len(err.errors))
            self.assertEqual("enter", str(err.errors[0][1]))

        try:
            with parallel(error_in_enter(), error_in_enter()):
                pass
        except MultipleError, err:
            self.assertEqual(2, len(err.errors))
            self.assertEqual("enter", str(err.errors[0][1]))
            self.assertEqual("enter", str(err.errors[1][1]))

        try:
            with parallel(error_in_exit(), error_in_exit()):
                pass
        except MultipleError, err:
            self.assertEqual(2, len(err.errors))
            self.assertEqual("exit", str(err.errors[0][1]))
            self.assertEqual("exit", str(err.errors[1][1]))


@contextlib.contextmanager
def web_server():
    """Sample context manager.

    """
    yield


@contextlib.contextmanager
def database():
    """Sample context manager.

    """
    yield


class error_in_enter(object):

    """Sample context manager."""

    def __enter__(self):
        raise Exception("enter")

    def __exit__(self, *exc_info):
        pass


class error_in_exit(object):

    """Sample context manager."""

    def __enter__(self):
        return self

    def __exit__(self, *exc_info):
        raise Exception("exit")


if __name__ == "__main__":
    unittest.main()

I have written several context managers to start server processes for my unit tests. My context managers wait for the process to be up and running before leavind the __enter__ method:

with mysql():
    with memcached():
        with twisted_app1():
            with twisted_app2():
                unittest.main()

Many of those processes, however, do not need to be started and stopped sequentially. A lot of time can be saved when they are started in parallel, like this:

with parallel(memcached(), mysql()):
    with parallel(twisted_app1(), twisted_app2()):
        unittest.main()