# ##### BEGIN GPL LICENSE BLOCK ##### # # SCA Tree Generator, a Blender addon # (c) 2013 Michel J. Anders (varkenvarken) # # This module is: kdtree.py # a pure python implementation of a kdtree # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # # ##### END GPL LICENSE BLOCK ##### from copy import copy, deepcopy class Hyperrectangle: '''an axis aligned bounding box of arbitrary dimension''' def __init__(self, dim, min, max): self.dim = dim self.min = deepcopy(min) # min and max should never point to the same instance self.max = deepcopy(max) def extend(self,pos): '''adapt the hyperectangle if necessary so it will contain pos.''' for i in range(self.dim): if pos[i]<self.min[i]: self.min[i]=pos[i] elif pos[i]>self.max[i]: self.max[i]=pos[i] def distance_squared(self,pos): '''return the distance squared to the nearest edge, or zero if pos lies within the hyperrectangle''' result=0.0 for i in range(self.dim): if pos[i]<self.min[i]: result+=(pos[i ]-self.min[i ])**2 elif pos[i]>self.max[i]: result+=(pos[i ]-self.max[i ])**2 return result def __str__(self): return "[(%d) %s:%s]"%(int(self.dim),str(self.min),str(self.max)) class Node: """implements a node in a kd-tree""" def __init__(self, pos, data=None): self.pos = deepcopy(pos) self.data = data self.left = None self.right = None self.dim=len(pos) self.dir = 0 self.count=0 self.level=0 self.rect=Hyperrectangle(self.dim,pos,pos) def addleft(self, node): self.left = node self.rect.extend(node.pos) node.level=self.level+1 node.dir=(self.dir+1)%self.dim def addright(self, node): self.right = node self.rect.extend(node.pos) node.level=self.level+1 node.dir=(self.dir+1)%self.dim def distance_squared(self, pos): d=self.pos-pos return d.dot(d) def _str(self,level): s = ' '*level+str(self.dir)+' '+str(self.pos)+' '+str(self.rect)+'\n' return s + ('' if self.left is None else 'L:'+self.left._str(level+1)) + ('' if self.right is None else 'R:'+self.right._str(level+1)) def __str__(self): return self._str(0) class Tree: """implements a kd-tree""" def __init__(self, dim): self.root = None self.nnearest=0 # number of nearest neighbor queries self.count=0 # number of nodes visited self.level=0 # deepest node level def resetcounters(self): self.nnearest=0 # number of nearest neighbor queries self.count=0 # number of nodes visited def _insert(self, node, pos, data): if pos[node.dir] < node.pos[node.dir]: if node.left is None: node.addleft(Node(pos, data)) return node.left else: node.rect.extend(pos) return self._insert(node.left, pos, data) else: if node.right is None: node.addright(Node(pos, data)) return node.right else: node.rect.extend(pos) return self._insert(node.right, pos, data) def insert(self, pos, data): if self.root is None: self.root = Node(pos,data) self.level = self.root.level return self.root else: node=self._insert(self.root, pos, data) if node.level > self.level : self.level = node.level return node def _nearest(self, node, pos, checkempty, level=0): self.count+=1 dir = node.dir d = pos[dir] - node.pos[dir] result = node distsq = None if checkempty and (node.data is None): result = None else: distsq = node.distance_squared(pos) if d <= 0: neartree = node.left fartree = node.right else: neartree = node.right fartree = node.left if neartree is not None: nearnode, neardistsq = self._nearest(neartree,pos,checkempty,level+1) if (result is None) or (neardistsq is not None and neardistsq < distsq): result, distsq = nearnode, neardistsq if fartree is not None: if (result is None) or (fartree.rect.distance_squared(pos) < distsq): farnode, fardistsq = self._nearest(fartree,pos,checkempty,level+1) if (result is None) or (fardistsq is not None and fardistsq < distsq): result, distsq = farnode, fardistsq return result, distsq def nearest(self, pos, checkempty=False): self.nnearest+=1 if self.root is None: return None, None self.root.count=0 node, distsq = self._nearest(self.root, pos, checkempty) self.count+=self.root.count return node,distsq def __str__(self): return str(self.root) if __name__ == "__main__": class vector(list): def __init__(self, *args): super().__init__([float(a) for a in args]) def __str__(self): return "<%.1f %.1f %.1f>"%tuple(self[0:3]) def __sub__(self,other): return vector(self[0]-other[0], self[1]-other[1], self[2]-other[2]) def __add__(self,other): return vector(self[0]+other[0], self[1]+other[1], self[2]+other[2]) def __mul__(self,other): s=sum(self[i]*other[i] for i in (0,1,2)) #print("ds",s,self,other,[self[i]*other[i] for i in (0,1,2)]) return s def dot(self,other): return sum(self[k]*other[k] for k in (0,1,2)) from random import random,seed, shuffle from time import time import unittest class TestVector(unittest.TestCase): def test_ops(self): v1=vector(1,0,0) v2=vector(2,1,0) self.assertAlmostEqual(v1*v2,2.0) self.assertAlmostEqual(v1.dot(v2),2.0) v3=vector(-1,-1,0) self.assertListEqual(v1-v2,v3) v4=vector(3,1,0) self.assertListEqual(v1+v2,v4) class TestHyperRectangle(unittest.TestCase): def setUp(self): self.left=vector(0,0,0) self.right=vector(1,1,1) self.left1=vector(-1,0,0) self.left2=vector(0,-1,0) self.left3=vector(0,0,-1) self.right1=vector(2,0,0) self.right2=vector(0,2,0) self.right3=vector(0,0,2) def test_constructor(self): hr=Hyperrectangle(3,self.left,self.right) self.assertListEqual(hr.min,self.left) self.assertListEqual(hr.max,self.right) def test_extend(self): hr=Hyperrectangle(3,self.left,self.right) hr.extend(self.left1) self.assertListEqual(hr.min,[-1,0,0]) self.assertListEqual(hr.max,[1,1,1]) hr.extend(self.left2) self.assertListEqual(hr.min,[-1,-1,0]) self.assertListEqual(hr.max,[1,1,1]) hr.extend(self.left3) self.assertListEqual(hr.min,[-1,-1,-1]) self.assertListEqual(hr.max,[1,1,1]) hr.extend(self.right1) self.assertListEqual(hr.min,[-1,-1,-1]) self.assertListEqual(hr.max,[2,1,1]) hr.extend(self.right2) self.assertListEqual(hr.min,[-1,-1,-1]) self.assertListEqual(hr.max,[2,2,1]) hr.extend(self.right3) self.assertListEqual(hr.min,[-1,-1,-1]) self.assertListEqual(hr.max,[2,2,2]) def test_distance_squared(self): hr=Hyperrectangle(3,self.left,self.right) self.assertAlmostEqual(hr.distance_squared(vector(0.5,0.5,0.5)),0.0) self.assertAlmostEqual(hr.distance_squared(vector(0,0,0)),0.0) self.assertAlmostEqual(hr.distance_squared(vector(-1,0,0)),1.0) self.assertAlmostEqual(hr.distance_squared(vector(2,0,0)),1.0) self.assertAlmostEqual(hr.distance_squared(vector(2,2,2)),3.0) self.assertAlmostEqual(hr.distance_squared(vector(0.5,2,2)),2.0) self.assertAlmostEqual(hr.distance_squared(vector(0.5,-1,-1)),2.0) self.assertAlmostEqual(hr.distance_squared(vector(0.5,0.5,2)),1.0) class TestTree(unittest.TestCase): def setUp(self): seed(42) r=(-1,0,1) self.points=[vector(k,l,m) for k in r for l in r for m in r] d=(-0.1,0,0.1) self.d=[vector(k,l,m) for k in d for l in d for m in d if (k*l*m) != 0] self.repeats=4 def test_simple(self): tree=Tree(3) p1=vector(0,0,0) p2=vector(-1,0,0) p3=vector(-1,1,0) d=vector(0.1,0.1,0.1) tree.insert(p1,p1) node, distsq = tree.nearest(p1) self.assertListEqual(node.pos,p1) self.assertAlmostEqual(distsq,0.0) node, distsq = tree.nearest(p1+d) self.assertListEqual(node.pos,p1) self.assertAlmostEqual(distsq,0.03) tree.insert(p2,p2) node, distsq = tree.nearest(p1) self.assertListEqual(node.pos,p1) self.assertAlmostEqual(distsq,0.0) node, distsq = tree.nearest(p1+d) self.assertListEqual(node.pos,p1) self.assertAlmostEqual(distsq,0.03) node, distsq = tree.nearest(p2) self.assertListEqual(node.pos,p2) self.assertAlmostEqual(distsq,0.0) node, distsq = tree.nearest(p2+d) self.assertListEqual(node.pos,p2) self.assertAlmostEqual(distsq,0.03) tree.insert(p3,p3) node, distsq = tree.nearest(p1) self.assertListEqual(node.pos,p1) self.assertAlmostEqual(distsq,0.0) node, distsq = tree.nearest(p1+d) self.assertListEqual(node.pos,p1) self.assertAlmostEqual(distsq,0.03) node, distsq = tree.nearest(p2) self.assertListEqual(node.pos,p2) self.assertAlmostEqual(distsq,0.0) node, distsq = tree.nearest(p2+d) self.assertListEqual(node.pos,p2) self.assertAlmostEqual(distsq,0.03) node, distsq = tree.nearest(p3) self.assertListEqual(node.pos,p3) self.assertAlmostEqual(distsq,0.0) node, distsq = tree.nearest(p3+d) self.assertListEqual(node.pos,p3) self.assertAlmostEqual(distsq,0.03) def test_nearest(self): for n in range(self.repeats): tree=Tree(3) shuffle(self.points) for p in self.points: tree.insert(p,p) # data equal to position for p in self.points: for d in self.d: node, distsq = tree.nearest(p+d) s="%s %s %s %s\n%s"%(str(p+d), str(p), str(d), str(node.pos),str(tree.root)) self.assertListEqual(node.pos,p,msg=s) self.assertListEqual(node.data,p) self.assertAlmostEqual(distsq,d.dot(d),msg=s) for p in self.points: node, distsq = tree.nearest(p) self.assertListEqual(node.pos,p) self.assertListEqual(node.data,p) self.assertAlmostEqual(distsq,0.0) def test_nearest_empty(self): for n in range(self.repeats): tree=Tree(3) shuffle(self.points) for p in self.points: tree.insert(p,p) # data equal to position for p in self.points: for d in self.d: node, distsq = tree.nearest(p+d) s="%s %s %s %s\n%s"%(str(p+d), str(p), str(d), str(node.pos),str(tree.root)) self.assertListEqual(node.pos,p,msg=s) self.assertListEqual(node.data,p) self.assertAlmostEqual(distsq,d.dot(d),msg=s) for p in self.points: node, distsq = tree.nearest(p) self.assertListEqual(node.pos,p) self.assertListEqual(node.data,p) self.assertAlmostEqual(distsq,0.0) # zeroing out a node should not affect retrieving any other node ... node , _ = tree.nearest(self.points[-1]) # last point node.data=None for p in self.points[:-1]: # all but last for d in self.d: node, distsq = tree.nearest(p+d) s="%s %s %s %s\n%s"%(str(p+d), str(p), str(d), str(node.pos),str(tree.root)) self.assertListEqual(node.pos,p,msg=s) self.assertListEqual(node.data,p) self.assertAlmostEqual(distsq,d.dot(d),msg=s) for p in self.points[:-1]: # all but last node, distsq = tree.nearest(p) self.assertListEqual(node.pos,p) self.assertListEqual(node.data,p) self.assertAlmostEqual(distsq,0.0) # ... also, we should be able to retrieve the node anyway ... node, distsq = tree.nearest(self.points[-1]) self.assertListEqual(node.pos,self.points[-1]) self.assertIsNone(node.data) self.assertAlmostEqual(distsq,0.0) # ... even for points nearby ... for d in self.d: node, distsq = tree.nearest(self.points[-1]+d) self.assertEqual(node.pos,self.points[-1]) self.assertIsNone(node.data) self.assertAlmostEqual(distsq,d.dot(d)) # ... unless we set checkempty node, distsq = tree.nearest(self.points[-1],checkempty=True) self.assertNotEqual(node.pos,self.points[-1]) self.assertIsNotNone(node.data) self.assertAlmostEqual(distsq,1.0) # on a perpendicular grid nearest neighbor is at most 1 away def test_performance(self): tree=Tree(3) tsize=1000 qsize=1000 emptyq=3 print("<performance test, may take several seconds>") qpos=[vector(random(),random(),random()) for p in range(qsize)] for p in range(tsize): pos=vector(random(),random(),random()) tree.insert(pos,pos) s=time() for p in qpos: node, distsq = tree.nearest(p) e=time()-s print("queries|tree size|tree height|empties|query load|query time") print("{0:7d}|{2:9d}|{1.level:11d}| 0|{3:10.2f}|{4:10.1f}".format(qsize,tree,tsize,float(tree.count)/qsize,e)) tree.resetcounters() empty=[] for p in range(tsize*9): pos=vector(random(),random(),random()) tree.insert(pos,pos) if (p % emptyq ) == 1 : empty.append(pos) s=time() for p in qpos: node, distsq = tree.nearest(p) e2=time()-s print("{0:7d}|{2:9d}|{1.level:11d}| 0|{3:10.2f}|{4:10.1f}".format(qsize,tree,tsize*10,float(tree.count)/qsize,e2)) self.assertLess(e2,3*e,msg="a 10x bigger tree shouldn't take more than 3x the time to query") for p in empty: node, distsq = tree.nearest(p) node.data = None tree.resetcounters() s=time() for p in qpos: node, distsq = tree.nearest(p,checkempty=True) e3=time()-s print("{0:7d}|{2:9d}|{1.level:11d}|{5:7d}|{3:10.2f}|{4:10.1f}".format(qsize,tree,tsize*10,float(tree.count)/qsize,e3,tsize*10//emptyq)) unittest.main()