Welcome, guest | Sign In | My Account | Store | Cart

Extremely fast, non-recursive, depth limited flatten with powerful control over which subtrees are to be expanded. If this is what you need then look no further.

Python, 256 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#! /usr/bin/env python
######################################################################
#  Written by Kevin L. Sitze on 2010-11-25
#  This code may be used pursuant to the MIT License.
######################################################################

"""This module contains four flatten() functions for generating either
a sequence or an iterator.  The goal is to "flatten" a tree (typically
comprised of lists and tuples) by slicing the elements in each
contained subtree over the subtree element itself.  For example:

    ([a, b], [c, (d, e)]) => (a, b, c, d, e)

The functions available via this module are

    flatten    ( sequence[, max_depth[, ltypes]] ) => sequence
    xflatten   ( sequence[, max_depth[, ltypes]] ) => iterator
    flatten_it ( iterable[, max_depth[, ltypes]] ) => sequence
    xflatten_it( iterable[, max_depth[, ltypes]] ) => iterator

Each function takes as its only required argument the tree to flatten.
The first two functions (flatten() and xflatten()) require their first
argument to be a valid Python sequence object.  The '_it' functions
(flatten_it() and xflatten_it()) accept any iterable object.

The return type for the flatten() and xflatten_it functions is the
same type as the input sequence, when possible, otherwise the type
will be 'list'.

Wall clock speed of these functions increase from the top of the list
down (i.e., where possible prefer the flatten() function to any other
if speed is a concern).

The "max_depth" argument is either a non-negative integer indicating
the maximum tree depth to flatten; or "None" to flatten the entire
tree.  The required "sequence" argument has a 'depth' of zero; a list
element of the sequence would be flattened if "max_depth" is greater
than or equal to one (1).  A negative depth is treated the same as a
depth of zero.

The "ltypes" argument indicates which elements are subtrees.  It may
be either a collection of sequence types or a predicate function.  If
"ltypes" is a collection a subtree is expanded if its type is in the
collection.  If "ltypes" is a predicate function a subtree is expanded
if the predicate function returns "True" for that subtree.

The implementation of flatten here runs in O(N) time where N is the
number of elements in the traversed tree.  It uses O(N+D) space
where D is the maximum depth of the traversed tree.
"""

def flatten( l, max_depth = None, ltypes = ( list, tuple ) ):
    """flatten( sequence[, max_depth[, ltypes]] ) => sequence

    Flatten every sequence in "l" whose type is contained in "ltypes"
    to "max_depth" levels down the tree.  See the module documentation
    for a complete description of this function.

    The sequence returned has the same type as the input sequence.
    """
    if max_depth is None: make_flat = lambda x: True
    else: make_flat = lambda x: max_depth > len( x )
    if callable( ltypes ): is_sequence = ltypes
    else: is_sequence = lambda x: isinstance( x, ltypes )

    r = list()
    s = list()
    s.append(( 0, l ))
    while s:
        i, l = s.pop()
        while i < len( l ):
            while is_sequence( l[i] ):
                if not l[i]: break
                elif make_flat( s ):
                    s.append(( i + 1, l ))
                    l = l[i]
                    i = 0
                else:
                    r.append( l[i] )
                    break
            else: r.append( l[i] )
            i += 1
    try: return type(l)(r)
    except TypeError: return r

def xflatten( l, max_depth = None, ltypes = ( list, tuple ) ):
    """xflatten( sequence[, max_depth[, ltypes]] ) => iterable

    Flatten every sequence in "l" whose type is contained in "ltypes"
    to "max_depth" levels down the tree.  See the module documentation
    for a complete description of this function.

    This is the iterator version of the flatten function.
    """
    if max_depth is None: make_flat = lambda x: True
    else: make_flat = lambda x: max_depth > len( x )
    if callable( ltypes ): is_sequence = ltypes
    else: is_sequence = lambda x: isinstance( x, ltypes )

    r = list()
    s = list()
    s.append(( 0, l ))
    while s:
        i, l = s.pop()
        while i < len( l ):
            while is_sequence( l[i] ):
                if not l[i]: break
                elif make_flat( s ):
                    s.append(( i + 1, l ))
                    l = l[i]
                    i = 0
                else:
                    yield l[i]
                    break
            else: yield l[i]
            i += 1

