Welcome, guest | Sign In | My Account | Store | Cart

The heapq module provides efficient functions for getting the top-N smallest and largest elements of an iterable. A caveat of these functions is that if there are ties (i.e. equal elements with respect to the comparison key), some elements may end up in the returned top-N list while some equal others may not:

>>> nsmallest(3, [4,3,-2,-3,2], key=abs)
[-2, 2, 3]

Although 3 and -3 are equal with respect to the key function, only one of them is chosen to be returned. For several applications, an all-or-nothing approach with respect to ties is preferable or even required.

A new optional boolean parameter 'ties' is proposed to accomodate these cases. If ties=True and the iterable contains more than N elements, the length of the returned sorted list can be lower than N if not all ties at the last position can fit in the list:

>>> nsmallest(3, [4,3,-2,-3,2], key=abs, ties=True)
[-2, 2]
Python, 178 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import heapq, bisect
from operator import itemgetter, neg
from itertools import islice, repeat, count, imap, izip, tee


def nlargest(n, iterable, key=None, ties=False):
    '''Find the n largest elements in iterable.
    
    @param ties: If False, equivalent to heapq.nlargest(); ties are not taken
        into account. If True, the returned list is guaranteed to contain all
        the equal smallest elements of the top-`n`. If it is not possible to
        satisfy this constraint by returning a list of size `n` exactly, then
        the top-`k` elements that satisfy the constraint are returned, where
        `k` is the largest integer smaller than `n`.
        
    >>> s = [-4,3,5,7,4,-7,-4,-3]
    >>> for i in xrange(1,len(s)+1):
    ...     print i, nlargest(i,s,key=abs)
    1 [7]
    2 [7, -7]
    3 [7, -7, 5]
    4 [7, -7, 5, -4]
    5 [7, -7, 5, -4, 4]
    6 [7, -7, 5, -4, 4, -4]
    7 [7, -7, 5, -4, 4, -4, 3]
    8 [7, -7, 5, -4, 4, -4, 3, -3]

    >>> for i in xrange(1,len(s)+1):
    ...     print i, nlargest(i,s,key=abs,ties=True)
    1 []
    2 [7, -7]
    3 [7, -7, 5]
    4 [7, -7, 5]
    5 [7, -7, 5]
    6 [7, -7, 5, -4, 4, -4]
    7 [7, -7, 5, -4, 4, -4]
    8 [7, -7, 5, -4, 4, -4, 3, -3]
    '''
    if not ties:
        return heapq.nlargest(n, iterable, key)
    in1, in2 = tee(iterable)
    it = izip(imap(key, in1), imap(neg, count()), in2)      # decorate
    result = list(islice(it, n))
    if not result:
        return result
    heapq.heapify(result)    
    heappush, heappop, heapreplace = heapq.heappush, heapq.heappop, heapq.heapreplace
    # smallest_key: smallest key of the nlargest
    smallest_key = result[0][0]
    # overflow: True if there are currently ties that don't fit in result
    overflow = False    
    for elem in it:        
        elem_key = elem[0]
        if elem_key < smallest_key:
            continue
        if not overflow:
            assert len(result) == n and result[0][0] == smallest_key
            if elem_key > smallest_key:
                elem_key = heapreplace(result, elem)[0]
                smallest_key = result[0][0]
            assert elem_key <= smallest_key
            # if the pending element (new or replaced) is equal to the smallest
            # we've got a tie that can't fit in result: drop ties
            if elem_key == smallest_key:
                overflow = True
                while result and result[0][0] == elem_key:
                    heappop(result)
        else:
            assert len(result) < n
            if elem_key > smallest_key:
                heappush(result, elem)
                if len(result) == n:
                    # result just filled and the last element is larger
                    # than smallest: existing ties are invalidated
                    overflow = False
                    smallest_key = result[0][0]
    result.sort(reverse=True)
    return map(itemgetter(2), result)                       # undecorate 


