#! /usr/bin/env python ###################################################################### # Written by Kevin L. Sitze on 2010-11-25 # This code may be used pursuant to the MIT License. ###################################################################### """This module contains four flatten() functions for generating either a sequence or an iterator. The goal is to "flatten" a tree (typically comprised of lists and tuples) by slicing the elements in each contained subtree over the subtree element itself. For example: ([a, b], [c, (d, e)]) => (a, b, c, d, e) The functions available via this module are flatten ( sequence[, max_depth[, ltypes]] ) => sequence xflatten ( sequence[, max_depth[, ltypes]] ) => iterator flatten_it ( iterable[, max_depth[, ltypes]] ) => sequence xflatten_it( iterable[, max_depth[, ltypes]] ) => iterator Each function takes as its only required argument the tree to flatten. The first two functions (flatten() and xflatten()) require their first argument to be a valid Python sequence object. The '_it' functions (flatten_it() and xflatten_it()) accept any iterable object. The return type for the flatten() and xflatten_it functions is the same type as the input sequence, when possible, otherwise the type will be 'list'. Wall clock speed of these functions increase from the top of the list down (i.e., where possible prefer the flatten() function to any other if speed is a concern). The "max_depth" argument is either a non-negative integer indicating the maximum tree depth to flatten; or "None" to flatten the entire tree. The required "sequence" argument has a 'depth' of zero; a list element of the sequence would be flattened if "max_depth" is greater than or equal to one (1). A negative depth is treated the same as a depth of zero. The "ltypes" argument indicates which elements are subtrees. It may be either a collection of sequence types or a predicate function. If "ltypes" is a collection a subtree is expanded if its type is in the collection. If "ltypes" is a predicate function a subtree is expanded if the predicate function returns "True" for that subtree. The implementation of flatten here runs in O(N) time where N is the number of elements in the traversed tree. It uses O(N+D) space where D is the maximum depth of the traversed tree. """ def flatten( l, max_depth = None, ltypes = ( list, tuple ) ): """flatten( sequence[, max_depth[, ltypes]] ) => sequence Flatten every sequence in "l" whose type is contained in "ltypes" to "max_depth" levels down the tree. See the module documentation for a complete description of this function. The sequence returned has the same type as the input sequence. """ if max_depth is None: make_flat = lambda x: True else: make_flat = lambda x: max_depth > len( x ) if callable( ltypes ): is_sequence = ltypes else: is_sequence = lambda x: isinstance( x, ltypes ) r = list() s = list() s.append(( 0, l )) while s: i, l = s.pop() while i < len( l ): while is_sequence( l[i] ): if not l[i]: break elif make_flat( s ): s.append(( i + 1, l )) l = l[i] i = 0 else: r.append( l[i] ) break else: r.append( l[i] ) i += 1 try: return type(l)(r) except TypeError: return r def xflatten( l, max_depth = None, ltypes = ( list, tuple ) ): """xflatten( sequence[, max_depth[, ltypes]] ) => iterable Flatten every sequence in "l" whose type is contained in "ltypes" to "max_depth" levels down the tree. See the module documentation for a complete description of this function. This is the iterator version of the flatten function. """ if max_depth is None: make_flat = lambda x: True else: make_flat = lambda x: max_depth > len( x ) if callable( ltypes ): is_sequence = ltypes else: is_sequence = lambda x: isinstance( x, ltypes ) r = list() s = list() s.append(( 0, l )) while s: i, l = s.pop() while i < len( l ): while is_sequence( l[i] ): if not l[i]: break elif make_flat( s ): s.append(( i + 1, l )) l = l[i] i = 0 else: yield l[i] break else: yield l[i] i += 1 def flatten_it( l, max_depth = None, ltypes = ( list, tuple ) ): """flatten_it( iterator[, max_depth[, ltypes]] ) => sequence Flatten every sequence in "l" whose type is contained in "ltypes" to "max_depth" levels down the tree. See the module documentation for a complete description of this function. The sequence returned has the same type as the input sequence. """ if max_depth is None: make_flat = lambda x: True else: make_flat = lambda x: max_depth > len( x ) if callable( ltypes ): is_iterable = ltypes else: is_iterable = lambda x: isinstance( x, ltypes ) r = list() s = list() s.append(( iter( l ) )) while s: i = s.pop() try: while True: e = i.next() if is_iterable( e ): if make_flat( s ): s.append(( i )) i = iter( e ) else: r.append( e ) else: r.append( e ) except StopIteration: pass try: return type(l)(r) except TypeError: return r def xflatten_it( l, max_depth = None, ltypes = ( list, tuple ) ): """xflatten_it( iterator[, max_depth[, ltypes]] ) => iterator Flatten every sequence in "l" whose type is contained in "ltypes" to "max_depth" levels down the tree. See the module documentation for a complete description of this function. This is the iterator version of the flatten_it function. """ if max_depth is None: make_flat = lambda x: True else: make_flat = lambda x: max_depth > len( x ) if callable( ltypes ): is_iterable = ltypes else: is_iterable = lambda x: isinstance( x, ltypes ) r = list() s = list() s.append(( iter( l ) )) while s: i = s.pop() try: while True: e = i.next() if is_iterable( e ): if make_flat( s ): s.append(( i )) i = iter( e ) else: yield e else: yield e except StopIteration: pass if __name__ == '__main__': import sys import traceback def assertEquals( exp, got ): if exp is got: r = True elif type( exp ) is not type( got ): r = False elif type( exp ) in ( float, complex ): r = abs( exp - got ) < 1e-8 else: r = ( exp == got ) if not r: print >>sys.stderr, "Error: expected <%s> but got <%s>" % ( repr( exp ), repr( got ) ) traceback.print_stack() def test( exp, got, depth = None ): assertEquals( exp, flatten( got, depth ) ) assertEquals( exp, tuple( xflatten( got, depth ) ) ) assertEquals( exp, flatten_it( got, depth ) ) assertEquals( exp, tuple( xflatten_it( got, depth ) ) ) test( (), () ) test( (), (()) ) test( (), ((),()) ) test( (), ((),((),()),()) ) test( (1,), ((1,),((),()),()) ) test( (1,), ((),1,((),()),()) ) test( (1,), ((),(1,(),()),()) ) test( (1,), ((),((1,),()),()) ) test( (1,), ((),((),1,()),()) ) test( (1,), ((),((),(1,)),()) ) test( (1,), ((),((),(),1),()) ) test( (1,), ((),((),()),1,()) ) test( (1,), ((),((),()),(1,)) ) test( (1,), ((),((),()),(),1) ) test( (1,), ((),1,()) ) test( (1,2,3), (1,2,3) ) test( (1,2,3), ((1,2),3) ) test( (1,2,3), (1,(2,3)) ) test( (1,2,3), ((1,),(2,),3) ) test( ((((((((((0,),1),2),3),4),5),6),7),8),9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 0 ) test( (((((((((0,),1),2),3),4),5),6),7),8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 1 ) test( ((((((((0,),1),2),3),4),5),6),7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 2 ) test( (((((((0,),1),2),3),4),5),6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 3 ) test( ((((((0,),1),2),3),4),5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 4 ) test( (((((0,),1),2),3),4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 5 ) test( ((((0,),1),2),3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 6 ) test( (((0,),1),2,3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 7 ) test( ((0,),1,2,3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 8 ) test( (0,1,2,3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 9 ) test( (0,1,2,3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 10 ) test( ({1:2},3,4,set([5,6])), ({1:2},(3,4),set([5,6])) ) l = (1,) # Build a tree 1 million elements deep for i in xrange( 1000000 ): l = ( l, 2 ) # expected value is a 1 followed by 1 million 2's exp = (1,) + (2,) * 1000000 # # Under 5 seconds on my machine... # got = flatten( l ) # assert( exp == got ) # # Also under 5 seconds... # got = tuple( xflatten( l ) ) # assert( exp == got ) # # 6 seconds # got = flatten_it( l ) # assert( exp == got ) # # 7 seconds # got = tuple( xflatten_it( l ) ) # assert( exp == got )