This function returns a random element from a sequence. The probability for each element in the sequence to be selected can be weighted by a user-provided callable
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | import bisect
import random
import unittest
try:
xrange
except NameError:
# Python 3.x
xrange = range
def weighted_random_choice(seq, weight):
"""Returns a random element from ``seq``. The probability for each element
``elem`` in ``seq`` to be selected is weighted by ``weight(elem)``.
``seq`` must be an iterable containing more than one element.
``weight`` must be a callable accepting one argument, and returning a
non-negative number. If ``weight(elem)`` is zero, ``elem`` will not be
considered.
"""
weights = 0
elems = []
for elem in seq:
w = weight(elem)
try:
is_neg = w < 0
except TypeError:
raise ValueError("Weight of element '%s' is not a number (%s)" %
(elem, w))
if is_neg:
raise ValueError("Weight of element '%s' is negative (%s)" %
(elem, w))
if w != 0:
try:
weights += w
except TypeError:
raise ValueError("Weight of element '%s' is not a number "
"(%s)" % (elem, w))
elems.append((weights, elem))
if not elems:
raise ValueError("Empty sequence")
ix = bisect.bisect(elems, (random.uniform(0, weights), None))
return elems[ix][1]
class TestCase(unittest.TestCase):
def test_empty(self):
"""Empty sequences raise ``ValueError``.
"""
self.assertRaises(ValueError,
weighted_random_choice, [], lambda x: 0)
self.assertRaises(ValueError,
weighted_random_choice, [1, 2, 3], lambda x: 0)
def test_invalid_weight(self):
"""Invalid weight values are detected.
"""
self.assertRaises(ValueError,
weighted_random_choice, [1, 2, 3], lambda x: "foo")
class Oops(Exception):
pass
def weight(elem):
raise Oops()
self.assertRaises(Oops, weighted_random_choice, [1, 2, 3], weight)
def test_spread(self):
"""Results are consistent with weight function.
"""
seq = range(0, 100)
odds, evens = [], []
bias = 10.0
def weight(elem):
if elem % 2:
return bias
else:
return 1
for _ in xrange(0, 5000):
elem = weighted_random_choice(seq, weight)
if elem % 2:
odds.append(elem)
else:
evens.append(elem)
delta = abs(bias - float(len(odds) / float(len(evens))))
self.assertTrue(delta < 1)
if __name__ == "__main__":
random.seed()
unittest.main()
|
It looks similar to probchoice() here. There is also another algorithm coded as probchoice2() on Rosetta Code
@Paddy: Thanks for the pointer. Yep, basically like
probchoice()
. The version here does not assume that the sequence is a list -- although it builds one, so for large input lists,probchoice()
is a better choice.