Queue subclass to simplify working with consumer threads. Keeps a count of tasks put into the queue and lets the consumer thread report when each task has been retrieved AND PROCESSED COMPLETELY. This reporting supports a join() method that blocks untils all submitted tasks have been fully processed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
import threading from Queue import Queue class TaskQueue(Queue): def __init__(self): Queue.__init__(self) self.all_tasks_done = threading.Condition(self.mutex) self.unfinished_tasks = 0 def _put(self, item): Queue._put(self, item) self.unfinished_tasks += 1 def task_done(self): """Indicate that a formerly enqueued task is complete. Used by Queue consumer threads. For each get() used to fetch a task, a subsequent call to task_done() tells the queue that the processing on the task is complete. If a join() is currently blocking, it will resume when all items have been processed (meaning that a task_done() call was received for every item that had been put() into the queue). Raises a ValueError if called more times than there were items placed in the queue. """ self.all_tasks_done.acquire() try: unfinished = self.unfinished_tasks - 1 if unfinished <= 0: if unfinished < 0: raise ValueError('task_done() called too many times') self.all_tasks_done.notifyAll() self.unfinished_tasks = unfinished finally: self.all_tasks_done.release() def join(self): """Blocks until all items in the Queue have been gotten and processed. The count of unfinished tasks goes up whenever an item is added to the queue. The count goes down whenever a consumer thread calls task_done() to indicate the item was retrieved and all work on it is complete. When the count of unfinished tasks drops to zero, join() unblocks. """ self.all_tasks_done.acquire() try: while self.unfinished_tasks: self.all_tasks_done.wait() finally: self.all_tasks_done.release() #### Example code #################### import sys def worker(): 'Stop at 1; otherwise if odd then load 3*x+1; or if even then divide by two' name = threading.currentThread().getName() while True: x = q.get() sys.stdout.write('%s\t%d\n' % (name, x)) if x <= 1: sys.stdout.write('!\n') elif x % 2 == 1: q.put(x * 3 + 1) else: q.put(x // 2) q.task_done() q = TaskQueue() numworkers = 4 for i in range(numworkers): t = threading.Thread(target=worker) t.setDaemon(True) t.start() for x in range(1, 50): q.put(x) q.join() print 'All inputs found their way to 1. Queue is empty and all processing complete.'
Put people in line at a bank and have them serviced by consumer threads (i.e. bank tellers). How do you know when all the customers have been serviced and you can close the doors at the bank? It is not enough to know that the line to see a teller is empty; you also need to know that the tellers have finished serving their customers.
So you keep a count of the customers going into the line (automatically done with a q.put(customer) call). Then tellers use q.get() to take a customer. When they are done with that customer, they call q.task_done() which decrements the count. The security guard is assigned to close the doors, but his task is held-up by q.join() which will block until the count has dropped to zero.
This approach is simpler and more flexible than other approaches such as submitting a None item to the queue to notify the consumer threads to terminate (essentially adding sentinel customers to the end of the line) or an alternate approach that uses a second result queue for the producer to match-up every enqueued task with a notice that the task was completed (i.e. matching all arriving customers with serviced customers).
This recipe was accepted for inclusion in Py2.5.