# ##### BEGIN GPL LICENSE BLOCK ##### # # SCA Tree Generator, a Blender add-on # (c) 2013 Michel J. Anders (varkenvarken) # # This module is: kdtree.py # a pure python implementation of a kdtree # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # # ##### END GPL LICENSE BLOCK ##### # <pep8 compliant> from copy import deepcopy class Hyperrectangle: '''an axis aligned bounding box of arbitrary dimension''' def __init__(self, dim, min, max): self.dim = dim self.min = deepcopy(min) # min and max should never point to the same instance self.max = deepcopy(max) def extend(self, pos): '''adapt the hyperectangle if necessary so it will contain pos.''' for i in range(self.dim): if pos[i] < self.min[i]: self.min[i] = pos[i] elif pos[i] > self.max[i]: self.max[i] = pos[i] def distance_squared(self, pos): '''return the distance squared to the nearest edge, or zero if pos lies within the hyperrectangle''' result = 0.0 for i in range(self.dim): if pos[i] < self.min[i]: result += (pos[i] - self.min[i]) ** 2 elif pos[i] > self.max[i]: result += (pos[i] - self.max[i]) ** 2 return result def __str__(self): return "[(%d) %s:%s]" % (int(self.dim), str(self.min), str(self.max)) class Node: """implements a node in a kd-tree""" def __init__(self, pos, data=None): self.pos = deepcopy(pos) self.data = data self.left = None self.right = None self.dim = len(pos) self.dir = 0 self.count = 0 self.level = 0 self.rect = Hyperrectangle(self.dim, pos, pos) def addleft(self, node): self.left = node self.rect.extend(node.pos) node.level = self.level + 1 node.dir = (self.dir + 1) % self.dim def addright(self, node): self.right = node self.rect.extend(node.pos) node.level = self.level + 1 node.dir = (self.dir + 1) % self.dim def distance_squared(self, pos): d = self.pos - pos return d.dot(d) def _str(self, level): s = ' ' * level + str(self.dir) + ' ' + str(self.pos) + ' ' + str(self.rect) + '\n' return s + ('' if self.left is None else 'L:' + self.left._str(level + 1)) + \ ('' if self.right is None else 'R:' + self.right._str(level + 1)) def __str__(self): return self._str(0) class Tree: """implements a kd-tree""" def __init__(self, dim): self.root = None self.nnearest = 0 # number of nearest neighbor queries self.count = 0 # number of nodes visited self.level = 0 # deepest node level def resetcounters(self): self.nnearest = 0 # number of nearest neighbor queries self.count = 0 # number of nodes visited def _insert(self, node, pos, data): if pos[node.dir] < node.pos[node.dir]: if node.left is None: node.addleft(Node(pos, data)) return node.left else: node.rect.extend(pos) return self._insert(node.left, pos, data) else: if node.right is None: node.addright(Node(pos, data)) return node.right else: node.rect.extend(pos) return self._insert(node.right, pos, data) def insert(self, pos, data): if self.root is None: self.root = Node(pos, data) self.level = self.root.level return self.root else: node = self._insert(self.root, pos, data) if node.level > self.level: self.level = node.level return node def _nearest(self, node, pos, checkempty, level=0): self.count += 1 dir = node.dir d = pos[dir] - node.pos[dir] result = node distsq = None if checkempty and (node.data is None): result = None else: distsq = node.distance_squared(pos) if d <= 0: neartree = node.left fartree = node.right else: neartree = node.right fartree = node.left if neartree is not None: nearnode, neardistsq = self._nearest(neartree, pos, checkempty, level + 1) if (result is None) or (neardistsq is not None and neardistsq < distsq): result, distsq = nearnode, neardistsq if fartree is not None: if (result is None) or (fartree.rect.distance_squared(pos) < distsq): farnode, fardistsq = self._nearest(fartree, pos, checkempty, level + 1) if (result is None) or (fardistsq is not None and fardistsq < distsq): result, distsq = farnode, fardistsq return result, distsq def nearest(self, pos, checkempty=False): self.nnearest += 1 if self.root is None: return None, None self.root.count = 0 node, distsq = self._nearest(self.root, pos, checkempty) self.count += self.root.count return node, distsq def __str__(self): return str(self.root) if __name__ == "__main__": class vector(list): def __init__(self, *args): super().__init__([float(a) for a in args]) def __str__(self): return "<%.1f %.1f %.1f>" % tuple(self[0:3]) def __sub__(self, other): return vector(self[0] - other[0], self[1] - other[1], self[2] - other[2]) def __add__(self, other): return vector(self[0] + other[0], self[1] + other[1], self[2] + other[2]) def __mul__(self, other): s = sum(self[i] * other[i] for i in (0, 1, 2)) # print("ds",s,self,other,[self[i]*other[i] for i in (0,1,2)]) return s def dot(self, other): return sum(self[k] * other[k] for k in (0, 1, 2)) from random import random, seed, shuffle from time import time import unittest class TestVector(unittest.TestCase): def test_ops(self): v1 = vector(1, 0, 0) v2 = vector(2, 1, 0) self.assertAlmostEqual(v1 * v2, 2.0) self.assertAlmostEqual(v1.dot(v2), 2.0) v3 = vector(-1, -1, 0) self.assertListEqual(v1 - v2, v3) v4 = vector(3, 1, 0) self.assertListEqual(v1 + v2, v4) class TestHyperRectangle(unittest.TestCase): def setUp(self): self.left = vector(0, 0, 0) self.right = vector(1, 1, 1) self.left1 = vector(-1, 0, 0) self.left2 = vector(0, -1, 0) self.left3 = vector(0, 0, -1) self.right1 = vector(2, 0, 0) self.right2 = vector(0, 2, 0) self.right3 = vector(0, 0, 2) def test_constructor(self): hr = Hyperrectangle(3, self.left, self.right) self.assertListEqual(hr.min, self.left) self.assertListEqual(hr.max, self.right) def test_extend(self): hr = Hyperrectangle(3, self.left, self.right) hr.extend(self.left1) self.assertListEqual(hr.min, [-1, 0, 0]) self.assertListEqual(hr.max, [1, 1, 1]) hr.extend(self.left2) self.assertListEqual(hr.min, [-1, -1, 0]) self.assertListEqual(hr.max, [1, 1, 1]) hr.extend(self.left3) self.assertListEqual(hr.min, [-1, -1, -1]) self.assertListEqual(hr.max, [1, 1, 1]) hr.extend(self.right1) self.assertListEqual(hr.min, [-1, -1, -1]) self.assertListEqual(hr.max, [2, 1, 1]) hr.extend(self.right2) self.assertListEqual(hr.min, [-1, -1, -1]) self.assertListEqual(hr.max, [2, 2, 1]) hr.extend(self.right3) self.assertListEqual(hr.min, [-1, -1, -1]) self.assertListEqual(hr.max, [2, 2, 2]) def test_distance_squared(self): hr = Hyperrectangle(3, self.left, self.right) self.assertAlmostEqual(hr.distance_squared(vector(0.5, 0.5, 0.5)), 0.0) self.assertAlmostEqual(hr.distance_squared(vector(0, 0, 0)), 0.0) self.assertAlmostEqual(hr.distance_squared(vector(-1, 0, 0)), 1.0) self.assertAlmostEqual(hr.distance_squared(vector(2, 0, 0)), 1.0) self.assertAlmostEqual(hr.distance_squared(vector(2, 2, 2)), 3.0) self.assertAlmostEqual(hr.distance_squared(vector(0.5, 2, 2)), 2.0) self.assertAlmostEqual(hr.distance_squared(vector(0.5, -1, -1)), 2.0) self.assertAlmostEqual(hr.distance_squared(vector(0.5, 0.5, 2)), 1.0) class TestTree(unittest.TestCase): def setUp(self): seed(42) r = (-1, 0, 1) self.points = [vector(k, l, m) for k in r for l in r for m in r] d = (-0.1, 0, 0.1) self.d = [vector(k, l, m) for k in d for l in d for m in d if (k * l * m) != 0] self.repeats = 4 def test_simple(self): tree = Tree(3) p1 = vector(0, 0, 0) p2 = vector(-1, 0, 0) p3 = vector(-1, 1, 0) d = vector(0.1, 0.1, 0.1) tree.insert(p1, p1) node, distsq = tree.nearest(p1) self.assertListEqual(node.pos, p1) self.assertAlmostEqual(distsq, 0.0) node, distsq = tree.nearest(p1 + d) self.assertListEqual(node.pos, p1) self.assertAlmostEqual(distsq, 0.03) tree.insert(p2, p2) node, distsq = tree.nearest(p1) self.assertListEqual(node.pos, p1) self.assertAlmostEqual(distsq, 0.0) node, distsq = tree.nearest(p1 + d) self.assertListEqual(node.pos, p1) self.assertAlmostEqual(distsq, 0.03) node, distsq = tree.nearest(p2) self.assertListEqual(node.pos, p2) self.assertAlmostEqual(distsq, 0.0) node, distsq = tree.nearest(p2 + d) self.assertListEqual(node.pos, p2) self.assertAlmostEqual(distsq, 0.03) tree.insert(p3, p3) node, distsq = tree.nearest(p1) self.assertListEqual(node.pos, p1) self.assertAlmostEqual(distsq, 0.0) node, distsq = tree.nearest(p1 + d) self.assertListEqual(node.pos, p1) self.assertAlmostEqual(distsq, 0.03) node, distsq = tree.nearest(p2) self.assertListEqual(node.pos, p2) self.assertAlmostEqual(distsq, 0.0) node, distsq = tree.nearest(p2 + d) self.assertListEqual(node.pos, p2) self.assertAlmostEqual(distsq, 0.03) node, distsq = tree.nearest(p3) self.assertListEqual(node.pos, p3) self.assertAlmostEqual(distsq, 0.0) node, distsq = tree.nearest(p3 + d) self.assertListEqual(node.pos, p3) self.assertAlmostEqual(distsq, 0.03) def test_nearest(self): for n in range(self.repeats): tree = Tree(3) shuffle(self.points) for p in self.points: tree.insert(p, p) # data equal to position for p in self.points: for d in self.d: node, distsq = tree.nearest(p + d) s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root)) self.assertListEqual(node.pos, p, msg=s) self.assertListEqual(node.data, p) self.assertAlmostEqual(distsq, d.dot(d), msg=s) for p in self.points: node, distsq = tree.nearest(p) self.assertListEqual(node.pos, p) self.assertListEqual(node.data, p) self.assertAlmostEqual(distsq, 0.0) def test_nearest_empty(self): for n in range(self.repeats): tree = Tree(3) shuffle(self.points) for p in self.points: tree.insert(p, p) # data equal to position for p in self.points: for d in self.d: node, distsq = tree.nearest(p + d) s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root)) self.assertListEqual(node.pos, p, msg=s) self.assertListEqual(node.data, p) self.assertAlmostEqual(distsq, d.dot(d), msg=s) for p in self.points: node, distsq = tree.nearest(p) self.assertListEqual(node.pos, p) self.assertListEqual(node.data, p) self.assertAlmostEqual(distsq, 0.0) # zeroing out a node should not affect retrieving any other node ... node, _ = tree.nearest(self.points[-1]) # last point node.data = None for p in self.points[:-1]: # all but last for d in self.d: node, distsq = tree.nearest(p + d) s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root)) self.assertListEqual(node.pos, p, msg=s) self.assertListEqual(node.data, p) self.assertAlmostEqual(distsq, d.dot(d), msg=s) for p in self.points[:-1]: # all but last node, distsq = tree.nearest(p) self.assertListEqual(node.pos, p) self.assertListEqual(node.data, p) self.assertAlmostEqual(distsq, 0.0) # ... also, we should be able to retrieve the node anyway ... node, distsq = tree.nearest(self.points[-1]) self.assertListEqual(node.pos, self.points[-1]) self.assertIsNone(node.data) self.assertAlmostEqual(distsq, 0.0) # ... even for points nearby ... for d in self.d: node, distsq = tree.nearest(self.points[-1] + d) self.assertEqual(node.pos, self.points[-1]) self.assertIsNone(node.data) self.assertAlmostEqual(distsq, d.dot(d)) # ... unless we set checkempty node, distsq = tree.nearest(self.points[-1], checkempty=True) self.assertNotEqual(node.pos, self.points[-1]) self.assertIsNotNone(node.data) self.assertAlmostEqual(distsq, 1.0) # on a perpendicular grid nearest neighbor is at most 1 away def test_performance(self): tree = Tree(3) tsize = 1000 qsize = 1000 emptyq = 3 print("<performance test, may take several seconds>") qpos = [vector(random(), random(), random()) for p in range(qsize)] for p in range(tsize): pos = vector(random(), random(), random()) tree.insert(pos, pos) s = time() for p in qpos: node, distsq = tree.nearest(p) e = time() - s print("queries|tree size|tree height|empties|query load|query time") print("{0:7d}|{2:9d}|{1.level:11d}| 0|{3:10.2f}|{4:10.1f}".format( qsize, tree, tsize, float(tree.count) / qsize, e) ) tree.resetcounters() empty = [] for p in range(tsize * 9): pos = vector(random(), random(), random()) tree.insert(pos, pos) if (p % emptyq) == 1: empty.append(pos) s = time() for p in qpos: node, distsq = tree.nearest(p) e2 = time() - s print("{0:7d}|{2:9d}|{1.level:11d}| 0|{3:10.2f}|{4:10.1f}".format( qsize, tree, tsize * 10, float(tree.count) / qsize, e2) ) self.assertLess(e2, 3 * e, msg="a 10x bigger tree shouldn't take more than 3x the time to query") for p in empty: node, distsq = tree.nearest(p) node.data = None tree.resetcounters() s = time() for p in qpos: node, distsq = tree.nearest(p, checkempty=True) e3 = time() - s print("{0:7d}|{2:9d}|{1.level:11d}|{5:7d}|{3:10.2f}|{4:10.1f}".format( qsize, tree, tsize * 10, float(tree.count) / qsize, e3, tsize * 10 // emptyq) ) unittest.main()