Welcome, guest | Sign In | My Account | Store | Cart
# Makes the KD-Tree far fast lookup
def make_kd_tree(points, dim, i=0):
   
if len(points) > 1:
        points
.sort(key=lambda x: x[i])
        i
= (i + 1) % dim
        half
= len(points) >> 1
       
return (
            make_kd_tree
(points[: half], dim, i),
            make_kd_tree
(points[half + 1:], dim, i),
            points
[half])
   
elif len(points) == 1:
       
return (None, None, points[0])

# K nearest neighbors. The heap is a bounded priority queue.
def get_knn(kd_node, point, k, dim, dist_func, return_distances=False, i=0, heap=None):
   
import heapq
    is_root
= not heap
   
if is_root:
        heap
= []
   
if kd_node:
        dist
= dist_func(point, kd_node[2])
        dx
= kd_node[2][i] - point[i]
       
if len(heap) < k:
            heapq
.heappush(heap, (-dist, kd_node[2]))
       
elif dist < -heap[0][0]:
            heapq
.heappushpop(heap, (-dist, kd_node[2]))
        i
= (i + 1) % dim
       
# Goes into the left branch, and then the right branch if needed
        get_knn
(kd_node[dx < 0], point, k, dim,
                dist_func
, return_distances, i, heap)
       
# -heap[0][0] is the largest distance in the heap
       
if dx * dx < -heap[0][0]:
            get_knn
(kd_node[dx >= 0], point, k, dim,
                    dist_func
, return_distances, i, heap)
   
if is_root:
        neighbors
= sorted((-h[0], h[1]) for h in heap)
       
return neighbors if return_distances else [n[1] for n in neighbors]

# For the closest neighbor
def get_nearest(kd_node, point, dim, dist_func, return_distances=False, i=0, best=None):
   
if kd_node:
        dist
= dist_func(point, kd_node[2])
        dx
= kd_node[2][i] - point[i]
       
if not best:
            best
= [dist, kd_node[2]]
       
elif dist < best[0]:
            best
[0], best[1] = dist, kd_node[2]
        i
= (i + 1) % dim
       
# Goes into the left branch, and then the right branch if needed
        get_nearest
(
            kd_node
[dx < 0], point, dim, dist_func, return_distances, i, best)
       
if dx * dx < best[0]:
            get_nearest
(
                kd_node
[dx >= 0], point, dim, dist_func, return_distances, i, best)
   
return best if return_distances else best[1]




""" Usage """

import random

def rand_point(dim):
   
return [random.uniform(-1, 1) for d in range(dim)]

dim
= 3  # 3 dimensions
points
= [rand_point(dim) for x in range(5000)]  # 5k random points
kd_tree
= make_kd_tree(points=points, dim=dim) # make the kd tree

# If you need labeled points, checkout my other recipe on adding attributes to python list
# https://code.activestate.com/recipes/users/4192908/

print get_knn(
    kd_node
=kd_tree,
    point
=[0] * dim,
    k
=8,
    dim
=dim,
    dist_func
=lambda a, b: sum((a[i] - b[i]) ** 2 for i in xrange(dim))) # Euclidean distance

print

print get_nearest(
    kd_node
=kd_tree,
    point
=[0] * dim,
    dim
=dim,
    dist_func
=lambda a, b: sum((a[i] - b[i]) ** 2 for i in xrange(dim))) # Euclidean distance

History