Welcome, guest | Sign In | My Account | Store | Cart
from sets import Set
import thread
from threading import Thread


class Lockable:
   def __init__(self):
        self.mutex = thread.allocate_lock()
        
def locked(method):
   def _locked(self, *args, **kwargs):
       self.mutex.acquire()
       try:
           result = method(self, *args, **kwargs)
       finally:
           self.mutex.release()
       return result
   return _locked


class TaggedSignal:
   def __init__(self, names):
       self.names = names
       self.lock = thread.allocate_lock()
       self.lock.acquire()
       self.tag = None
       
   def set(self, tag):
       self.tag = tag
       self.lock.release()

   def wait(self):
       self.lock.acquire()
       self.lock.release()
       return self.tag


class MultiQueue(Lockable):
    def __init__(self, maxsize):
        self.maxsize = maxsize
        self.queues = {}
        self.writeLocks = {}
        self.waitingLists = {}
        Lockable.__init__(self)

    def put(self, name, value):
        writeLock = self.getWriteLock(name)
        must_release = True
        writeLock.acquire()
        try:
            must_release = self._put(name, value, writeLock)
        finally:
            if must_release:
                writeLock.release()

    def _put(self, name, value, writeLock):        
        if self.waitingLists.has_key(name):
            self.notify_signal(name, value)
            return True
        else:
            queue = self.getQueue(name)
            queue.append(value)
            return len(queue)<self.maxsize
    _put = locked(_put)

    def notify_signal(self, name, value):
        signal = self.waitingLists[name][0]
        signal.set((name, value))
        for name in signal.names:
            waitingList = self.waitingLists[name]
            waitingList.remove(signal)
            if len(waitingList)==0:
                del self.waitingLists[name]

    def get(self, *names):
        must_wait, result = self._get(names)
        if must_wait:
            return result.wait()
        else:
            return result
            
    def _get(self, names):
        populatedQueues = Set(self.queues.keys()).intersection(names)
        if len(populatedQueues)==0:
            return True, self.createSignal(names)
        else:
            return False, self.getValue(populatedQueues.pop())
    _get = locked(_get)
    
    def getValue(self, queueName):
        queue = self.queues[queueName]
        was_full = len(queue)==self.maxsize
        value = queue.pop(0)
        if len(queue)==0:
            del self.queues[queueName]
            del self.writeLocks[queueName]
        else:
            if was_full:
                self.writeLocks[queueName].release()
        return queueName, value
            
    def createSignal(self, names):
        signal = TaggedSignal(names)
        for name in names:
            try:
                waitingList = self.waitingLists[name]
            except KeyError:
                waitingList = self.waitingLists[name] = []
            waitingList.append(signal)
        return signal

    def getWriteLock(self, name):
        try:
            writeLock = self.writeLocks[name]
        except KeyError:
            writeLock = self.writeLocks[name] = thread.allocate_lock()
        return writeLock
    getWriteLock = locked(getWriteLock)

    def getQueue(self, name):
        try:
            queue = self.queues[name]
        except KeyError:
            queue = self.queues[name] = []
        return queue

History

  • revision 3 (19 years ago)
  • previous revisions are not available