Welcome, guest | Sign In | My Account | Store | Cart

To choose among a list of objects using a specified frequency distribution.

Python, 108 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
"""
wchoice.py -- by bearophile, V.1.0 Oct 30 2006

Weighted choice: like the random.choice() when the probabilities of
the single elements aren't the same.
"""

from random import random
from bisect import bisect
from itertools import izip

def wchoice(objects, frequences, filter=True, normalize=True):
    """wchoice(objects, frequences, filter=True, normalize=True): return
    a function that return the given objects with the specified frequency
    distribution. If no objects with frequency>0 are given, return a
    constant function that return None.

    Input:
      objects: sequence of elements to choose.
      frequences: sequence of their frequences.
      filter=False disables the filtering, speeding up the object creation,
        but less bad cases are controlled. Frequences must be float > 0.
      normalize=False disables the probablitity normalization. The choice
        becomes faster, but sum(frequences) must be 1
    """
    if filter:
        # Test and clean the frequencies.
        if isinstance(frequences, (set, dict)):
            raise "in wchoice: frequences: only ordered sequences."
        if isinstance(objects, (set, dict)):
            raise "in wchoice: objects: only ordered sequences."
        if len(frequences) != len(objects):
            raise "in wchoice: objects and frequences must have the same lenght."
        frequences = map(float, frequences)
        filteredFreq = []
        filteredObj = []
        for freq, obj in izip(frequences, objects):
            if freq < 0:
                raise "in wchoice: only positive frequences."
            elif freq >1e-8:
                filteredFreq.append(freq)
                filteredObj.append(obj)

        if len(filteredFreq) == 0:
            return lambda: None
        if len(filteredFreq) == 1:
            return lambda: filteredObj[0]
        frequences = filteredFreq
        objects = filteredObj
    else:
        if len(objects) == 1:
            return lambda: objects[0]
        # Here objects is unaltered, so it must have a fast __getitem__

    addedFreq = []
    lastSum = 0
    for freq in frequences:
        lastSum += freq
        addedFreq.append(lastSum)

    # If the choice method is called many times, then the frequences
    # can be normalized to sum 1, so instead of random()*self.sumFreq
    # a random() suffices.
    if normalize:
        return lambda rnd=random, bis=bisect: objects[bis(addedFreq, rnd()*lastSum)]
    else:
        return lambda rnd=random, bis=bisect: objects[bis(addedFreq, rnd())]


if __name__ == '__main__':
    print "wchoice tests:"
    objs = "ABCDE"
    freqs = [1, 3, 1.1, 0, 5]
    sumf = sum(freqs)
    wc = wchoice(objs, freqs)
    freq1 = dict.fromkeys(objs, 0)
    nestractions = 100000
    for i in xrange(nestractions):
        freq1[wc()] += 1

    freq2 = sorted(freq1.items())
    freq3 = [sumf*float(v)/nestractions for (k,v) in freq2]

    for (f1,f2) in zip(freq3, freqs):
        print abs(f1-f2),
        assert abs(f1-f2) < 0.05
    print "\n"

    wc = wchoice(["a"], [1])
    assert set(wc() for i in xrange(20000)) == set(["a"])

    wc = wchoice(["a"], [0])
    assert set(wc() for i in xrange(20000)) == set([None])

    wc = wchoice(["a","b"], [0,0])
    assert set(wc() for i in xrange(20000)) == set([None])

    objs = ["A"]
    freqs = [1.5]
    wc = wchoice(objs, freqs, filter=False)
    assert [wc() for _ in xrange(10)] == ["A"] * 10

    objs = "ABCDE"
    freqs = [1, 3, 1.1, 0.1, 5]
    wc = wchoice(objs, freqs, filter=False)
    print [wc() for _ in xrange(50)]

    print "Tests done."

Probably some scientific libraries do this already (but this is pure python). In the beginning this was a class, but this version that returns a function is faster, shorter and simpler. I don't know if there is a faster way to conpute the addedFreq list. It uses a binary search with the bisect, but doing a quantization of the frequences, this algorithm can be approximated with an O(1) one. I don't know if a C-based bisect can speed up this code some.