# Author: David Decotigny, Oct 1, 2008
# @brief Multiplexer for parallel transactions over a single data
# channel. This is like a pipe on which we provide a multithreaded
# request/response messaging system. This system allows multiple
# threads to issue several requests in parallel: they are treated in
# parallel on the receiving side and the responses are sent back to
# their respective requesting thread. The exceptions are correctly
# transferred: the _trace member of the exception object will
# indicate the traceback (as text).
#
# The basic synopsis is:
# - call Mux::transaction(*args, **kwds) from a "sender" thread
# - the transaction is sent to the Demux via the channel
# - DeMux::process_transaction(*args, **kwds) gets called by another thread,
# on the other side of the channel (in general in
# another process/machine)
# - the result/exception of process_transaction() is sent back to the
# sender via the channel
# - Mux::transaction() on the sender returns or raises the exception
# raised by DeMux::process_transaction()
#
# The user code of this module has to override the
# DeMux::process_transaction() method for the whole system to be
# useful.
#
# This module makes the following assumptions on the channel used to
# transmit the requests/responses:
# - the channel is bidirectional: both parties can send and receive data on it
# - the channel can transmit arbitrary serializable python objects
# - the channel consists of 2 endpoints having the same API: one
# endpoint for the Multiplexer, one endpoint for the DeMultiplexer
# - the endpoints of the channel have the following methods:
# fileno(): return a file descriptor suitable for select/poll of data
# ready to be received in non-blocking mode (at least for the
# first byte)
# send(data): send the given python data to the receiving party
# data = recv(): wait for python data from the sending party and return it
# close(): close the endpoint in both send/receive directions
# - send() is multithread-safe
#
# To achieve parallel handling of "simultaneous" requests, the
# demultiplexer handles each request in a separate thread: either the
# threads are created on demand (nworkers = None), or a pool of
# pre-allocated threads is used (nworkers = integer). To manage the
# interleaving of the transactions, each transaction has its own ID,
# the "xid".
import sys, os, threading, Queue, itertools, traceback, select, struct
import cPickle as pickle # Only for SimpleChannelEndpoint
__all__ = ["Mux", "DeMux", "ChannelPair"]
## Magic token to mark the end of job submission by the DeMux
SENTINEL = "QUIT"
def is_sentinel(obj):
"""Predicate True when a DeMux worker thread receives a
"terminate" order from the DeMux"""
return type(obj) is str and obj == SENTINEL
class ReceiverThread(threading.Thread):
"""Generic wrapper class to wait for data from a channel:
handle_message() is called for each data received. Provides a
stop() method to stop receiving the data. This is a thread
object: call start() to start it"""
def __init__(self, channel, *args, **kwds):
"""
\param channel is a Channel endpoint (fileno/recv/close
methods expected)
"""
threading.Thread.__init__(self, *args, **kwds)
self._channel = channel
self.__terms = os.pipe()
self._recv = channel.recv
self._send = channel.send
def run(self):
"""
Wait for either a call to stop() or for a data to be available
on the channel and then call handle_message. And loop over.
"""
# Initialize poll()
fd = self._channel.fileno()
waitset = select.poll()
eventmask = select.POLLIN | select.POLLERR \
| select.POLLHUP | select.POLLPRI
waitset.register(fd, eventmask)
waitset.register(self.__terms[0], eventmask)
while 1:
exit_loop = False
for fd_, evt in waitset.poll():
if fd_ != fd:
# Received sthg on the __terms pipe
exit_loop = True
break
if evt != select.POLLIN:
# Receive something on the channel, but not a normal
# data (probably a HUP)
exit_loop = True
break
if exit_loop:
break
# Error while receiving => term thread
data = self._recv()
# Call handle_message (dump the exceptions, but ignore them)
try:
self.handle_message(data)
except:
traceback.print_exc()
# End while
def handle_message(self, message):
"""Method to override: called each time a message is received"""
raise NotImplementedError("Children classes expected to override it")
def stop(self):
"""Stop receiving data. Waits until the thread is
terminated. DO NOT CALL THIS from inside handle_message()"""
os.write(self.__terms[1], "TERMINATION")
self._channel.close()
self.join()
class Mux(ReceiverThread):
"""Thread that multiplexes calls to the transaction() method on
the given channel"""
def __init__(self, channel):
"""
\param channel is a Channel endpoint (fileno/recv/close
methods expected)
"""
ReceiverThread.__init__(self, channel)
self.__lock = threading.Lock()
self.__waitq = dict()
self.__idgen = itertools.count(42)
def transaction(self, *args, **kwds):
"""Call this method to send the given args on the wire and
wait for a response"""
evt = threading.Event(self.__lock)
# Allocate a transaction ID
self.__lock.acquire()
try:
xid = self.__idgen.next()
assert xid not in self.__waitq
self.__waitq[xid] = [evt, None] # If except: means MUX stopped
except AttributeError:
raise EOFError("MUX has been stopped.")
finally:
self.__lock.release()
# Send the request
self._send((xid, args, kwds))
# Wait for the answer
evt.wait()
# Return the answer/raise the exception to the caller
self.__lock.acquire()
try:
# Retrieve the result
try:
result = self.__waitq[xid][1]
except (AttributeError, IndexError):
raise EOFError("MUX has been stopped.")
except:
print "EX", self.__waitq
# Work done
del self.__waitq[xid]
# Reformat the result
xid_, result_ = result
assert xid_ == xid, \
"Expected txn id %s != received (%s)" % (xid, xid_)
status, details = result_
if status == "OK":
return details
elif status == "EXCEPTION":
raise details
else:
raise RuntimeError("Invalid status %s !" % repr(status))
return result
finally:
self.__lock.release()
def run(self):
"""Listen to the messages coming from the endpoint and
dispatch them to the threads which sent them"""
try:
ReceiverThread.run(self)
except:
traceback.print_exc()
# If we're here, it means that a stop has been requested:
# unblock _all_ the waiting caller threads and force them
# to fail in transaction()
self.__lock.acquire()
try:
for xid, slot in self.__waitq.iteritems():
del slot[1] # Force IndexError on the waiting threads
slot[0].set()
del self.__waitq # Force AttributeError on next transaction()
finally:
self.__lock.release()
def handle_message(self, msg):
"""Needed by the ReceiverThread object: dispatch the messages
to the caller threads"""
xid, result = msg
self.__lock.acquire()
try:
slot = self.__waitq[xid]
slot[1] = msg
slot[0].set() # wake up the caller thread
finally:
self.__lock.release()
class DeMux(ReceiverThread):
"""Thread that demultiplexes transactions coming from a
multiplexer, and calls process_transaction() for each of them. The
transactions are processed in parallel in different worker
threads. The worker threads are either consisting in a pool of
threads (when nworkers is not None), or are created on-demand when
requests arrive (when nworkers is None)"""
__lock = None # Lock object
__workq = None # Queue object or None (in on-demand mode)
__nworkers = None # Specified size of the pool of threads
__workers = None # Either a list of threads (pool) or a dict xid->thread
# (in on-demand mode)
def __init__(self, channel, nworkers = None):
"""
\param channel is a Channel endpoint (fileno/recv/close
methods expected)
\param nworkers (integer) number of threads in the pool able
to process the transaction requests, or None when threads have
to be created on demand
"""
ReceiverThread.__init__(self, channel)
self.__nworkers = nworkers
self.__lock = threading.Lock()
if nworkers is not None:
self.__workers = []
self.__workq = Queue.Queue()
for idworker in range(nworkers):
thr = threading.Thread(target=self._pool_work)
self.__workers.append(thr)
thr.start()
else:
self.__workers = dict()
def handle_message(self, msg):
"""Required by ReceiverThread"""
xid, args, kwds = msg
if self.__nworkers is not None:
# In pool mode: send the job to the pool
self.__workq.put((xid, args, kwds))
else:
# In on-demand mode: spawn a new thread to do the job
thr = threading.Thread(target=self._do_process_transaction,
args=(xid,)+args, kwargs=kwds)
# Register the thread for this transaction
self.__lock.acquire()
try:
self.__workers[xid] = thr
finally:
self.__lock.release()
try:
thr.start()
except:
# Oops, cannot start worker...
self.__lock.acquire()
try:
del self.__workers[xid]
finally:
self.__lock.release()
# Sending exception back to sender
ex = sys.exc_info()[1]
if ex is not None:
ex._trace = traceback.format_exc()
else:
ex = sys.exc_info()[0]
self._send((xid, ("EXCEPTION", ex)))
def _pool_work(self):
"""Method run by the pool worker threads in pool mode"""
while 1:
# Simply consume the jobs from the queue until we get the
# sentinel token
data = self.__workq.get()
if is_sentinel(data):
break
xid, args, kwds = data
# Will raise exception ONLY when connection problems:
self._do_process_transaction(xid, *args, **kwds)
def _do_process_transaction(self, xid, *args, **kwds):
"""Method run by the worker threads to process one transaction"""
# Call process_transaction and prepare the result to send
result = None
try:
result = ("OK", self.process_transaction(*args, **kwds))
except Exception, ex:
ex._trace = traceback.format_exc
result = ("EXCEPTION", ex)
except:
ex = sys.exc_info()[1]
if ex is not None:
ex._trace = traceback.format_exc()
else:
ex = sys.exc_info()[0]
result = ("EXCEPTION", ex)
finally:
if result is None:
ex = RuntimeError("Unexpected error !")
result = ("EXCEPTION", ex)
# Send response
self._send((xid, result))
# Unregister the thread in on-demand mode
if self.__nworkers is None:
self.__lock.acquire()
try:
# In on-demand mode: unregister the thread for this transaction
del self.__workers[xid]
finally:
self.__lock.release()
def process_transaction(self, *args, **kwds):
"""Implement this method in order to generate a response from
the given transaction arguments"""
raise NotImplementedError("Children must implement this method")
def stop(self):
"""Stop the worker threads and close the channel"""
ReceiverThread.stop(self)
#
# No lock because the listening thread is dead already (no new
# thread)
#
# Clearing job queue
if self.__workq is not None:
while 1:
try:
self.__workq.get_nowait()
except Queue.Empty:
break
# Stopping workers
if self.__nworkers is not None:
for i in range(self.__nworkers):
self.__workq.put(SENTINEL)
for thr in self.__workers:
thr.join()
else:
while self.__workers:
xid, thr = self.__workers.popitem()
thr.join()
class SimpleChannelEndpoint:
"""Construct a channel compliant with the channel specifications
from a pair of r/w file descriptors"""
SZI = struct.calcsize('I')
def __init__(self, fd_r, fd_w):
"""
\param r,w The read-write file descriptors used for this endpoint
"""
self._fd_r = fd_r
self._fd_w = fd_w
self._wlock = threading.Lock() # send() has to be thread-safe
def fileno(self):
"""Return a file descriptor suitable for select/poll of data
ready to be received in non-blocking mode (at least for the
first byte)"""
return self._fd_r
def send(self, data):
"""send the given python data to the receiving party"""
sdata = pickle.dumps(data)
sdata = struct.pack('I', len(sdata)) + sdata
self._wlock.acquire()
try:
os.write(self._fd_w, sdata)
finally:
self._wlock.release()
def recv(self):
"""wait for python data from the sending party and return it"""
(expected,) = struct.unpack('I', os.read(self._fd_r, self.SZI))
sdata = ""
while 1:
sdata += os.read(self._fd_r, expected - len(sdata))
assert len(sdata) <= expected
if len(sdata) == expected:
break
return pickle.loads(sdata)
def close(self):
"""close the endpoint in both send/receive directions"""
self._wlock.acquire()
try:
os.close(self._fd_w)
finally:
self._wlock.release()
os.close(self._fd_r)
def ChannelPair():
"""Very simple function returning a connected pair of channels"""
r1, w2 = os.pipe()
r2, w1 = os.pipe()
return ( SimpleChannelEndpoint(r1, w1), SimpleChannelEndpoint(r2, w2) )
def _test():
"""
Some tests
"""
import time, thread
c1, c2 = ChannelPair()
mux = Mux(c1)
class MyDeMux(DeMux):
"""A demultiplexer in which each transaction is a call to sleep()"""
def process_transaction(self, message_before, duration, message_after):
"""One trasaction is just a call to sleep"""
print "[%d] BEGIN: %s (sleep %fs)" % (thread.get_ident(),
message_before, duration)
time.sleep(duration)
print "[%d] END: %s" % (thread.get_ident(), message_after)
class Submitter(threading.Thread):
"""A thread that submits 3 transactions to the mux object"""
def run(self):
"""Submit 3 transactions and stop"""
mux.transaction("msg1", 3, "msg2")
mux.transaction("msg3", 2, "msg4")
mux.transaction("msg5", 1, "msg6")
try:
mux.transaction("msgE", -1, "msgEE")
except IOError, ex:
print "Got expected exception from the DeMux: %s" % repr(ex)
demux = MyDeMux(c2, 100)
# demux = MyDeMux(c2)
# Starting mux/demux
mux.start()
demux.start()
# Starting as many threads that run transactions as possible
children = []
for i in range(700):
thr = Submitter()
try:
thr.start()
children.append(thr)
except:
break
print "Started %d submission threads" % len(children)
# Waiting for the children
for thr in children:
try:
thr.join()
except KeyboardInterrupt:
print "User interruption."
break
# Stopping mux/demux
mux.stop()
demux.stop()
print "Bye."
if __name__ == "__main__":
_test()