A very simple and concise KD-tree for points in python.
For labeled points, you may want to check out my other recipe: Python add/set attributes to list
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | # Makes the KD-Tree far fast lookup
def make_kd_tree(points, dim, i=0):
if len(points) > 1:
points.sort(key=lambda x: x[i])
i = (i + 1) % dim
half = len(points) >> 1
return (
make_kd_tree(points[: half], dim, i),
make_kd_tree(points[half + 1:], dim, i),
points[half])
elif len(points) == 1:
return (None, None, points[0])
# K nearest neighbors. The heap is a bounded priority queue.
def get_knn(kd_node, point, k, dim, dist_func, return_distances=False, i=0, heap=None):
import heapq
is_root = not heap
if is_root:
heap = []
if kd_node:
dist = dist_func(point, kd_node[2])
dx = kd_node[2][i] - point[i]
if len(heap) < k:
heapq.heappush(heap, (-dist, kd_node[2]))
elif dist < -heap[0][0]:
heapq.heappushpop(heap, (-dist, kd_node[2]))
i = (i + 1) % dim
# Goes into the left branch, and then the right branch if needed
get_knn(kd_node[dx < 0], point, k, dim,
dist_func, return_distances, i, heap)
# -heap[0][0] is the largest distance in the heap
if dx * dx < -heap[0][0]:
get_knn(kd_node[dx >= 0], point, k, dim,
dist_func, return_distances, i, heap)
if is_root:
neighbors = sorted((-h[0], h[1]) for h in heap)
return neighbors if return_distances else [n[1] for n in neighbors]
# For the closest neighbor
def get_nearest(kd_node, point, dim, dist_func, return_distances=False, i=0, best=None):
if kd_node:
dist = dist_func(point, kd_node[2])
dx = kd_node[2][i] - point[i]
if not best:
best = [dist, kd_node[2]]
elif dist < best[0]:
best[0], best[1] = dist, kd_node[2]
i = (i + 1) % dim
# Goes into the left branch, and then the right branch if needed
get_nearest(
kd_node[dx < 0], point, dim, dist_func, return_distances, i, best)
if dx * dx < best[0]:
get_nearest(
kd_node[dx >= 0], point, dim, dist_func, return_distances, i, best)
return best if return_distances else best[1]
""" Usage """
import random
def rand_point(dim):
return [random.uniform(-1, 1) for d in range(dim)]
dim = 3 # 3 dimensions
points = [rand_point(dim) for x in range(5000)] # 5k random points
kd_tree = make_kd_tree(points=points, dim=dim) # make the kd tree
# If you need labeled points, checkout my other recipe on adding attributes to python list
# https://code.activestate.com/recipes/users/4192908/
print get_knn(
kd_node=kd_tree,
point=[0] * dim,
k=8,
dim=dim,
dist_func=lambda a, b: sum((a[i] - b[i]) ** 2 for i in xrange(dim))) # Euclidean distance
print
print get_nearest(
kd_node=kd_tree,
point=[0] * dim,
dim=dim,
dist_func=lambda a, b: sum((a[i] - b[i]) ** 2 for i in xrange(dim))) # Euclidean distance
|