Welcome, guest | Sign In | My Account | Store | Cart
from functools import reduce
from collections import deque
from operator import getitem, setitem

def nested_enumerate(lst):
    """An analogue of enumerate for nested lists. 

       Returns an iterator over the (index, element) pairs of `lst` where 
       `index` is a list of integers [x0, x1,.., xn] such that 
       `lst[x0][x1]...[xn]==element`

       
       >>> for i, e in nested_enumerate([0, [[[1, [2, [[[3]]]]]]], [[[4]]]]):
               print('%s %s'%(str(i), str(e)))
       [0] 0
       [1, 0, 0, 0] 1
       [1, 0, 0, 1, 0] 2
       [1, 0, 0, 1, 1, 0, 0, 0] 3
       [2, 0, 0, 0] 4
    """
    
    # initial, partial index of lst
    partial_index = deque([([i], e) for (i, e) in enumerate(lst)])
      
    while partial_index:
        index, obj = partial_index.popleft()
        if isinstance(obj, list):
            # if obj is a list then its elements require further indexing
            new_dimension = [(index+[i], e) for (i, e) in enumerate(obj)]
            partial_index.extendleft(reversed(new_dimension)) 
        else:
            # obj is fully indexed
            yield index, obj


# complementary functions #

def nested_getitem(lst, index):
    """Returns lst[index[0]]...[index[n]]"""
    return reduce(getitem, index, lst)


def nested_setitem(lst, index, value):
    """Equivalent to the statement lst[index[0]]...[index[n]]=value"""
    setitem(
        reduce(getitem, index[0:-1], lst), index[-1], value
    )


# quick test #

deeplist = [0, [[[1, [2, [[[3]]]]]]], [[[4]]]]

for index, element in nested_enumerate(deeplist):
    assert nested_getitem(deeplist, index)==element

# example usage: applying a function to each element in a nested list #

square = lambda x: x**2

for index, element in nested_enumerate(deeplist):
    nested_setitem(deeplist, index, square(element))

assert deeplist==[0, [[[1, [4, [[[9]]]]]]], [[[16]]]]

# not recommended, but demonstrates different ways of traversing a list
# (plus, we all love flatten, right? ;-)

def flatten(lst):
    return [e for (i, e) in nested_enumerate(lst)]

def flatten2(lst):
    return [nested_getitem(lst, i) for (i, e) in nested_enumerate(lst)]

assert flatten(deeplist)==flatten2(deeplist)==[0, 1, 4, 9, 16]

# sort elements based on their depth of nesting, with deepest first
depthfirst = [e for (i, e) in sorted(nested_enumerate(deeplist), key=lambda (i, e):-len(i))]

assert depthfirst == [9, 4, 1, 16, 0]

Diff to Previous Revision

--- revision 5 2012-12-11 15:54:31
+++ revision 6 2012-12-11 16:59:14
@@ -17,7 +17,7 @@
        [1, 0, 0, 1, 0] 2
        [1, 0, 0, 1, 1, 0, 0, 0] 3
        [2, 0, 0, 0] 4
-        """
+    """
     
     # initial, partial index of lst
     partial_index = deque([([i], e) for (i, e) in enumerate(lst)])

History