Welcome, guest | Sign In | My Account | Store | Cart
# Makes the KD-Tree far fast lookup
def make_kd_tree(points, dim, i=0):
    if len(points) > 1:
        points.sort(key=lambda x: x[i])
        i = (i + 1) % dim
        half = len(points) >> 1
        return (
            make_kd_tree(points[: half], dim, i),
            make_kd_tree(points[half + 1:], dim, i),
            points[half])
    elif len(points) == 1:
        return (None, None, points[0])

# K nearest neighbors. The heap is a bounded priority queue.
def get_knn(kd_node, point, k, dim, dist_func, return_distances=False, i=0, heap=None):
    import heapq
    is_root = not heap
    if is_root:
        heap = []
    if kd_node:
        dist = dist_func(point, kd_node[2])
        dx = kd_node[2][i] - point[i]
        if len(heap) < k:
            heapq.heappush(heap, (-dist, kd_node[2]))
        elif dist < -heap[0][0]:
            heapq.heappushpop(heap, (-dist, kd_node[2]))
        i = (i + 1) % dim
        # Goes into the left branch, and then the right branch if needed
        get_knn(kd_node[dx < 0], point, k, dim,
                dist_func, return_distances, i, heap)
        # -heap[0][0] is the largest distance in the heap
        if dx * dx < -heap[0][0]:
            get_knn(kd_node[dx >= 0], point, k, dim,
                    dist_func, return_distances, i, heap)
    if is_root:
        neighbors = sorted((-h[0], h[1]) for h in heap)
        return neighbors if return_distances else [n[1] for n in neighbors]

# For the closest neighbor
def get_nearest(kd_node, point, dim, dist_func, return_distances=False, i=0, best=None):
    if kd_node:
        dist = dist_func(point, kd_node[2])
        dx = kd_node[2][i] - point[i]
        if not best:
            best = [dist, kd_node[2]]
        elif dist < best[0]:
            best[0], best[1] = dist, kd_node[2]
        i = (i + 1) % dim
        # Goes into the left branch, and then the right branch if needed
        get_nearest(
            kd_node[dx < 0], point, dim, dist_func, return_distances, i, best)
        if dx * dx < best[0]:
            get_nearest(
                kd_node[dx >= 0], point, dim, dist_func, return_distances, i, best)
    return best if return_distances else best[1]




""" Usage """

import random

def rand_point(dim):
    return [random.uniform(-1, 1) for d in range(dim)]

dim = 3  # 3 dimensions
points = [rand_point(dim) for x in range(5000)]  # 5k random points
kd_tree = make_kd_tree(points=points, dim=dim) # make the kd tree

# If you need labeled points, checkout my other recipe on adding attributes to python list
# https://code.activestate.com/recipes/users/4192908/

print get_knn(
    kd_node=kd_tree, 
    point=[0] * dim, 
    k=8, 
    dim=dim, 
    dist_func=lambda a, b: sum((a[i] - b[i]) ** 2 for i in xrange(dim))) # Euclidean distance

print

print get_nearest(
    kd_node=kd_tree,
    point=[0] * dim,
    dim=dim,
    dist_func=lambda a, b: sum((a[i] - b[i]) ** 2 for i in xrange(dim))) # Euclidean distance

History