Welcome, guest | Sign In | My Account | Store | Cart

O(n) quicksort style algorithm for looking up data based on rank order. Useful for finding medians, percentiles, quartiles, and deciles. Equivalent to data[n] when the data is already sorted.

Python, 26 lines
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import random

def select(data, n):
    "Find the nth rank ordered element (the least value has rank 0)."
    data = list(data)
    if not 0 <= n < len(data):
        raise ValueError('not enough elements for the given rank')
    while True:
        pivot = random.choice(data)
        pcount = 0
        under, over = [], []
        uappend, oappend = under.append, over.append
        for elem in data:
            if elem < pivot:
                uappend(elem)
            elif elem > pivot:
                oappend(elem)
            else:
                pcount += 1
        if n < len(under):
            data = under
        elif n < len(under) + pcount:
            return pivot
        else:
            data = over
            n -= len(under) + pcount

The input data can be any iterable.<pre></pre> The randomization of pivots makes the algorithm perform consistently even with unfavorable data orderings (the same kind that wreak havoc on quicksort). Makes approximately lg2(N) calls to random.choice().<pre></pre> Revised to include the pivot counts after David Eppstein pointed out that the originally posted algorithm ran slowly when all the inputs were equal.

8 comments

David Eppstein 20 years, 2 months ago  # | flag

Not always O(n). This takes quadratic time when all items are equal, no?

Shane Holloway 20 years, 2 months ago  # | flag

O(n) for a sequence of identical values.

# Changed the algorithm slightly to count the number of pivot entries
# so that O(n) can be achieved for identical values.  Also note that
# there is an early out condition when you ask for the Nth value, and
# len(data) 3:
            swap = random.randrange(len(data))
            data[0], data[swap] = data[swap], data[0]

        it = iter(data)
        pivot, piviotcount = it.next(), 1
        under, over = [], []
        ua, oa = under.append, over.append
        for elem in it:
            r = cmp(elem, pivot)
            if r <pre>

# Changed the algorithm slightly to count the number of pivot entries
# so that O(n) can be achieved for identical values.  Also note that
# there is an early out condition when you ask for the Nth value, and
# len(data) 3:
            swap = random.randrange(len(data))
            data[0], data[swap] = data[swap], data[0]

        it = iter(data)
        pivot, piviotcount = it.next(), 1
        under, over = [], []
        ua, oa = under.append, over.append
        for elem in it:
            r = cmp(elem, pivot)
            if r

</pre>

Shane Holloway 20 years, 2 months ago  # | flag

Invalid previous post. Don't know what's up, but I can't post the entire modifed algorithm...

Raymond Hettinger (author) 20 years, 2 months ago  # | flag

Using cmp(). Even within the <pre> block, some additional markup is necessary to post code that includes less than or greater than comparisons. The markup is &lt; and &gt; respectively.

The idea you were starting to post looks like an attempt to use cmp() to avoid doing two comparisons. My benchmarks show that this runs more slowly than using two comparisons but that may be dependent on the type of data being selected.

David Eppstein 20 years, 2 months ago  # | flag

My previous comment was for v1.2. v1.3 seems to fix the worst case performance, but at the expense of a larger number of comparisons. I wonder if something like the following code would be faster. I'm not sure the ops[cmp(...)] trickery is very Pythonic, though, especially as cmp doesn't seem to be guaranteed to return -1,0,1.

import random

def select(data, n):
    "Find the nth rank ordered element (the least value has rank 0)."
    data = list(data)
    try:
        while True:
            pivot = random.choice(data)
            under, over, match = [], [], []
            ops = match.append, over.append, under.append
            for elem in data:
                ops[cmp(elem,pivot)](elem)
            if n &lt; len(under):
                data = under
            else:
                lowcount = len(under) + len(match)
                if n &lt; lowcount:
                    return pivot
                else:
                    data = over
                    n -= lowcount
    except IndexError:
        raise ValueError('not enough elements for the given rank')
Raymond Hettinger (author) 20 years, 2 months ago  # | flag

Using cmp() and computed append reference. This approach seems inspired but does not win in timings with integer data. Ultimately, it should win whenever comparisons are very expensive and the data objects implement a clever __cmp__() method (otherwise, the builtin cmp() function will end-up making two comparisons to differentiate the three cases).

Chris Perkins 20 years, 1 month ago  # | flag

Why? I'm unclear as to what kind of data you would use this for. I did some very simple tests, on lists of integers and strings, and in all cases that I tried, the obvious solution runs faster:

def select(data, n):
    data = list(data)
    data.sort()
    return data[n]

Can you give some guidelines as to when your solution would be better?

Raymond Hettinger (author) 20 years, 1 month ago  # | flag

Why use select()? The difference between an O(n) algorithm and an O(n lg n) algorithm becomes more pronounced with long list lengths. This is particularly important when the constant factor is large (perhaps due to an expensive compare function).

For instance, try generating a million and one random points and find the median point using a custom ordering function (dist, angle):

class Point:
    def __init__(self, x, y):
        self.x = x
        self.y = y
    def __cmp__(self, other):
        # order by distance from orig and then by angle from the x-axis
        c = cmp(self.x ** 2 + self.y ** 2, other.x ** 2 + other.y ** 2)
        if c != 0:
            return c
        if self.x == 0 and self.y == 0:
            return 0
        return cmp(math.atan2(self.y, self.x), math.atan2(other.y, other.x)

Also, when you run comparisions with sort(), be sure to randomize the data beforehand (sort() will take shortcuts if the data is already partially ordered).