Skip to content
Snippets Groups Projects
kdtree.py 17.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • Michel Anders's avatar
    Michel Anders committed
    # ##### BEGIN GPL LICENSE BLOCK #####
    #
    
    #  SCA Tree Generator, a Blender add-on
    
    Michel Anders's avatar
    Michel Anders committed
    #  (c) 2013 Michel J. Anders (varkenvarken)
    #
    #  This module is: kdtree.py
    #  a pure python implementation of a kdtree
    #
    #  This program is free software; you can redistribute it and/or
    #  modify it under the terms of the GNU General Public License
    #  as published by the Free Software Foundation; either version 2
    #  of the License, or (at your option) any later version.
    #
    #  This program is distributed in the hope that it will be useful,
    #  but WITHOUT ANY WARRANTY; without even the implied warranty of
    #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    #  GNU General Public License for more details.
    #
    #  You should have received a copy of the GNU General Public License
    #  along with this program; if not, write to the Free Software Foundation,
    #  Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
    #
    # ##### END GPL LICENSE BLOCK #####
    
    
    Michel Anders's avatar
    Michel Anders committed
    # <pep8 compliant>
    
    
    from copy import deepcopy
    
    Michel Anders's avatar
    Michel Anders committed
    
    
    Michel Anders's avatar
    Michel Anders committed
    
    
    Michel Anders's avatar
    Michel Anders committed
    class Hyperrectangle:
    
    Michel Anders's avatar
    Michel Anders committed
        '''an axis aligned bounding box of arbitrary dimension'''
    
        def __init__(self, dim, min, max):
            self.dim = dim
            self.min = deepcopy(min)  # min and max should never point to the same instance
            self.max = deepcopy(max)
    
        def extend(self, pos):
            '''adapt the hyperectangle if necessary so it will contain pos.'''
            for i in range(self.dim):
                if pos[i] < self.min[i]:
                    self.min[i] = pos[i]
                elif pos[i] > self.max[i]:
                    self.max[i] = pos[i]
    
        def distance_squared(self, pos):
            '''return the distance squared to the nearest edge, or zero if pos lies within the hyperrectangle'''
            result = 0.0
            for i in range(self.dim):
                if pos[i] < self.min[i]:
                    result += (pos[i] - self.min[i]) ** 2
                elif pos[i] > self.max[i]:
                    result += (pos[i] - self.max[i]) ** 2
            return result
    
        def __str__(self):
            return "[(%d) %s:%s]" % (int(self.dim), str(self.min), str(self.max))
    
    
    Michel Anders's avatar
    Michel Anders committed
    
    class Node:
    
    Michel Anders's avatar
    Michel Anders committed
        """implements a node in a kd-tree"""
    
        def __init__(self, pos, data=None):
            self.pos = deepcopy(pos)
            self.data = data
            self.left = None
            self.right = None
            self.dim = len(pos)
            self.dir = 0
            self.count = 0
            self.level = 0
            self.rect = Hyperrectangle(self.dim, pos, pos)
    
        def addleft(self, node):
            self.left = node
            self.rect.extend(node.pos)
            node.level = self.level + 1
            node.dir = (self.dir + 1) % self.dim
    
        def addright(self, node):
            self.right = node
            self.rect.extend(node.pos)
            node.level = self.level + 1
            node.dir = (self.dir + 1) % self.dim
    
        def distance_squared(self, pos):
            d = self.pos - pos
            return d.dot(d)
    
        def _str(self, level):
            s = '  ' * level + str(self.dir) + ' ' + str(self.pos) + ' ' + str(self.rect) + '\n'
    
            return s + ('' if self.left is None else 'L:' + self.left._str(level + 1)) + \
                       ('' if self.right is None else 'R:' + self.right._str(level + 1))
    
    Michel Anders's avatar
    Michel Anders committed
    
        def __str__(self):
            return self._str(0)
    
    
    Michel Anders's avatar
    Michel Anders committed
    
    class Tree:
    
    Michel Anders's avatar
    Michel Anders committed
        """implements a kd-tree"""
    
        def __init__(self, dim):
            self.root = None
            self.nnearest = 0  # number of nearest neighbor queries
            self.count = 0  # number of nodes visited
            self.level = 0  # deepest node level
    
        def resetcounters(self):
            self.nnearest = 0  # number of nearest neighbor queries
            self.count = 0  # number of nodes visited
    
        def _insert(self, node, pos, data):
            if pos[node.dir] < node.pos[node.dir]:
                if node.left is None:
                    node.addleft(Node(pos, data))
                    return node.left
                else:
                    node.rect.extend(pos)
                    return self._insert(node.left, pos, data)
            else:
                if node.right is None:
                    node.addright(Node(pos, data))
                    return node.right
                else:
                    node.rect.extend(pos)
                    return self._insert(node.right, pos, data)
    
        def insert(self, pos, data):
            if self.root is None:
                self.root = Node(pos, data)
                self.level = self.root.level
                return self.root
            else:
                node = self._insert(self.root, pos, data)
                if node.level > self.level:
                    self.level = node.level
                return node
    
        def _nearest(self, node, pos, checkempty, level=0):
    
            self.count += 1
    
            dir = node.dir
            d = pos[dir] - node.pos[dir]
    
            result = node
            distsq = None
            if checkempty and (node.data is None):
                result = None
            else:
                distsq = node.distance_squared(pos)
    
            if d <= 0:
                neartree = node.left
                fartree = node.right
            else:
                neartree = node.right
                fartree = node.left
    
            if neartree is not None:
                nearnode, neardistsq = self._nearest(neartree, pos, checkempty, level + 1)
                if (result is None) or (neardistsq is not None and neardistsq < distsq):
                    result, distsq = nearnode, neardistsq
    
            if fartree is not None:
                if (result is None) or (fartree.rect.distance_squared(pos) < distsq):
                    farnode, fardistsq = self._nearest(fartree, pos, checkempty, level + 1)
                    if (result is None) or (fardistsq is not None and fardistsq < distsq):
                        result, distsq = farnode, fardistsq
    
            return result, distsq
    
        def nearest(self, pos, checkempty=False):
            self.nnearest += 1
            if self.root is None:
                return None, None
            self.root.count = 0
            node, distsq = self._nearest(self.root, pos, checkempty)
            self.count += self.root.count
            return node, distsq
    
        def __str__(self):
            return str(self.root)
    
    Michel Anders's avatar
    Michel Anders committed
    
    
    Michel Anders's avatar
    Michel Anders committed
    if __name__ == "__main__":
    
    
    Michel Anders's avatar
    Michel Anders committed
        class vector(list):
    
            def __init__(self, *args):
                super().__init__([float(a) for a in args])
    
            def __str__(self):
                return "<%.1f %.1f %.1f>" % tuple(self[0:3])
    
            def __sub__(self, other):
                return vector(self[0] - other[0], self[1] - other[1], self[2] - other[2])
    
            def __add__(self, other):
                return vector(self[0] + other[0], self[1] + other[1], self[2] + other[2])
    
            def __mul__(self, other):
                s = sum(self[i] * other[i] for i in (0, 1, 2))
    
                # print("ds",s,self,other,[self[i]*other[i] for i in (0,1,2)])
    
    Michel Anders's avatar
    Michel Anders committed
                return s
    
            def dot(self, other):
                return sum(self[k] * other[k] for k in (0, 1, 2))
    
        from random import random, seed, shuffle
        from time import time
        import unittest
    
        class TestVector(unittest.TestCase):
            def test_ops(self):
                v1 = vector(1, 0, 0)
                v2 = vector(2, 1, 0)
                self.assertAlmostEqual(v1 * v2, 2.0)
                self.assertAlmostEqual(v1.dot(v2), 2.0)
                v3 = vector(-1, -1, 0)
                self.assertListEqual(v1 - v2, v3)
                v4 = vector(3, 1, 0)
                self.assertListEqual(v1 + v2, v4)
    
        class TestHyperRectangle(unittest.TestCase):
    
            def setUp(self):
                self.left = vector(0, 0, 0)
                self.right = vector(1, 1, 1)
                self.left1 = vector(-1, 0, 0)
                self.left2 = vector(0, -1, 0)
                self.left3 = vector(0, 0, -1)
                self.right1 = vector(2, 0, 0)
                self.right2 = vector(0, 2, 0)
                self.right3 = vector(0, 0, 2)
    
            def test_constructor(self):
                hr = Hyperrectangle(3, self.left, self.right)
                self.assertListEqual(hr.min, self.left)
                self.assertListEqual(hr.max, self.right)
    
            def test_extend(self):
                hr = Hyperrectangle(3, self.left, self.right)
                hr.extend(self.left1)
                self.assertListEqual(hr.min, [-1, 0, 0])
                self.assertListEqual(hr.max, [1, 1, 1])
                hr.extend(self.left2)
                self.assertListEqual(hr.min, [-1, -1, 0])
                self.assertListEqual(hr.max, [1, 1, 1])
                hr.extend(self.left3)
                self.assertListEqual(hr.min, [-1, -1, -1])
                self.assertListEqual(hr.max, [1, 1, 1])
                hr.extend(self.right1)
                self.assertListEqual(hr.min, [-1, -1, -1])
                self.assertListEqual(hr.max, [2, 1, 1])
                hr.extend(self.right2)
                self.assertListEqual(hr.min, [-1, -1, -1])
                self.assertListEqual(hr.max, [2, 2, 1])
                hr.extend(self.right3)
                self.assertListEqual(hr.min, [-1, -1, -1])
                self.assertListEqual(hr.max, [2, 2, 2])
    
            def test_distance_squared(self):
                hr = Hyperrectangle(3, self.left, self.right)
                self.assertAlmostEqual(hr.distance_squared(vector(0.5, 0.5, 0.5)), 0.0)
                self.assertAlmostEqual(hr.distance_squared(vector(0, 0, 0)), 0.0)
                self.assertAlmostEqual(hr.distance_squared(vector(-1, 0, 0)), 1.0)
                self.assertAlmostEqual(hr.distance_squared(vector(2, 0, 0)), 1.0)
                self.assertAlmostEqual(hr.distance_squared(vector(2, 2, 2)), 3.0)
                self.assertAlmostEqual(hr.distance_squared(vector(0.5, 2, 2)), 2.0)
                self.assertAlmostEqual(hr.distance_squared(vector(0.5, -1, -1)), 2.0)
                self.assertAlmostEqual(hr.distance_squared(vector(0.5, 0.5, 2)), 1.0)
    
        class TestTree(unittest.TestCase):
    
            def setUp(self):
                seed(42)
                r = (-1, 0, 1)
                self.points = [vector(k, l, m) for k in r for l in r for m in r]
                d = (-0.1, 0, 0.1)
                self.d = [vector(k, l, m) for k in d for l in d for m in d if (k * l * m) != 0]
                self.repeats = 4
    
            def test_simple(self):
                tree = Tree(3)
                p1 = vector(0, 0, 0)
                p2 = vector(-1, 0, 0)
                p3 = vector(-1, 1, 0)
                d = vector(0.1, 0.1, 0.1)
    
                tree.insert(p1, p1)
                node, distsq = tree.nearest(p1)
                self.assertListEqual(node.pos, p1)
                self.assertAlmostEqual(distsq, 0.0)
                node, distsq = tree.nearest(p1 + d)
                self.assertListEqual(node.pos, p1)
                self.assertAlmostEqual(distsq, 0.03)
    
                tree.insert(p2, p2)
                node, distsq = tree.nearest(p1)
                self.assertListEqual(node.pos, p1)
                self.assertAlmostEqual(distsq, 0.0)
                node, distsq = tree.nearest(p1 + d)
                self.assertListEqual(node.pos, p1)
                self.assertAlmostEqual(distsq, 0.03)
    
                node, distsq = tree.nearest(p2)
                self.assertListEqual(node.pos, p2)
                self.assertAlmostEqual(distsq, 0.0)
                node, distsq = tree.nearest(p2 + d)
                self.assertListEqual(node.pos, p2)
                self.assertAlmostEqual(distsq, 0.03)
    
                tree.insert(p3, p3)
                node, distsq = tree.nearest(p1)
                self.assertListEqual(node.pos, p1)
                self.assertAlmostEqual(distsq, 0.0)
                node, distsq = tree.nearest(p1 + d)
                self.assertListEqual(node.pos, p1)
                self.assertAlmostEqual(distsq, 0.03)
    
                node, distsq = tree.nearest(p2)
                self.assertListEqual(node.pos, p2)
                self.assertAlmostEqual(distsq, 0.0)
                node, distsq = tree.nearest(p2 + d)
                self.assertListEqual(node.pos, p2)
                self.assertAlmostEqual(distsq, 0.03)
    
                node, distsq = tree.nearest(p3)
                self.assertListEqual(node.pos, p3)
                self.assertAlmostEqual(distsq, 0.0)
                node, distsq = tree.nearest(p3 + d)
                self.assertListEqual(node.pos, p3)
                self.assertAlmostEqual(distsq, 0.03)
    
            def test_nearest(self):
                for n in range(self.repeats):
                    tree = Tree(3)
                    shuffle(self.points)
                    for p in self.points:
                        tree.insert(p, p)  # data equal to position
    
                    for p in self.points:
                        for d in self.d:
                            node, distsq = tree.nearest(p + d)
                            s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
                            self.assertListEqual(node.pos, p, msg=s)
                            self.assertListEqual(node.data, p)
                            self.assertAlmostEqual(distsq, d.dot(d), msg=s)
    
                    for p in self.points:
                        node, distsq = tree.nearest(p)
                        self.assertListEqual(node.pos, p)
                        self.assertListEqual(node.data, p)
                        self.assertAlmostEqual(distsq, 0.0)
    
            def test_nearest_empty(self):
                for n in range(self.repeats):
                    tree = Tree(3)
                    shuffle(self.points)
                    for p in self.points:
                        tree.insert(p, p)  # data equal to position
    
                    for p in self.points:
                        for d in self.d:
                            node, distsq = tree.nearest(p + d)
                            s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
                            self.assertListEqual(node.pos, p, msg=s)
                            self.assertListEqual(node.data, p)
                            self.assertAlmostEqual(distsq, d.dot(d), msg=s)
    
                    for p in self.points:
                        node, distsq = tree.nearest(p)
                        self.assertListEqual(node.pos, p)
                        self.assertListEqual(node.data, p)
                        self.assertAlmostEqual(distsq, 0.0)
    
                    # zeroing out a node should not affect retrieving any other node ...
                    node, _ = tree.nearest(self.points[-1])  # last point
                    node.data = None
                    for p in self.points[:-1]:  # all but last
                        for d in self.d:
                            node, distsq = tree.nearest(p + d)
                            s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
                            self.assertListEqual(node.pos, p, msg=s)
                            self.assertListEqual(node.data, p)
                            self.assertAlmostEqual(distsq, d.dot(d), msg=s)
    
                    for p in self.points[:-1]:  # all but last
                        node, distsq = tree.nearest(p)
                        self.assertListEqual(node.pos, p)
                        self.assertListEqual(node.data, p)
                        self.assertAlmostEqual(distsq, 0.0)
    
                    # ... also, we should be able to retrieve the node anyway ...
                    node, distsq = tree.nearest(self.points[-1])
                    self.assertListEqual(node.pos, self.points[-1])
                    self.assertIsNone(node.data)
                    self.assertAlmostEqual(distsq, 0.0)
    
                    # ... even for points nearby ...
                    for d in self.d:
                        node, distsq = tree.nearest(self.points[-1] + d)
                        self.assertEqual(node.pos, self.points[-1])
                        self.assertIsNone(node.data)
                        self.assertAlmostEqual(distsq, d.dot(d))
    
                    # ... unless we set checkempty
                    node, distsq = tree.nearest(self.points[-1], checkempty=True)
                    self.assertNotEqual(node.pos, self.points[-1])
                    self.assertIsNotNone(node.data)
                    self.assertAlmostEqual(distsq, 1.0)  # on a perpendicular grid nearest neighbor is at most 1 away
    
            def test_performance(self):
                tree = Tree(3)
                tsize = 1000
                qsize = 1000
                emptyq = 3
    
                print("<performance test, may take several seconds>")
                qpos = [vector(random(), random(), random()) for p in range(qsize)]
                for p in range(tsize):
                    pos = vector(random(), random(), random())
                    tree.insert(pos, pos)
                s = time()
                for p in qpos:
                    node, distsq = tree.nearest(p)
                e = time() - s
                print("queries|tree size|tree height|empties|query load|query time")
    
                print("{0:7d}|{2:9d}|{1.level:11d}|      0|{3:10.2f}|{4:10.1f}".format(
                        qsize, tree, tsize, float(tree.count) / qsize, e)
                )
    
    Michel Anders's avatar
    Michel Anders committed
                tree.resetcounters()
                empty = []
                for p in range(tsize * 9):
                    pos = vector(random(), random(), random())
                    tree.insert(pos, pos)
                    if (p % emptyq) == 1:
                        empty.append(pos)
                s = time()
                for p in qpos:
                    node, distsq = tree.nearest(p)
                e2 = time() - s
    
                print("{0:7d}|{2:9d}|{1.level:11d}|      0|{3:10.2f}|{4:10.1f}".format(
                        qsize, tree, tsize * 10, float(tree.count) / qsize, e2)
                )
    
    Michel Anders's avatar
    Michel Anders committed
                self.assertLess(e2, 3 * e, msg="a 10x bigger tree shouldn't take more than 3x the time to query")
    
                for p in empty:
                    node, distsq = tree.nearest(p)
                    node.data = None
                tree.resetcounters()
                s = time()
                for p in qpos:
                    node, distsq = tree.nearest(p, checkempty=True)
                e3 = time() - s
    
                print("{0:7d}|{2:9d}|{1.level:11d}|{5:7d}|{3:10.2f}|{4:10.1f}".format(
                        qsize, tree, tsize * 10, float(tree.count) / qsize, e3, tsize * 10 // emptyq)
                )
    
    Michel Anders's avatar
    Michel Anders committed
    
        unittest.main()