Welcome, guest | Sign In | My Account | Store | Cart
import bisect
import random
import unittest

try:
    xrange
except NameError:
    # Python 3.x
    xrange = range


def weighted_random_choice(seq, weight):
    """Returns a random element from ``seq``. The probability for each element
    ``elem`` in ``seq`` to be selected is weighted by ``weight(elem)``.

    ``seq`` must be an iterable containing more than one element.

    ``weight`` must be a callable accepting one argument, and returning a
    non-negative number. If ``weight(elem)`` is zero, ``elem`` will not be
    considered. 
        
    """ 
    weights = 0
    elems = [] 
    for elem in seq:
        w = weight(elem)     
        try:
            is_neg = w < 0
        except TypeError:    
            raise ValueError("Weight of element '%s' is not a number (%s)" %
                             (elem, w))
        if is_neg:
            raise ValueError("Weight of element '%s' is negative (%s)" %
                             (elem, w))
        if w != 0:               
            try:
                weights += w
            except TypeError:
                raise ValueError("Weight of element '%s' is not a number "
                                 "(%s)" % (elem, w))
            elems.append((weights, elem))
    if not elems:
        raise ValueError("Empty sequence")
    ix = bisect.bisect(elems, (random.uniform(0, weights), None))
    return elems[ix][1]
        

class TestCase(unittest.TestCase):

    def test_empty(self):
        """Empty sequences raise ``ValueError``.

        """
        self.assertRaises(ValueError,
                          weighted_random_choice, [], lambda x: 0)
        self.assertRaises(ValueError,
                          weighted_random_choice, [1, 2, 3], lambda x: 0)

    def test_invalid_weight(self):
        """Invalid weight values are detected.

        """
        self.assertRaises(ValueError,
                          weighted_random_choice, [1, 2, 3], lambda x: "foo")

        class Oops(Exception):
            pass

        def weight(elem):
            raise Oops()

        self.assertRaises(Oops, weighted_random_choice, [1, 2, 3], weight)

    def test_spread(self):
        """Results are consistent with weight function.

        """
        seq = range(0, 100)
        odds, evens = [], []

        bias = 10.0

        def weight(elem):
            if elem % 2:
                return bias
            else:
                return 1

        for _ in xrange(0, 5000):
            elem = weighted_random_choice(seq, weight)
            if elem % 2:
                odds.append(elem)
            else:
                evens.append(elem)

        delta = abs(bias - float(len(odds) / float(len(evens))))
        self.assertTrue(delta < 1)


if __name__ == "__main__":
    random.seed()
    unittest.main()

History