The heapq module provides efficient functions for getting the top-N smallest and largest elements of an iterable. A caveat of these functions is that if there are ties (i.e. equal elements with respect to the comparison key), some elements may end up in the returned top-N list while some equal others may not:
>>> nsmallest(3, [4,3,-2,-3,2], key=abs)
[-2, 2, 3]
Although 3 and -3 are equal with respect to the key function, only one of them is chosen to be returned. For several applications, an all-or-nothing approach with respect to ties is preferable or even required.
A new optional boolean parameter 'ties' is proposed to accomodate these cases. If ties=True and the iterable contains more than N elements, the length of the returned sorted list can be lower than N if not all ties at the last position can fit in the list:
>>> nsmallest(3, [4,3,-2,-3,2], key=abs, ties=True)
[-2, 2]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import heapq, bisect
from operator import itemgetter, neg
from itertools import islice, repeat, count, imap, izip, tee
def nlargest(n, iterable, key=None, ties=False):
'''Find the n largest elements in iterable.
@param ties: If False, equivalent to heapq.nlargest(); ties are not taken
into account. If True, the returned list is guaranteed to contain all
the equal smallest elements of the top-`n`. If it is not possible to
satisfy this constraint by returning a list of size `n` exactly, then
the top-`k` elements that satisfy the constraint are returned, where
`k` is the largest integer smaller than `n`.
>>> s = [-4,3,5,7,4,-7,-4,-3]
>>> for i in xrange(1,len(s)+1):
... print i, nlargest(i,s,key=abs)
1 [7]
2 [7, -7]
3 [7, -7, 5]
4 [7, -7, 5, -4]
5 [7, -7, 5, -4, 4]
6 [7, -7, 5, -4, 4, -4]
7 [7, -7, 5, -4, 4, -4, 3]
8 [7, -7, 5, -4, 4, -4, 3, -3]
>>> for i in xrange(1,len(s)+1):
... print i, nlargest(i,s,key=abs,ties=True)
1 []
2 [7, -7]
3 [7, -7, 5]
4 [7, -7, 5]
5 [7, -7, 5]
6 [7, -7, 5, -4, 4, -4]
7 [7, -7, 5, -4, 4, -4]
8 [7, -7, 5, -4, 4, -4, 3, -3]
'''
if not ties:
return heapq.nlargest(n, iterable, key)
in1, in2 = tee(iterable)
it = izip(imap(key, in1), imap(neg, count()), in2) # decorate
result = list(islice(it, n))
if not result:
return result
heapq.heapify(result)
heappush, heappop, heapreplace = heapq.heappush, heapq.heappop, heapq.heapreplace
# smallest_key: smallest key of the nlargest
smallest_key = result[0][0]
# overflow: True if there are currently ties that don't fit in result
overflow = False
for elem in it:
elem_key = elem[0]
if elem_key < smallest_key:
continue
if not overflow:
assert len(result) == n and result[0][0] == smallest_key
if elem_key > smallest_key:
elem_key = heapreplace(result, elem)[0]
smallest_key = result[0][0]
assert elem_key <= smallest_key
# if the pending element (new or replaced) is equal to the smallest
# we've got a tie that can't fit in result: drop ties
if elem_key == smallest_key:
overflow = True
while result and result[0][0] == elem_key:
heappop(result)
else:
assert len(result) < n
if elem_key > smallest_key:
heappush(result, elem)
if len(result) == n:
# result just filled and the last element is larger
# than smallest: existing ties are invalidated
overflow = False
smallest_key = result[0][0]
result.sort(reverse=True)
return map(itemgetter(2), result) # undecorate
def nsmallest(n, iterable, key=None, ties=False):
'''Find the n smallest elements in iterable.
@param ties: If False, equivalent to heapq.nsmallest(); ties are not taken
into account. If True, the returned list is guaranteed to contain all
the equal largest elements of the top-`n`. If it is not possible to
satisfy this constraint by returning a list of size `n` exactly, then
the top-`k` elements that satisfy the constraint are returned, where
`k` is the largest integer smaller than `n`.
>>> s = [-4,3,5,7,4,-7,-4,-3]
>>> for i in xrange(1,len(s)+1):
... print i, nsmallest(i,s,key=abs)
1 [3]
2 [3, -3]
3 [3, -3, -4]
4 [3, -3, -4, 4]
5 [3, -3, -4, 4, -4]
6 [3, -3, -4, 4, -4, 5]
7 [3, -3, -4, 4, -4, 5, 7]
8 [3, -3, -4, 4, -4, 5, 7, -7]
>>> for i in xrange(1,len(s)+1):
... print i, nsmallest(i,s,key=abs,ties=True)
1 []
2 [3, -3]
3 [3, -3]
4 [3, -3]
5 [3, -3, -4, 4, -4]
6 [3, -3, -4, 4, -4, 5]
7 [3, -3, -4, 4, -4, 5]
8 [3, -3, -4, 4, -4, 5, 7, -7]
'''
if not ties:
return heapq.nsmallest(n, iterable, key)
in1, in2 = tee(iterable)
it = izip(imap(key, in1), count(), in2) # decorate
if hasattr(iterable, '__len__') and n * 10 <= len(iterable):
# For smaller values of n, the bisect method is faster than a minheap.
# It is also memory efficient, consuming only n elements of space.
result = sorted(islice(it, n))
if not result:
return []
insort = bisect.insort
pop = result.pop
# largest_key: largest key of the nsmallest
largest_key = result[-1][0]
# overflow: True if there are currently ties that don't fit in result
overflow = False
for elem in it:
elem_key = elem[0]
if elem_key > largest_key:
continue
if not overflow:
assert len(result) == n and result[-1][0] == largest_key
if elem_key < largest_key:
insort(result, elem)
# pop the largest from the result
elem_key = pop()[0]
# and update largest to the new largest
largest_key = result[-1][0]
assert elem_key >= largest_key
# if the pending element (new or popped) is equal to the largest
# we've got a tie that can't fit in result: drop ties
if elem_key == largest_key:
overflow = True
while result and result[-1][0] == elem_key:
pop()
else:
assert len(result) < n
if elem_key < largest_key:
insort(result, elem)
if len(result) == n:
# result just filled and the last element is smaller
# than largest: existing ties are invalidated
overflow = False
largest_key = result[-1][0]
else:
# An alternative approach manifests the whole iterable in memory but
# saves comparisons by heapifying all at once. Also, saves time
# over bisect.insort() which has O(n) data movement time for every
# insertion. Finding the n smallest of an m length iterable requires
# O(m) + O(n log m) comparisons.
h = list(it)
heapq.heapify(h)
result = map(heapq.heappop, repeat(h, min(n, len(h))))
if result:
largest_key = result[-1][0]
# is largest_key equal to the next smallest in heap ?
if h and h[0][0] == largest_key:
# if yes, delete all trailing ties
while result and result[-1][0] == largest_key:
del result[-1]
return map(itemgetter(2), result) # undecorate
if __name__ == '__main__':
import doctest; doctest.testmod()
|
Rather than allow arbitrarily large memory consumption (where you might as well use sorted() instead) and rather than trying to accommodate all possible user choices for discarding, preserving, and whatnot, consider modifying the API to return the standard nlargest/smallest plus a count of the number of ties (perhaps using a namedtuple for clarity):
Limiting the maximum number of ties is a good suggestion, but I don't think returning just the count of ties is sufficient; typically ones wants to get the actual tied objects back, just like all other equal objects that do fit in the N-sized buffer. How about:
so that the returned list contains between n and n+extra_ties items ? Perhaps allow also extra_ties=None to mean no limit.
Returning the count of non-included ties would allow usage such as:
The more I look at the code and think of possible API alternatives and implementation challenges, the more I think it's better to just use sorted() or to call nsmallest/nlargest() with a larger value of n (to accommodate the maximum allowable number of ties).
In terms of API, I'd be happy with even the minimal I suggested in the beginning, a single 'ties' boolean parameter that if True returns all ties (i.e. ties='preserve' in the current recipe). I don't see the memory consumption issue as a show-stopper; it should be rare in real use cases and even if it occurs it's no worse (memory-wise) than using sorted().
As for the implementation, I'm not too attached to it; if you value the clarity of calling internally nsmallest() more than once over the performance gain of doing it once, I'm fine with it. We can always optimize it later if necessary.
Put your best version here and I'll link to it from the main docs.
I'm not good at picking universally "best" solutions, that's why I tend to prefer more flexible APIs at the cost of simplicity; here for instance I can see use cases for both the "preserve" and "discard" semantics. If I had to pick one only, I'd probably go with "discard" since it can be implemented without requiring potentially unlimited memory usage (the current implementation just discards the ties at the end but that can be fixed).
Edit: