Welcome, guest | Sign In | My Account | Store | Cart

The union find data structure is primarily used for Kruskal's Minimum Spanning Tree algorithm, though can be used whenever one only needs to determine of two items are in the same set, and be able to combine sets quickly.

Python, 136 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
'''
unionfind.py

A class that implements the Union Find data structure and algorithm.  This
data structure allows one to find out which set an object belongs to, as well
as join two sets.

The algorithm's performance, given m union/find operations of any ordering, on
n elements has been shown to take log* time per operation, where log* is
pronounced log-star, and is the INVERSE of what is known as the Ackerman
function, which is given below:
A(0) = 1
A(n) = 2**A(n-1)

I include the functions to be complete.  Note that we can be 'inefficient'
when performing the inverse ackerman function, as it will only take a maximum
of 6 iterations to perform; A(5) is 65536 binary digits long (a 1 with 65535
zeroes following).  A(6) is 2**65536 binary digits long, and cannot be
represented by the memory of the entire universe.


The Union Find data structure is not a universal set implementation, but can
tell you if two objects are in the same set, in different sets, or you can
combine two sets.
ufset.find(obja) == ufset.find(objb)
ufset.find(obja) != ufset.find(objb)
ufset.union(obja, objb)


This algorithm and data structure are primarily used for Kruskal's Minimum
Spanning Tree algorithm for graphs, but other uses have been found.

August 12, 2003 Josiah Carlson
'''

def Ackerman(inp, memo={0:1}):
	inp = max(int(inp), 0)
	if inp in memo:
		return memo[inp]
	elif inp <= 5:
		memo[inp] = 2**ackerman(inp-1)
		return memo[inp]
	else:
		print "Such a number is not representable by all the subatomic\nparticles in the universe."
		ackerman(4);
		out = (inp-4)*"2**" + str(memo[4])
		print out
		raise Exception, "NumberCannotBeRepresentedByAllSubatomicParticlesInUniverse"

def inverseAckerman(inp):
    t = 0
    while Ackerman(t) < inp:
        t += 1
    return t


class UnionFind:
    def __init__(self):
        '''\
Create an empty union find data structure.'''
        self.num_weights = {}
        self.parent_pointers = {}
        self.num_to_objects = {}
        self.objects_to_num = {}
        self.__repr__ = self.__str__
    def insert_objects(self, objects):
        '''\
Insert a sequence of objects into the structure.  All must be Python hashable.'''
        for object in objects:
            self.find(object);
    def find(self, object):
        '''\
Find the root of the set that an object is in.
If the object was not known, will make it known, and it becomes its own set.
Object must be Python hashable.'''
        if not object in self.objects_to_num:
            obj_num = len(self.objects_to_num)
            self.num_weights[obj_num] = 1
            self.objects_to_num[object] = obj_num
            self.num_to_objects[obj_num] = object
            self.parent_pointers[obj_num] = obj_num
            return object
        stk = [self.objects_to_num[object]]
        par = self.parent_pointers[stk[-1]]
        while par != stk[-1]:
            stk.append(par)
            par = self.parent_pointers[par]
        for i in stk:
            self.parent_pointers[i] = par
        return self.num_to_objects[par]
    def union(self, object1, object2):
        '''\
Combine the sets that contain the two objects given.
Both objects must be Python hashable.
If either or both objects are unknown, will make them known, and combine them.'''
        o1p = self.find(object1)
        o2p = self.find(object2)
        if o1p != o2p:
            on1 = self.objects_to_num[o1p]
            on2 = self.objects_to_num[o2p]
            w1 = self.num_weights[on1]
            w2 = self.num_weights[on2]
            if w1 < w2:
                o1p, o2p, on1, on2, w1, w2 = o2p, o1p, on2, on1, w2, w1
            self.num_weights[on1] = w1+w2
            del self.num_weights[on2]
            self.parent_pointers[on2] = on1
    def __str__(self):
        '''\
Included for testing purposes only.
All information needed from the union find data structure can be attained
using find.'''
        sets = {}
        for i in xrange(len(self.objects_to_num)):
            sets[i] = []
        for i in self.objects_to_num:
            sets[self.objects_to_num[self.find(i)]].append(i)
        out = []
        for i in sets.itervalues():
            if i:
                out.append(repr(i))
        return ', '.join(out)

if __name__ == '__main__':
    print "Testing..."
    uf = UnionFind()
    az = "abcdefghijklmnopqrstuvwxyz"
    az += az.upper()
    uf.insert_objects(az)
    import random
    cnt = 0
    while len(uf.num_weights) > 20:
        cnt += 1
        uf.union(random.choice(az), random.choice(az))
    print uf, cnt
    print "Testing complete."

One could also use the below using standard dictionaries to do basically the same thing (though it doesn't actually return an object from find(), it returns the memory location of the dictionary/set the object belongs to).

<pre> class straightforward: def __init__(self): self.objects = {} self.count = 0 def insert_objects(self, objects): for i in objects: self.find(i) def find(self, a): if not a in self.objects: self.objects[a] = {a:1} self.count += 1 return id(self.objects[a]) def union(self, a, b): if self.find(a) != self.find(b): la = len(self.objects[a]) lb = len(self.objects[b]) if la > lb: a, b = b, a self.objects[b].update(self.objects[a]) self.objects[a] = self.objects[b] self.count -= 1 def __str__(self): outp = {} for i in self.objects.itervalues(): outp[id(i)] = i out = [] for i in outp.values(): out.append(str(i.keys())) return ', '.join(out) </pre>

What about timings? <pre> UnionFind: create 1.953 union 1.781 total 3.734 straightforward: create 3.157 union 1.171 total 4.328 </pre> These were tested on a PII-400 using 100,000 objects, unioning random pairs of objects 25,000 times (pairs were the same for both structures, and union times were consistent per item).

This seems like a clear win for the straightforward method. However, considering that dict.update is a pure C function, and STILL is only about 33% faster than pure Python on a large number of joins. Makes me wonder how fast UnionFind would be in C.

8 comments

Josiah Carlson (author) 18 years, 3 months ago  # | flag

For the actual data structure application, it is faster. When counting the number of dictionary lookups (access and insertion), the union-find structure completely annhilliates the straightforward approach when we end up joining every individual set to the one larger set, as is appropriate for kruskal's minimum spanning tree algorithm.

This is because while each operation on the union-find structure is known to be O(inverse_ackerman(n)), each operation using the straightforward approach is O(n).

Even with the C dict.update(), using the union-find structure is FAR faster.

David Eppstein 18 years, 2 months ago  # | flag

Looks useful. I need a union-find structure for Edmonds' blossom-contraction algorithm for maximum matching in general graphs (unfortunately the graphs I need this for are not bipartite so I can't use my previous recipe) and this looks like a good implementation. I like the way you set up the API to allow arbitrary hashables to take part in the structure.

A few comments, though:

  • What is the point of converting objects to numbers and back again? Does the performance gain really outweigh the added code complexity? And if you're going to do that, why do you use the numbers to index into a dict instead of into a list?

  • Maybe it would be more pythonic to use __getitem__ instead of find? That is, instead of calling UF.find(x), simply write UF[x] to find the set associated with x, just like other data structures (e.g. dict) use the same notation for other values associated with x? I can't think of a similar neat and understandable syntax for union, though.

  • You appear to have mixed up log* and alpha. log*(n) is the minimum height of a tower of powers of 2, 2^(2^(...)) that is at least n. alpha is the inverse of the Ackermann function, which can be defined by several recurrences but not the one you give -- a typical one is

    A(1,x) = 2x A(x,1) = 2 A(x,y) = A(x-1,A(x,y-1))

Although both alpha and log* are incredibly slowly growing, alpha is much slower than log*.

Josiah Carlson (author) 18 years, 2 months ago  # | flag

Why convert? More pythonic? - You make a good point about conversions between objects and integers. I used dictionaries purely out of personal preference. While parent pointers make little sense in using dictionaries, keeping a dictionary of child pointers end up helping considerably when I went about implementing the 'delete' operation for the Blossom Contraction algorithm you emailed me about (which is included below). In using dictionaries for child pointers, delete can be ammortized to O(1) per deletion, paid for by earlier union and find operations.

  • In terms of being more pythonic, certainly using __getitem__ would be significantly more pythonic, no arguments here. I wrote them in terms of the operations performed on the structure (union, find) because it is explicit. It doesn't make sense to me to make something more pythonic, when it doesn't make sense to UF[x] = y. The addition of delete below, could also has a python equivalent __delitem__, but again, I shy away from using too many accessors when the structure is so different from what is offered in standard python, and doesn't help with understanding the use of the structure.

  • My mistake.

    class BlossomContraction: def __init__(self): '''\ Create an empty blossom contraction data structure.''' self.num_weights = {} self.parent_pointers = {} self.num_to_objects = {} self.objects_to_num = {} self.child_pointers = {} self.next = -1 self.setcount = 0 self.__repr__ = self.__str__ def insert_objects(self, objects): '''\ Insert a sequence of objects into the structure. All must be Python hashable.''' for object in objects: self.find(object); def find(self, object): '''\ Find the root of the set that an object is in. If the object was not known, will make it known, and it becomes its own set. Object must be Python hashable.''' if not object in self.objects_to_num: self.next += 1 obj_num = self.next self.num_weights[obj_num] = 1 self.objects_to_num[object] = obj_num self.num_to_objects[obj_num] = object self.parent_pointers[obj_num] = obj_num self.child_pointers[obj_num] = {} self.setcount += 1 return object stk = [self.objects_to_num[object]] par = self.parent_pointers[stk[-1]] while par != stk[-1]: del self.child_pointers[par][stk[-1]] stk.append(par) par = self.parent_pointers[par] for i in xrange(1,len(stk)-1): self.num_weights[stk[i]] -= i par = stk.pop() for i in stk:

(comment continued...)

Josiah Carlson (author) 18 years, 2 months ago  # | flag

(...continued from previous comment)

            self.parent_pointers[i] = par
            self.child_pointers[par][i] = None
        return self.num_to_objects[par]
    def union(self, object1, object2):
        '''\
Combine the sets that contain the two objects given.
Both objects must be Python hashable.
If either or both objects are unknown, will make them known, and combine them.'''
        o1p = self.find(object1)
        o2p = self.find(object2)
        if o1p != o2p:
            on1 = self.objects_to_num[o1p]
            on2 = self.objects_to_num[o2p]
            w1 = self.num_weights[on1]
            w2 = self.num_weights[on2]
            if w2 > w1:
                o1p, o2p, on1, on2, w1, w2 = o2p, o1p, on2, on1, w2, w1
            self.num_weights[on1] = w1+w2
            self.parent_pointers[on2] = on1
            self.child_pointers[on1][on2] = None
            self.setcount -= 1
    def delete(self, object):
        '''\
Remove an object from the Blossom Contraction structure.
Object does not necessarily need to be a root.'''
        if not object in self.objects_to_num:
            return
        rootn = self.objects_to_num[self.find(object)]
        objectn = self.objects_to_num[object]
        #uncomment the following line to require object to be a root.
        #assert objectn == rootn
        self.num_weights[rootn] -= self.num_weights[objectn]
        for child in self.child_pointers[objectn]:
            self.parent_pointers[child] = child
        self.setcount += len(self.child_pointers[objectn])-1
        del self.parent_pointers[objectn]
        del self.child_pointers[objectn]
        del self.num_weights[objectn]
        del self.num_to_objects[objectn]
        del self.objects_to_num[object]
    def __str__(self):
        '''\
Included for testing purposes only.
All information needed from the union find data structure can be attained
using find.'''
        sets = {}
        for i in self.parent_pointers:
            if i == self.parent_pointers[i]:
                sets[i] = []
        for i in self.objects_to_num:
            sets[self.objects_to_num[self.find(i)]].append(i)
        out = []
        for i in sets.itervalues():
            out.append(repr(i))
        return ', '.join(out)
Josiah Carlson (author) 17 years, 11 months ago  # | flag

Note on blossom contraction... This doesn't actually implement blossom contraction. It implements a related algorithm and structure, but it isn't really useful.

Alain Pointdexter 16 years, 10 months ago  # | flag

help the needy ;-). All well and fine... but where is the useful recipe, i mean the kruskal algorithm ?

Ryan Coleman 15 years, 3 months ago  # | flag

straightforward method buggy. so, i could be wrong here, but i think i'm not, so i thought i'd share... feel free to correct.

hey, maybe the reason the straightforward method wins in timings is that it is fundamentally broken. you have to reset all members of the list to point to the new, larger list, which you don't do...

self.objects[b].update(self.objects[a])
self.objects[a] = self.objects[b]

these lines are the problem. well, the first is okay, but the second one needs to iterate through the old members of self.objects[a] and set them all to point to the new bigger list.

Peter Wood 13 years, 9 months ago  # | flag

straightforward method. I have had erroneous results with the straightforward method. It is better to use the original method.

Created by Josiah Carlson on Wed, 13 Aug 2003 (PSF)
Python recipes (4591)
Josiah Carlson's recipes (9)

Required Modules

Other Information and Tasks