Newer
Older
# SCA Tree Generator, a Blender add-on
# (c) 2013 Michel J. Anders (varkenvarken)
#
# This module is: kdtree.py
# a pure python implementation of a kdtree
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
# ##### END GPL LICENSE BLOCK #####
from copy import deepcopy
'''an axis aligned bounding box of arbitrary dimension'''
def __init__(self, dim, min, max):
self.dim = dim
self.min = deepcopy(min) # min and max should never point to the same instance
self.max = deepcopy(max)
def extend(self, pos):
'''adapt the hyperectangle if necessary so it will contain pos.'''
for i in range(self.dim):
if pos[i] < self.min[i]:
self.min[i] = pos[i]
elif pos[i] > self.max[i]:
self.max[i] = pos[i]
def distance_squared(self, pos):
'''return the distance squared to the nearest edge, or zero if pos lies within the hyperrectangle'''
result = 0.0
for i in range(self.dim):
if pos[i] < self.min[i]:
result += (pos[i] - self.min[i]) ** 2
elif pos[i] > self.max[i]:
result += (pos[i] - self.max[i]) ** 2
return result
def __str__(self):
return "[(%d) %s:%s]" % (int(self.dim), str(self.min), str(self.max))
"""implements a node in a kd-tree"""
def __init__(self, pos, data=None):
self.pos = deepcopy(pos)
self.data = data
self.left = None
self.right = None
self.dim = len(pos)
self.dir = 0
self.count = 0
self.level = 0
self.rect = Hyperrectangle(self.dim, pos, pos)
def addleft(self, node):
self.left = node
self.rect.extend(node.pos)
node.level = self.level + 1
node.dir = (self.dir + 1) % self.dim
def addright(self, node):
self.right = node
self.rect.extend(node.pos)
node.level = self.level + 1
node.dir = (self.dir + 1) % self.dim
def distance_squared(self, pos):
d = self.pos - pos
return d.dot(d)
def _str(self, level):
s = ' ' * level + str(self.dir) + ' ' + str(self.pos) + ' ' + str(self.rect) + '\n'
return s + ('' if self.left is None else 'L:' + self.left._str(level + 1)) + \
('' if self.right is None else 'R:' + self.right._str(level + 1))
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""implements a kd-tree"""
def __init__(self, dim):
self.root = None
self.nnearest = 0 # number of nearest neighbor queries
self.count = 0 # number of nodes visited
self.level = 0 # deepest node level
def resetcounters(self):
self.nnearest = 0 # number of nearest neighbor queries
self.count = 0 # number of nodes visited
def _insert(self, node, pos, data):
if pos[node.dir] < node.pos[node.dir]:
if node.left is None:
node.addleft(Node(pos, data))
return node.left
else:
node.rect.extend(pos)
return self._insert(node.left, pos, data)
else:
if node.right is None:
node.addright(Node(pos, data))
return node.right
else:
node.rect.extend(pos)
return self._insert(node.right, pos, data)
def insert(self, pos, data):
if self.root is None:
self.root = Node(pos, data)
self.level = self.root.level
return self.root
else:
node = self._insert(self.root, pos, data)
if node.level > self.level:
self.level = node.level
return node
def _nearest(self, node, pos, checkempty, level=0):
self.count += 1
dir = node.dir
d = pos[dir] - node.pos[dir]
result = node
distsq = None
if checkempty and (node.data is None):
result = None
else:
distsq = node.distance_squared(pos)
if d <= 0:
neartree = node.left
fartree = node.right
else:
neartree = node.right
fartree = node.left
if neartree is not None:
nearnode, neardistsq = self._nearest(neartree, pos, checkempty, level + 1)
if (result is None) or (neardistsq is not None and neardistsq < distsq):
result, distsq = nearnode, neardistsq
if fartree is not None:
if (result is None) or (fartree.rect.distance_squared(pos) < distsq):
farnode, fardistsq = self._nearest(fartree, pos, checkempty, level + 1)
if (result is None) or (fardistsq is not None and fardistsq < distsq):
result, distsq = farnode, fardistsq
return result, distsq
def nearest(self, pos, checkempty=False):
self.nnearest += 1
if self.root is None:
return None, None
self.root.count = 0
node, distsq = self._nearest(self.root, pos, checkempty)
self.count += self.root.count
return node, distsq
def __str__(self):
return str(self.root)
class vector(list):
def __init__(self, *args):
super().__init__([float(a) for a in args])
def __str__(self):
return "<%.1f %.1f %.1f>" % tuple(self[0:3])
def __sub__(self, other):
return vector(self[0] - other[0], self[1] - other[1], self[2] - other[2])
def __add__(self, other):
return vector(self[0] + other[0], self[1] + other[1], self[2] + other[2])
def __mul__(self, other):
s = sum(self[i] * other[i] for i in (0, 1, 2))
# print("ds",s,self,other,[self[i]*other[i] for i in (0,1,2)])
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
return s
def dot(self, other):
return sum(self[k] * other[k] for k in (0, 1, 2))
from random import random, seed, shuffle
from time import time
import unittest
class TestVector(unittest.TestCase):
def test_ops(self):
v1 = vector(1, 0, 0)
v2 = vector(2, 1, 0)
self.assertAlmostEqual(v1 * v2, 2.0)
self.assertAlmostEqual(v1.dot(v2), 2.0)
v3 = vector(-1, -1, 0)
self.assertListEqual(v1 - v2, v3)
v4 = vector(3, 1, 0)
self.assertListEqual(v1 + v2, v4)
class TestHyperRectangle(unittest.TestCase):
def setUp(self):
self.left = vector(0, 0, 0)
self.right = vector(1, 1, 1)
self.left1 = vector(-1, 0, 0)
self.left2 = vector(0, -1, 0)
self.left3 = vector(0, 0, -1)
self.right1 = vector(2, 0, 0)
self.right2 = vector(0, 2, 0)
self.right3 = vector(0, 0, 2)
def test_constructor(self):
hr = Hyperrectangle(3, self.left, self.right)
self.assertListEqual(hr.min, self.left)
self.assertListEqual(hr.max, self.right)
def test_extend(self):
hr = Hyperrectangle(3, self.left, self.right)
hr.extend(self.left1)
self.assertListEqual(hr.min, [-1, 0, 0])
self.assertListEqual(hr.max, [1, 1, 1])
hr.extend(self.left2)
self.assertListEqual(hr.min, [-1, -1, 0])
self.assertListEqual(hr.max, [1, 1, 1])
hr.extend(self.left3)
self.assertListEqual(hr.min, [-1, -1, -1])
self.assertListEqual(hr.max, [1, 1, 1])
hr.extend(self.right1)
self.assertListEqual(hr.min, [-1, -1, -1])
self.assertListEqual(hr.max, [2, 1, 1])
hr.extend(self.right2)
self.assertListEqual(hr.min, [-1, -1, -1])
self.assertListEqual(hr.max, [2, 2, 1])
hr.extend(self.right3)
self.assertListEqual(hr.min, [-1, -1, -1])
self.assertListEqual(hr.max, [2, 2, 2])
def test_distance_squared(self):
hr = Hyperrectangle(3, self.left, self.right)
self.assertAlmostEqual(hr.distance_squared(vector(0.5, 0.5, 0.5)), 0.0)
self.assertAlmostEqual(hr.distance_squared(vector(0, 0, 0)), 0.0)
self.assertAlmostEqual(hr.distance_squared(vector(-1, 0, 0)), 1.0)
self.assertAlmostEqual(hr.distance_squared(vector(2, 0, 0)), 1.0)
self.assertAlmostEqual(hr.distance_squared(vector(2, 2, 2)), 3.0)
self.assertAlmostEqual(hr.distance_squared(vector(0.5, 2, 2)), 2.0)
self.assertAlmostEqual(hr.distance_squared(vector(0.5, -1, -1)), 2.0)
self.assertAlmostEqual(hr.distance_squared(vector(0.5, 0.5, 2)), 1.0)
class TestTree(unittest.TestCase):
def setUp(self):
seed(42)
r = (-1, 0, 1)
self.points = [vector(k, l, m) for k in r for l in r for m in r]
d = (-0.1, 0, 0.1)
self.d = [vector(k, l, m) for k in d for l in d for m in d if (k * l * m) != 0]
self.repeats = 4
def test_simple(self):
tree = Tree(3)
p1 = vector(0, 0, 0)
p2 = vector(-1, 0, 0)
p3 = vector(-1, 1, 0)
d = vector(0.1, 0.1, 0.1)
tree.insert(p1, p1)
node, distsq = tree.nearest(p1)
self.assertListEqual(node.pos, p1)
self.assertAlmostEqual(distsq, 0.0)
node, distsq = tree.nearest(p1 + d)
self.assertListEqual(node.pos, p1)
self.assertAlmostEqual(distsq, 0.03)
tree.insert(p2, p2)
node, distsq = tree.nearest(p1)
self.assertListEqual(node.pos, p1)
self.assertAlmostEqual(distsq, 0.0)
node, distsq = tree.nearest(p1 + d)
self.assertListEqual(node.pos, p1)
self.assertAlmostEqual(distsq, 0.03)
node, distsq = tree.nearest(p2)
self.assertListEqual(node.pos, p2)
self.assertAlmostEqual(distsq, 0.0)
node, distsq = tree.nearest(p2 + d)
self.assertListEqual(node.pos, p2)
self.assertAlmostEqual(distsq, 0.03)
tree.insert(p3, p3)
node, distsq = tree.nearest(p1)
self.assertListEqual(node.pos, p1)
self.assertAlmostEqual(distsq, 0.0)
node, distsq = tree.nearest(p1 + d)
self.assertListEqual(node.pos, p1)
self.assertAlmostEqual(distsq, 0.03)
node, distsq = tree.nearest(p2)
self.assertListEqual(node.pos, p2)
self.assertAlmostEqual(distsq, 0.0)
node, distsq = tree.nearest(p2 + d)
self.assertListEqual(node.pos, p2)
self.assertAlmostEqual(distsq, 0.03)
node, distsq = tree.nearest(p3)
self.assertListEqual(node.pos, p3)
self.assertAlmostEqual(distsq, 0.0)
node, distsq = tree.nearest(p3 + d)
self.assertListEqual(node.pos, p3)
self.assertAlmostEqual(distsq, 0.03)
def test_nearest(self):
for n in range(self.repeats):
tree = Tree(3)
shuffle(self.points)
for p in self.points:
tree.insert(p, p) # data equal to position
for p in self.points:
for d in self.d:
node, distsq = tree.nearest(p + d)
s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
self.assertListEqual(node.pos, p, msg=s)
self.assertListEqual(node.data, p)
self.assertAlmostEqual(distsq, d.dot(d), msg=s)
for p in self.points:
node, distsq = tree.nearest(p)
self.assertListEqual(node.pos, p)
self.assertListEqual(node.data, p)
self.assertAlmostEqual(distsq, 0.0)
def test_nearest_empty(self):
for n in range(self.repeats):
tree = Tree(3)
shuffle(self.points)
for p in self.points:
tree.insert(p, p) # data equal to position
for p in self.points:
for d in self.d:
node, distsq = tree.nearest(p + d)
s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
self.assertListEqual(node.pos, p, msg=s)
self.assertListEqual(node.data, p)
self.assertAlmostEqual(distsq, d.dot(d), msg=s)
for p in self.points:
node, distsq = tree.nearest(p)
self.assertListEqual(node.pos, p)
self.assertListEqual(node.data, p)
self.assertAlmostEqual(distsq, 0.0)
# zeroing out a node should not affect retrieving any other node ...
node, _ = tree.nearest(self.points[-1]) # last point
node.data = None
for p in self.points[:-1]: # all but last
for d in self.d:
node, distsq = tree.nearest(p + d)
s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
self.assertListEqual(node.pos, p, msg=s)
self.assertListEqual(node.data, p)
self.assertAlmostEqual(distsq, d.dot(d), msg=s)
for p in self.points[:-1]: # all but last
node, distsq = tree.nearest(p)
self.assertListEqual(node.pos, p)
self.assertListEqual(node.data, p)
self.assertAlmostEqual(distsq, 0.0)
# ... also, we should be able to retrieve the node anyway ...
node, distsq = tree.nearest(self.points[-1])
self.assertListEqual(node.pos, self.points[-1])
self.assertIsNone(node.data)
self.assertAlmostEqual(distsq, 0.0)
# ... even for points nearby ...
for d in self.d:
node, distsq = tree.nearest(self.points[-1] + d)
self.assertEqual(node.pos, self.points[-1])
self.assertIsNone(node.data)
self.assertAlmostEqual(distsq, d.dot(d))
# ... unless we set checkempty
node, distsq = tree.nearest(self.points[-1], checkempty=True)
self.assertNotEqual(node.pos, self.points[-1])
self.assertIsNotNone(node.data)
self.assertAlmostEqual(distsq, 1.0) # on a perpendicular grid nearest neighbor is at most 1 away
def test_performance(self):
tree = Tree(3)
tsize = 1000
qsize = 1000
emptyq = 3
print("<performance test, may take several seconds>")
qpos = [vector(random(), random(), random()) for p in range(qsize)]
for p in range(tsize):
pos = vector(random(), random(), random())
tree.insert(pos, pos)
s = time()
for p in qpos:
node, distsq = tree.nearest(p)
e = time() - s
print("queries|tree size|tree height|empties|query load|query time")
print("{0:7d}|{2:9d}|{1.level:11d}| 0|{3:10.2f}|{4:10.1f}".format(
qsize, tree, tsize, float(tree.count) / qsize, e)
)
tree.resetcounters()
empty = []
for p in range(tsize * 9):
pos = vector(random(), random(), random())
tree.insert(pos, pos)
if (p % emptyq) == 1:
empty.append(pos)
s = time()
for p in qpos:
node, distsq = tree.nearest(p)
e2 = time() - s
print("{0:7d}|{2:9d}|{1.level:11d}| 0|{3:10.2f}|{4:10.1f}".format(
qsize, tree, tsize * 10, float(tree.count) / qsize, e2)
)
self.assertLess(e2, 3 * e, msg="a 10x bigger tree shouldn't take more than 3x the time to query")
for p in empty:
node, distsq = tree.nearest(p)
node.data = None
tree.resetcounters()
s = time()
for p in qpos:
node, distsq = tree.nearest(p, checkempty=True)
e3 = time() - s
print("{0:7d}|{2:9d}|{1.level:11d}|{5:7d}|{3:10.2f}|{4:10.1f}".format(
qsize, tree, tsize * 10, float(tree.count) / qsize, e3, tsize * 10 // emptyq)
)