def nsmallest(n, iterable, key=None, ties=False):
    '''Find the n smallest elements in iterable.
    
    @param ties: If False, equivalent to heapq.nsmallest(); ties are not taken
        into account. If True, the returned list is guaranteed to contain all
        the equal largest elements of the top-`n`. If it is not possible to
        satisfy this constraint by returning a list of size `n` exactly, then
        the top-`k` elements that satisfy the constraint are returned, where
        `k` is the largest integer smaller than `n`.
        
    >>> s = [-4,3,5,7,4,-7,-4,-3]
    >>> for i in xrange(1,len(s)+1):
    ...     print i, nsmallest(i,s,key=abs)
    1 [3]
    2 [3, -3]
    3 [3, -3, -4]
    4 [3, -3, -4, 4]
    5 [3, -3, -4, 4, -4]
    6 [3, -3, -4, 4, -4, 5]
    7 [3, -3, -4, 4, -4, 5, 7]
    8 [3, -3, -4, 4, -4, 5, 7, -7]

    >>> for i in xrange(1,len(s)+1):
    ...     print i, nsmallest(i,s,key=abs,ties=True)
    1 []
    2 [3, -3]
    3 [3, -3]
    4 [3, -3]
    5 [3, -3, -4, 4, -4]
    6 [3, -3, -4, 4, -4, 5]
    7 [3, -3, -4, 4, -4, 5]
    8 [3, -3, -4, 4, -4, 5, 7, -7]
    '''
    if not ties:
        return heapq.nsmallest(n, iterable, key)
    in1, in2 = tee(iterable)
    it = izip(imap(key, in1), count(), in2)     # decorate
    if hasattr(iterable, '__len__') and n * 10 <= len(iterable):
        # For smaller values of n, the bisect method is faster than a minheap.
        # It is also memory efficient, consuming only n elements of space.
        result = sorted(islice(it, n))
        if not result:
            return []
        insort = bisect.insort
        pop = result.pop
        # largest_key: largest key of the nsmallest
        largest_key = result[-1][0]
        # overflow: True if there are currently ties that don't fit in result
        overflow = False
        for elem in it:
            elem_key = elem[0]
            if elem_key > largest_key:
                continue
            if not overflow:
                assert len(result) == n and result[-1][0] == largest_key
                if elem_key < largest_key:
                    insort(result, elem)
                    # pop the largest from the result
                    elem_key = pop()[0]
                    # and update largest to the new largest
                    largest_key = result[-1][0]
                assert elem_key >= largest_key
                # if the pending element (new or popped) is equal to the largest
                # we've got a tie that can't fit in result: drop ties
                if elem_key == largest_key:
                    overflow = True
                    while result and result[-1][0] == elem_key:
                        pop()
            else:
                assert len(result) < n
                if elem_key < largest_key:
                    insort(result, elem)
                    if len(result) == n:
                        # result just filled and the last element is smaller
                        # than largest: existing ties are invalidated
                        overflow = False
                        largest_key = result[-1][0]
    else:
        # An alternative approach manifests the whole iterable in memory but
        # saves comparisons by heapifying all at once.  Also, saves time
        # over bisect.insort() which has O(n) data movement time for every
        # insertion.  Finding the n smallest of an m length iterable requires
        #    O(m) + O(n log m) comparisons.
        h = list(it)
        heapq.heapify(h)
        result = map(heapq.heappop, repeat(h, min(n, len(h))))
        if result:
            largest_key = result[-1][0]
            # is largest_key equal to the next smallest in heap ?
            if h and h[0][0] == largest_key:
                # if yes, delete all trailing ties
                while result and result[-1][0] == largest_key:
                    del result[-1]
    return map(itemgetter(2), result)     # undecorate 


if __name__ == '__main__':    
    import doctest; doctest.testmod()

7 comments

Raymond Hettinger 15 years ago  # | flag

Rather than allow arbitrarily large memory consumption (where you might as well use sorted() instead) and rather than trying to accommodate all possible user choices for discarding, preserving, and whatnot, consider modifying the API to return the standard nlargest/smallest plus a count of the number of ties (perhaps using a namedtuple for clarity):

>>> nsmallest(3, [4,3,-2,-3,2], key=abs)
NSmallest(result=[-2, 2, 3], ties=2)

Limiting the maximum number of ties is a good suggestion, but I don't think returning just the count of ties is sufficient; typically ones wants to get the actual tied objects back, just like all other equal objects that do fit in the N-sized buffer. How about:

def nsmallest(n, iterable, key=None, extra_ties=0)

so that the returned list contains between n and n+extra_ties items ? Perhaps allow also extra_ties=None to mean no limit.

Returning the count of non-included ties would allow usage such as:

top3,missing_ties = nsmallest(3, seq, key=abs, extra_ties=2)
if missing_ties:
    # log it somewhere
    if missing_ties > 10:
        # missing too many; retry with larger extra_ties
        top3,extra_ties = nsmallest(3, seq, key=abs, extra_ties=12)
Raymond Hettinger 15 years ago  # | flag

The more I look at the code and think of possible API alternatives and implementation challenges, the more I think it's better to just use sorted() or to call nsmallest/nlargest() with a larger value of n (to accommodate the maximum allowable number of ties).

In terms of API, I'd be happy with even the minimal I suggested in the beginning, a single 'ties' boolean parameter that if True returns all ties (i.e. ties='preserve' in the current recipe). I don't see the memory consumption issue as a show-stopper; it should be rare in real use cases and even if it occurs it's no worse (memory-wise) than using sorted().

As for the implementation, I'm not too attached to it; if you value the clarity of calling internally nsmallest() more than once over the performance gain of doing it once, I'm fine with it. We can always optimize it later if necessary.

Raymond Hettinger 15 years ago  # | flag

Put your best version here and I'll link to it from the main docs.

I'm not good at picking universally "best" solutions, that's why I tend to prefer more flexible APIs at the cost of simplicity; here for instance I can see use cases for both the "preserve" and "discard" semantics. If I had to pick one only, I'd probably go with "discard" since it can be implemented without requiring potentially unlimited memory usage (the current implementation just discards the ties at the end but that can be fixed).

Edit:

  • Simplified the API down to a single optional boolean parameter (using the original "discard" semantics).
  • Simplified and condensed the implementation; memory usage is now independent of #ties.