import time import threading class Cache: "A cached function" # a dict of sets, one for each instance of this class __allInstances = set() # where the cached values are actually kept maxAge = 3600 # the default max allowed age of a cache entry (in seconds) collectionInterval = 2 # how long to wait between collection events __stopCollecting = False def __init__(self, func): Cache.__allInstances.add(self) self._store = {} self.__func = func def __del__(self): if self in Cache.__allInstances: Cache.__allInstances.remove(self) def __call__(self, *args, **kw): key = (args, tuple(sorted(kw.items()))) if self._store.has_key(key): return self._store[key][1] result = self.__func(*args, **kw) self._store[key] = (time.time(), result) return result def invalidate(self): "Invalidate all cache entries for this function" self._store.clear() def invalidate_one(self, *args, **kw): "Invalidate the cache entry for a particular set of arguments for this function" key = (args, tuple(sorted(kw.items()))) if self._store.has_key(key): del self._store[key] def collect(self): "Clean out any cache entries in this store that are currently older than allowed" now = time.time() for key, v in self._store.items(): t, value = v # creation time, function output if self.maxAge > 0 and now - t > self.maxAge: # max ages of zero mean don't collect del self._store[key] @classmethod def collectAll(cls): "Clean out all old cache entries in all functions being cached" for instance in cls.__allInstances: instance.collect() @classmethod def _startCollection(cls): "Periodically clean up old entries until the stop flag is set" while cls.__stopCollecting is not True: time.sleep(cls.collectionInterval) cls.collectAll() @classmethod def startCollection(cls): "Start the automatic collection process in its own thread" cls.collectorThread = threading.Thread(target=cls._startCollection) cls.collectorThread.setDaemon(False) cls.collectorThread.start() @classmethod def stopCollection(cls): cls.__stopCollecting = True # ------------------- # Example usage: @Cache def foo(arg): print 'foo called with arg=%s' % arg return arg*2 @Cache def bar(arg1, arg2=3): print 'bar called with arg1=%s arg2=%s' % (arg1, arg2) return arg1 + arg2 foo(2) # cache misses, foo invoked, cache entry created for arg=2 foo(2) # cache hit, cached value retrieved foo.invalidate() # all cache entries for foo are deleted (in this case, only 1) foo(2) # cache misses, etc. bar(1) # cache misses, cache entry created for arg1=1, arg2=3 bar(1, 3) # cache hit bar(1, 2) # cache misses, cache entry created for arg1=1, arg2=2 bar.invalidate_one(1, 3) bar(1, 3) # cache miss bar(1, 2) # cache hit Cache.collectionInterval = 1.5 Cache.startCollection() # starts cache collection for all funcs every 1.5 seconds foo(2) # cache hit foo.maxAge = 1 # set the max age for foo's (and only foo's) cache entries to 1 second # wait a second foo(2) # cache miss Cache.stopCollection() # ------------------