Welcome, guest | Sign In | My Account | Store | Cart
from collections import Mapping, Set, Sequence 

# dual python 2/3 compatability, inspired by the "six" library
string_types = (str, unicode) if str is bytes else (str, bytes)
iteritems = lambda mapping: getattr(mapping, 'iteritems', mapping.items)()

def objwalk(obj, path=(), memo=None):
    if memo is None:
        memo = set()
    iterator = None
    if isinstance(obj, Mapping):
        iterator = iteritems
    elif isinstance(obj, (Sequence, Set)) and not isinstance(obj, string_types):
        iterator = enumerate
    if iterator:
        if id(obj) not in memo:
            memo.add(id(obj))
            for path_component, value in iterator(obj):
                for result in objwalk(value, path + (path_component,), memo):
                    yield result
            memo.remove(id(obj))
    else:
        yield path, obj

# optional test code from here on
import unittest

class TestObjwalk(unittest.TestCase):
    def assertObjwalk(self, object_to_walk, *expected_results):
        return self.assertEqual(tuple(sorted(expected_results)), tuple(sorted(objwalk(object_to_walk))))
    def test_empty_containers(self):
        self.assertObjwalk({})
        self.assertObjwalk([])
    def test_single_objects(self):
        for obj in (None, 42, True, "spam"):
            self.assertObjwalk(obj, ((), obj))
    def test_plain_containers(self):
        self.assertObjwalk([1, True, "spam"], ((0,), 1), ((1,), True), ((2,), "spam"))
        self.assertObjwalk({None: 'eggs', 'bacon': 'ham', 'spam': 1},
                           ((None,), 'eggs'), (('spam',), 1), (('bacon',), 'ham'))
        # sets are unordered, so we dont test the path, only that no object is missing
        self.assertEqual(set(obj for path, obj in objwalk(set((1,2,3)))), set((1,2,3)))
    def test_nested_containers(self):
        self.assertObjwalk([1, [2, [3, 4]]],
                           ((0,), 1), ((1,0), 2), ((1, 1, 0), 3), ((1, 1, 1), 4))
        self.assertObjwalk({1: {2: {3: 'spam'}}},
                           ((1,2,3), 'spam'))
    def test_repeating_containers(self):
        repeated = (1,2)
        self.assertObjwalk([repeated, repeated],
                           ((0, 0), 1), ((0, 1), 2), ((1, 0), 1), ((1, 1), 2))
    def test_recursive_containers(self):
        recursive = [1, 2]
        recursive.append(recursive)
        recursive.append(3)
        self.assertObjwalk(recursive, ((0,), 1), ((1,), 2), ((3,), 3))

if __name__ == '__main__':
    unittest.main()

Diff to Previous Revision

--- revision 1 2011-12-13 09:42:13
+++ revision 2 2011-12-23 22:10:38
@@ -1,25 +1,59 @@
 from collections import Mapping, Set, Sequence 
 
-try:
-    from six import string_types, iteritems
-except ImportError:
-    string_types = (str, unicode) if str is bytes else (str, bytes)
-    iteritems = lambda mapping: getattr(mapping, 'iteritems', mapping.items)()
+# dual python 2/3 compatability, inspired by the "six" library
+string_types = (str, unicode) if str is bytes else (str, bytes)
+iteritems = lambda mapping: getattr(mapping, 'iteritems', mapping.items)()
 
 def objwalk(obj, path=(), memo=None):
     if memo is None:
         memo = set()
+    iterator = None
     if isinstance(obj, Mapping):
-        if id(obj) not in memo:
-            memo.add(id(obj)) 
-            for key, value in iteritems(obj):
-                for child in objwalk(value, path + (key,), memo):
-                    yield child
+        iterator = iteritems
     elif isinstance(obj, (Sequence, Set)) and not isinstance(obj, string_types):
+        iterator = enumerate
+    if iterator:
         if id(obj) not in memo:
             memo.add(id(obj))
-            for index, value in enumerate(obj):
-                for child in objwalk(value, path + (index,), memo):
-                    yield child
+            for path_component, value in iterator(obj):
+                for result in objwalk(value, path + (path_component,), memo):
+                    yield result
+            memo.remove(id(obj))
     else:
         yield path, obj
+
+# optional test code from here on
+import unittest
+
+class TestObjwalk(unittest.TestCase):
+    def assertObjwalk(self, object_to_walk, *expected_results):
+        return self.assertEqual(tuple(sorted(expected_results)), tuple(sorted(objwalk(object_to_walk))))
+    def test_empty_containers(self):
+        self.assertObjwalk({})
+        self.assertObjwalk([])
+    def test_single_objects(self):
+        for obj in (None, 42, True, "spam"):
+            self.assertObjwalk(obj, ((), obj))
+    def test_plain_containers(self):
+        self.assertObjwalk([1, True, "spam"], ((0,), 1), ((1,), True), ((2,), "spam"))
+        self.assertObjwalk({None: 'eggs', 'bacon': 'ham', 'spam': 1},
+                           ((None,), 'eggs'), (('spam',), 1), (('bacon',), 'ham'))
+        # sets are unordered, so we dont test the path, only that no object is missing
+        self.assertEqual(set(obj for path, obj in objwalk(set((1,2,3)))), set((1,2,3)))
+    def test_nested_containers(self):
+        self.assertObjwalk([1, [2, [3, 4]]],
+                           ((0,), 1), ((1,0), 2), ((1, 1, 0), 3), ((1, 1, 1), 4))
+        self.assertObjwalk({1: {2: {3: 'spam'}}},
+                           ((1,2,3), 'spam'))
+    def test_repeating_containers(self):
+        repeated = (1,2)
+        self.assertObjwalk([repeated, repeated],
+                           ((0, 0), 1), ((0, 1), 2), ((1, 0), 1), ((1, 1), 2))
+    def test_recursive_containers(self):
+        recursive = [1, 2]
+        recursive.append(recursive)
+        recursive.append(3)
+        self.assertObjwalk(recursive, ((0,), 1), ((1,), 2), ((3,), 3))
+
+if __name__ == '__main__':
+    unittest.main()

History