Welcome, guest | Sign In | My Account | Store | Cart
"""
wchoice.py -- by bearophile, V.1.0 Oct 30 2006

Weighted choice: like the random.choice() when the probabilities of
the single elements aren't the same.
"""

from random import random
from bisect import bisect
from itertools import izip

def wchoice(objects, frequences, filter=True, normalize=True):
    """wchoice(objects, frequences, filter=True, normalize=True): return
    a function that return the given objects with the specified frequency
    distribution. If no objects with frequency>0 are given, return a
    constant function that return None.

    Input:
      objects: sequence of elements to choose.
      frequences: sequence of their frequences.
      filter=False disables the filtering, speeding up the object creation,
        but less bad cases are controlled. Frequences must be float > 0.
      normalize=False disables the probablitity normalization. The choice
        becomes faster, but sum(frequences) must be 1
    """
    if filter:
        # Test and clean the frequencies.
        if isinstance(frequences, (set, dict)):
            raise "in wchoice: frequences: only ordered sequences."
        if isinstance(objects, (set, dict)):
            raise "in wchoice: objects: only ordered sequences."
        if len(frequences) != len(objects):
            raise "in wchoice: objects and frequences must have the same lenght."
        frequences = map(float, frequences)
        filteredFreq = []
        filteredObj = []
        for freq, obj in izip(frequences, objects):
            if freq < 0:
                raise "in wchoice: only positive frequences."
            elif freq >1e-8:
                filteredFreq.append(freq)
                filteredObj.append(obj)

        if len(filteredFreq) == 0:
            return lambda: None
        if len(filteredFreq) == 1:
            return lambda: filteredObj[0]
        frequences = filteredFreq
        objects = filteredObj
    else:
        if len(objects) == 1:
            return lambda: objects[0]
        # Here objects is unaltered, so it must have a fast __getitem__

    addedFreq = []
    lastSum = 0
    for freq in frequences:
        lastSum += freq
        addedFreq.append(lastSum)

    # If the choice method is called many times, then the frequences
    # can be normalized to sum 1, so instead of random()*self.sumFreq
    # a random() suffices.
    if normalize:
        return lambda rnd=random, bis=bisect: objects[bis(addedFreq, rnd()*lastSum)]
    else:
        return lambda rnd=random, bis=bisect: objects[bis(addedFreq, rnd())]


if __name__ == '__main__':
    print "wchoice tests:"
    objs = "ABCDE"
    freqs = [1, 3, 1.1, 0, 5]
    sumf = sum(freqs)
    wc = wchoice(objs, freqs)
    freq1 = dict.fromkeys(objs, 0)
    nestractions = 100000
    for i in xrange(nestractions):
        freq1[wc()] += 1

    freq2 = sorted(freq1.items())
    freq3 = [sumf*float(v)/nestractions for (k,v) in freq2]

    for (f1,f2) in zip(freq3, freqs):
        print abs(f1-f2),
        assert abs(f1-f2) < 0.05
    print "\n"

    wc = wchoice(["a"], [1])
    assert set(wc() for i in xrange(20000)) == set(["a"])

    wc = wchoice(["a"], [0])
    assert set(wc() for i in xrange(20000)) == set([None])

    wc = wchoice(["a","b"], [0,0])
    assert set(wc() for i in xrange(20000)) == set([None])

    objs = ["A"]
    freqs = [1.5]
    wc = wchoice(objs, freqs, filter=False)
    assert [wc() for _ in xrange(10)] == ["A"] * 10

    objs = "ABCDE"
    freqs = [1, 3, 1.1, 0.1, 5]
    wc = wchoice(objs, freqs, filter=False)
    print [wc() for _ in xrange(50)]

    print "Tests done."

History