def flatten_it( l, max_depth = None, ltypes = ( list, tuple ) ):
    """flatten_it( iterator[, max_depth[, ltypes]] ) => sequence

    Flatten every sequence in "l" whose type is contained in "ltypes"
    to "max_depth" levels down the tree.  See the module documentation
    for a complete description of this function.

    The sequence returned has the same type as the input sequence.
    """
    if max_depth is None: make_flat = lambda x: True
    else: make_flat = lambda x: max_depth > len( x )
    if callable( ltypes ): is_iterable = ltypes
    else: is_iterable = lambda x: isinstance( x, ltypes )

    r = list()
    s = list()
    s.append(( iter( l ) ))
    while s:
        i = s.pop()
        try:
            while True:
                e = i.next()
                if is_iterable( e ):
                    if make_flat( s ):
                        s.append(( i ))
                        i = iter( e )
                    else:
                        r.append( e )
                else:
                    r.append( e )
        except StopIteration: pass
    try: return type(l)(r)
    except TypeError: return r

def xflatten_it( l, max_depth = None, ltypes = ( list, tuple ) ):
    """xflatten_it( iterator[, max_depth[, ltypes]] ) => iterator

    Flatten every sequence in "l" whose type is contained in "ltypes"
    to "max_depth" levels down the tree.  See the module documentation
    for a complete description of this function.

    This is the iterator version of the flatten_it function.
    """
    if max_depth is None: make_flat = lambda x: True
    else: make_flat = lambda x: max_depth > len( x )
    if callable( ltypes ): is_iterable = ltypes
    else: is_iterable = lambda x: isinstance( x, ltypes )

    r = list()
    s = list()
    s.append(( iter( l ) ))
    while s:
        i = s.pop()
        try:
            while True:
                e = i.next()
                if is_iterable( e ):
                    if make_flat( s ):
                        s.append(( i ))
                        i = iter( e )
                    else:
                        yield e
                else:
                    yield e
        except StopIteration: pass

