import bisect
import random
import unittest
try:
xrange
except NameError:
# Python 3.x
xrange = range
def weighted_random_choice(seq, weight):
"""Returns a random element from ``seq``. The probability for each element
``elem`` in ``seq`` to be selected is weighted by ``weight(elem)``.
``seq`` must be an iterable containing more than one element.
``weight`` must be a callable accepting one argument, and returning a
non-negative number. If ``weight(elem)`` is zero, ``elem`` will not be
considered.
"""
weights = 0
elems = []
for elem in seq:
w = weight(elem)
try:
is_neg = w < 0
except TypeError:
raise ValueError("Weight of element '%s' is not a number (%s)" %
(elem, w))
if is_neg:
raise ValueError("Weight of element '%s' is negative (%s)" %
(elem, w))
if w != 0:
try:
weights += w
except TypeError:
raise ValueError("Weight of element '%s' is not a number "
"(%s)" % (elem, w))
elems.append((weights, elem))
if not elems:
raise ValueError("Empty sequence")
ix = bisect.bisect(elems, (random.uniform(0, weights), None))
return elems[ix][1]
class TestCase(unittest.TestCase):
def test_empty(self):
"""Empty sequences raise ``ValueError``.
"""
self.assertRaises(ValueError,
weighted_random_choice, [], lambda x: 0)
self.assertRaises(ValueError,
weighted_random_choice, [1, 2, 3], lambda x: 0)
def test_invalid_weight(self):
"""Invalid weight values are detected.
"""
self.assertRaises(ValueError,
weighted_random_choice, [1, 2, 3], lambda x: "foo")
class Oops(Exception):
pass
def weight(elem):
raise Oops()
self.assertRaises(Oops, weighted_random_choice, [1, 2, 3], weight)
def test_spread(self):
"""Results are consistent with weight function.
"""
seq = range(0, 100)
odds, evens = [], []
bias = 10.0
def weight(elem):
if elem % 2:
return bias
else:
return 1
for _ in xrange(0, 5000):
elem = weighted_random_choice(seq, weight)
if elem % 2:
odds.append(elem)
else:
evens.append(elem)
delta = abs(bias - float(len(odds) / float(len(evens))))
self.assertTrue(delta < 1)
if __name__ == "__main__":
random.seed()
unittest.main()