Welcome, guest | Sign In | My Account | Store | Cart

This function returns a random element from a sequence. The probability for each element in the sequence to be selected can be weighted by a user-provided callable

Python, 102 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
import bisect
import random
import unittest

try:
    xrange
except NameError:
    # Python 3.x
    xrange = range


def weighted_random_choice(seq, weight):
    """Returns a random element from ``seq``. The probability for each element
    ``elem`` in ``seq`` to be selected is weighted by ``weight(elem)``.

    ``seq`` must be an iterable containing more than one element.

    ``weight`` must be a callable accepting one argument, and returning a
    non-negative number. If ``weight(elem)`` is zero, ``elem`` will not be
    considered. 
        
    """ 
    weights = 0
    elems = [] 
    for elem in seq:
        w = weight(elem)     
        try:
            is_neg = w < 0
        except TypeError:    
            raise ValueError("Weight of element '%s' is not a number (%s)" %
                             (elem, w))
        if is_neg:
            raise ValueError("Weight of element '%s' is negative (%s)" %
                             (elem, w))
        if w != 0:               
            try:
                weights += w
            except TypeError:
                raise ValueError("Weight of element '%s' is not a number "
                                 "(%s)" % (elem, w))
            elems.append((weights, elem))
    if not elems:
        raise ValueError("Empty sequence")
    ix = bisect.bisect(elems, (random.uniform(0, weights), None))
    return elems[ix][1]
        

class TestCase(unittest.TestCase):

    def test_empty(self):
        """Empty sequences raise ``ValueError``.

        """
        self.assertRaises(ValueError,
                          weighted_random_choice, [], lambda x: 0)
        self.assertRaises(ValueError,
                          weighted_random_choice, [1, 2, 3], lambda x: 0)

    def test_invalid_weight(self):
        """Invalid weight values are detected.

        """
        self.assertRaises(ValueError,
                          weighted_random_choice, [1, 2, 3], lambda x: "foo")

        class Oops(Exception):
            pass

        def weight(elem):
            raise Oops()

        self.assertRaises(Oops, weighted_random_choice, [1, 2, 3], weight)

    def test_spread(self):
        """Results are consistent with weight function.

        """
        seq = range(0, 100)
        odds, evens = [], []

        bias = 10.0

        def weight(elem):
            if elem % 2:
                return bias
            else:
                return 1

        for _ in xrange(0, 5000):
            elem = weighted_random_choice(seq, weight)
            if elem % 2:
                odds.append(elem)
            else:
                evens.append(elem)

        delta = abs(bias - float(len(odds) / float(len(evens))))
        self.assertTrue(delta < 1)


if __name__ == "__main__":
    random.seed()
    unittest.main()

2 comments

Paddy McCarthy 11 years, 3 months ago  # | flag

It looks similar to probchoice() here. There is also another algorithm coded as probchoice2() on Rosetta Code

Carlos Valiente (author) 11 years, 3 months ago  # | flag

@Paddy: Thanks for the pointer. Yep, basically like probchoice(). The version here does not assume that the sequence is a list -- although it builds one, so for large input lists, probchoice() is a better choice.