if __name__ == '__main__':

    import sys
    import traceback
    def assertEquals( exp, got ):
        if exp is got:
            r = True
        elif type( exp ) is not type( got ):
            r = False
        elif type( exp ) in ( float, complex ):
            r = abs( exp - got ) < 1e-8
        else:
            r = ( exp == got )
        if not r:
            print >>sys.stderr, "Error: expected <%s> but got <%s>" % ( repr( exp ), repr( got ) )
            traceback.print_stack()

    def test( exp, got, depth = None ):
        assertEquals( exp, flatten( got, depth ) )
        assertEquals( exp, tuple( xflatten( got, depth ) ) )
        assertEquals( exp, flatten_it( got, depth ) )
        assertEquals( exp, tuple( xflatten_it( got, depth ) ) )

    test( (),      () )
    test( (),      (()) )
    test( (),      ((),()) )
    test( (),      ((),((),()),()) )
    test( (1,),    ((1,),((),()),()) )
    test( (1,),    ((),1,((),()),()) )
    test( (1,),    ((),(1,(),()),()) )
    test( (1,),    ((),((1,),()),()) )
    test( (1,),    ((),((),1,()),()) )
    test( (1,),    ((),((),(1,)),()) )
    test( (1,),    ((),((),(),1),()) )
    test( (1,),    ((),((),()),1,()) )
    test( (1,),    ((),((),()),(1,)) )
    test( (1,),    ((),((),()),(),1) )
    test( (1,),    ((),1,()) )
    test( (1,2,3), (1,2,3) )
    test( (1,2,3), ((1,2),3) )
    test( (1,2,3), (1,(2,3)) )
    test( (1,2,3), ((1,),(2,),3) )
    test( ((((((((((0,),1),2),3),4),5),6),7),8),9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 0 )
    test( (((((((((0,),1),2),3),4),5),6),7),8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 1 )
    test( ((((((((0,),1),2),3),4),5),6),7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 2 )
    test( (((((((0,),1),2),3),4),5),6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 3 )
    test( ((((((0,),1),2),3),4),5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 4 )
    test( (((((0,),1),2),3),4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 5 )
    test( ((((0,),1),2),3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 6 )
    test( (((0,),1),2,3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 7 )
    test( ((0,),1,2,3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 8 )
    test( (0,1,2,3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 9 )
    test( (0,1,2,3,4,5,6,7,8,9), ((((((((((0,),1),2),3),4),5),6),7),8),9), 10 )

    test( ({1:2},3,4,set([5,6])), ({1:2},(3,4),set([5,6])) )

    l = (1,)
    # Build a tree 1 million elements deep
    for i in xrange( 1000000 ): l = ( l, 2 )
    # expected value is a 1 followed by 1 million 2's
    exp = (1,) + (2,) * 1000000
    # # Under 5 seconds on my machine...
    # got = flatten( l )
    # assert( exp == got )
    # # Also under 5 seconds...
    # got = tuple( xflatten( l ) )
    # assert( exp == got )
    # # 6 seconds
    # got = flatten_it( l )
    # assert( exp == got )
    # # 7 seconds
    # got = tuple( xflatten_it( l ) )
    # assert( exp == got )

OK, I know what you're thinking, "Why are you placing Yet Another Flatten Function (YAFF) on ActiveState? Look at all the discussions on this topic, hasn't it been beaten to death?"

http://code.activestate.com/recipes/576487-flatten-sequences/
http://code.activestate.com/recipes/363051-flatten/
http://code.activestate.com/recipes/577250-flatten-a-list/
http://code.activestate.com/recipes/577255-flatten-a-list-or-list-of-lists-etc/
http://code.activestate.com/recipes/577387-non-recursive-flatten-leaves-strings-and-dicts-alo/

Well, after looking around and reading comments and generally getting miffed at so-called fast code and so called "change-in-place" code I decided to put my own 2 cents worth into the fray. The various flatten functions I seen here suffer from one or more of the following issues:

  1. it is too slow for large data sets.
  2. it is a recursive algorithm, so the tree depth matters.
  3. which sequences or iterators in the tree are flattened is ad-hoc, often with the programmer having to hard code changes in the algorithm to customize the operation for their own needs.
  4. none of the functions allow control over the maximum tree depth to apply the flatten operation.
  5. even so called "change in place" code often isn't; the Python system still has to manage resizing your list as new elements are sliced in, so your memory footprint is still the same despite using "only one list."

So for you flatten aficionados out there, I give you not just one flatten function, I give you four! Each of these functions perform optimally in space and time (for Python types anyway) and will chug merrily away at your flatten needs until this function finally appears as a builtin. The solution provided here:

  1. is linear in time with your tree growth.
  2. is not recursive, keeping track of the descent via an internal stack.
  3. allows complete control over what items are to be flattened at every stage. The ability to indicate a predicate function makes your wierd flatten projects just work. You need to flatten only sequences having an odd number of elements? No problem!

    flatten( seq, ltypes = lambda x: len( x ) % 2 == 1 )

  4. allows a maximum depth restriction on how deep to apply the flatten operation.

  5. tries its best to move each element only once (subject to internal list resizing).
  6. allows you to consume your flattened tree as either a sequence or an iterator.
  7. allows you to flatten a tree produced by a generator.
  8. is FAST!
  9. is DOCUMENTED!
  10. is VERSITILE!
  11. has UNIT TESTS!

What these functions don't do:

  1. provide a minimum depth restriction to specify when to start flattening subtrees.
  2. ...
  3. wash your cloths
  4. ...
  5. do non-flattening things...
  6. ...
  7. ...

Well, that's it. What are you waiting for? Download the code, save it to 'yaff.py' in your local python library directory and start flattening those trees (and don't waste paper)!

2 comments

someonesdad 13 years, 4 months ago  # | flag

Kevin: Bravo! What immediately prompted me to download your stuff is that you put some thought into the self tests. This is important for production-use code.

The biggest advantages for my purposes are: your routines aren't recursive and the depth control/predicate option give the programmer flexibility.

Thanks for all the work!

A minor nit: I'd recommend you follow PEP8 coding style guidelines, as it makes the code easier to read. But that's minor -- you put in a lot of good work and it will help others do their stuff better.

François ALLAIN 13 years, 2 months ago  # | flag

That's nicely coded and tested, thank you. However, I'm stuck with this problem that I haven't managed to resolve with your flatten function :

data = [1, 567, 2, (23, 'a'), [789, 7, 9, [700, 777, 284]], 13, (435, 'b')]

And I want to flatten it except all the tuples which would result to :

result = [1, 567, 2, (23, 'a'), 789, 7, 9, 700, 777, 284, 13, (435, 'b')]

I have tried flatten(data, ltypes=(list)) with no success... Could you tell me how to perform this ? Thanks in advance