#!/usr/bin/env python
#
# Copyright 2010- Hui Zhang
# E-mail: hui.zh012@gmail.com
#
# Distributed under the terms of the GPL (GNU Public License)
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
from functools import partial
from weakref import ref
### weak ref key holder
def _ref(value, ondel):
try:
vref = ref(value, ondel)
except:
vref = lambda:value
return vref
class weakmap(object):
def __init__(self):
self.__map = {}
self.__valmemo = {}
self.__backref = {}
def set(self, key, value):
ikey, ival = id(key), id(value)
self.__map[ikey] = (ival,
_ref(key, partial(self.forget_key, ikey))
)
if ival not in self.__valmemo:
self.__valmemo[ival] = _ref(value, partial(self.forget_value, ival))
self.__backref.setdefault(ival, set()).add(ikey)
def get(self, key, **kwargs):
ikey = id(key)
if ikey in self.__map:
return self.__valmemo[self.__map[ikey][0]]()
if 'default' in kwargs:
return kwargs['default']
raise KeyError(str(key))
def pop(self, key):
ikey = id(key)
if ikey in self.__map:
value = self.__valmemo[self.__map[ikey][0]]()
self.forget_key(ikey)
return value
else:
raise KeyError(str(key))
def forget_key(self, keyid, obj=None):
if keyid in self.__map:
ival = self.__map.pop(keyid)[0]
self.__backref[ival].remove(keyid)
if not self.__backref[ival]:
self.__backref.pop(ival)
self.__valmemo.pop(ival)
def forget_value(self, valueid, obj=None):
for ikey in self.__backref.get(valueid, ()):
if ikey in self.__map:
self.__map.pop(ikey)
for d in [self.__backref, self.__valmemo]:
if valueid in d:
d.pop(valueid)
def __iter__(self):
for k, v in self.__map.iteritems():
k = v[1]()
v = self.__valmemo[v[0]]()
yield (k, v)
class weakkeymap(object):
def __init__(self):
self.__map = {}
def set(self, key, value):
ikey = id(key)
self.__map[ikey] = (value,
_ref(key, partial(self.forget_key, ikey))
)
def get(self, key, **kwargs):
ikey = id(key)
if ikey in self.__map:
return self.__map[ikey][0]
if 'default' in kwargs:
return kwargs['default']
raise KeyError(str(key))
def pop(self, key):
ikey = id(key)
if ikey in self.__map:
return self.__map.pop(ikey)[0]
else:
raise KeyError(str(key))
def forget_key(self, keyid, obj=None):
if keyid in self.__map:
self.__map.pop(keyid)
def __iter__(self):
for k, v in self.__map.iteritems():
k = v[1]()
yield (k, v[0])
Diff to Previous Revision
--- revision 1 2011-02-18 09:40:16
+++ revision 2 2011-02-20 16:26:38
@@ -17,68 +17,98 @@
from functools import partial
from weakref import ref
+### weak ref key holder
+def _ref(value, ondel):
+ try:
+ vref = ref(value, ondel)
+ except:
+ vref = lambda:value
+ return vref
+
class weakmap(object):
- def __init__(self, weakkey=True, weakvalue=False, refanyway=True):
- self.__conf = namedtuple('conf', 'weakkey weakvalue refanyway')(weakkey, weakvalue, refanyway)
+ def __init__(self):
self.__map = {}
- self.__memo = {}
+ self.__valmemo = {}
self.__backref = {}
def set(self, key, value):
- if self.__conf.weakkey:
- self.remember(key,
- partial(self.forget_key, id(key)))
- ikey = id(key)
- else:
- ikey = key
-
- if self.__conf.weakvalue:
- self.__map[ikey] = self.remember(value,
- partial(self.forget_value, id(value)))
- self.__backref.setdefault(id(value), set()).add(id(key))
- else:
- self.__map[ikey] = value
+ ikey, ival = id(key), id(value)
+ self.__map[ikey] = (ival,
+ _ref(key, partial(self.forget_key, ikey))
+ )
+ if ival not in self.__valmemo:
+ self.__valmemo[ival] = _ref(value, partial(self.forget_value, ival))
+ self.__backref.setdefault(ival, set()).add(ikey)
def get(self, key, **kwargs):
- if self.__conf.weakkey:
- key = id(key)
-
- if key in self.__map:
- if self.__conf.weakvalue:
- return self.__map[key]()
- else:
- return self.__map[key]
-
+ ikey = id(key)
+ if ikey in self.__map:
+ return self.__valmemo[self.__map[ikey][0]]()
if 'default' in kwargs:
return kwargs['default']
-
raise KeyError(str(key))
-
- def remember(self, value, ondel):
- try:
- vref = ref(value, ondel)
- self.__memo[id(value)] = vref
- except Exception as e:
- if self.__conf.refanyway:
- vref = lambda:value
- else:
- raise e
- return vref
+ def pop(self, key):
+ ikey = id(key)
+ if ikey in self.__map:
+ value = self.__valmemo[self.__map[ikey][0]]()
+ self.forget_key(ikey)
+ return value
+ else:
+ raise KeyError(str(key))
+
def forget_key(self, keyid, obj=None):
- for d in [self.__map, self.__memo]:
- if keyid in d: d.pop(keyid)
+ if keyid in self.__map:
+ ival = self.__map.pop(keyid)[0]
+ self.__backref[ival].remove(keyid)
+ if not self.__backref[ival]:
+ self.__backref.pop(ival)
+ self.__valmemo.pop(ival)
def forget_value(self, valueid, obj=None):
- for keyid in self.__backref.get(valueid, ()):
- if keyid in self.__map: self.__map.pop(keyid)
- for d in [self.__backref, self.__memo]:
- if valueid in d: d.pop(valueid)
+ for ikey in self.__backref.get(valueid, ()):
+ if ikey in self.__map:
+ self.__map.pop(ikey)
+ for d in [self.__backref, self.__valmemo]:
+ if valueid in d:
+ d.pop(valueid)
def __iter__(self):
for k, v in self.__map.iteritems():
- if self.__conf.weakkey:
- k = self.__memo[k]()
- if self.__conf.weakvalue:
- v = self.__memo[v]()
- yield (k, v)
+ k = v[1]()
+ v = self.__valmemo[v[0]]()
+ yield (k, v)
+
+class weakkeymap(object):
+ def __init__(self):
+ self.__map = {}
+
+ def set(self, key, value):
+ ikey = id(key)
+ self.__map[ikey] = (value,
+ _ref(key, partial(self.forget_key, ikey))
+ )
+
+ def get(self, key, **kwargs):
+ ikey = id(key)
+ if ikey in self.__map:
+ return self.__map[ikey][0]
+ if 'default' in kwargs:
+ return kwargs['default']
+ raise KeyError(str(key))
+
+ def pop(self, key):
+ ikey = id(key)
+ if ikey in self.__map:
+ return self.__map.pop(ikey)[0]
+ else:
+ raise KeyError(str(key))
+
+ def forget_key(self, keyid, obj=None):
+ if keyid in self.__map:
+ self.__map.pop(keyid)
+
+ def __iter__(self):
+ for k, v in self.__map.iteritems():
+ k = v[1]()
+ yield (k, v[0])