from random import *
from heapq import *
cmps = 0
class Int(int):
def __lt__(self, other):
global cmps
cmps += 1
return int.__lt__(self, other)
def __le__(self, other):
global cmps
cmps += 1
return int.__le__(self, other)
def count_cmps(f, data, k):
'Count comparisons in a call to f(k, data)'
global cmps
data = data[:]
shuffle(data)
cmps = 0
result = f(k, data)
assert result[:10] == list(range(10))
return cmps
# -------- variants of nsmallest -------
def heapifying_smallest(k, data):
heapify(data)
result = [heappop(data) for j in range(k)]
data.extend(result)
return result
def select_nth(data, n):
if len(data) == 1:
return data[0]
pivot = choice(data)
lhs, rhs = [], []
for elem in data:
(lhs if elem < pivot else rhs).append(elem)
if len(lhs) >= n+1:
return select_nth(lhs, n)
else:
return select_nth(rhs, n - len(lhs))
def selecting_smallest(k, data):
pivot = select_nth(data, k)
return sorted(elem for elem in data if elem <= pivot)[:k]
def partitioning_smallest(n, data):
if len(data) <= 1:
return data
pivot = choice(data)
lhs, rhs = [], []
for elem in data:
(lhs if elem <= pivot else rhs).append(elem)
if n < len(lhs):
return partitioning_smallest(n, lhs)
else:
return sorted(lhs) + partitioning_smallest(n - len(lhs), rhs)
if __name__ == '__main__':
# compare nsmallest implementations
n, k = 100000, 100
print('n: %d\tk: %d' % (n, k))
data = list(map(Int, range(n)))
for f in [nsmallest, heapifying_smallest,
selecting_smallest, partitioning_smallest]:
counts = sorted(count_cmps(f, data, k) for i in range(5))
print(counts, f.__name__)