diff --git a/src/main/java/azgracompress/kdtree/KDNode.java b/src/main/java/azgracompress/kdtree/KDNode.java index dfa09f3118d26d232a84c892289aafb520660821..c911b5463111523581e55707233c0f9947c56945 100644 --- a/src/main/java/azgracompress/kdtree/KDNode.java +++ b/src/main/java/azgracompress/kdtree/KDNode.java @@ -1,23 +1,23 @@ package azgracompress.kdtree; public class KDNode { - private final int keyIndex; - private final int median; + private final int discriminator; + private final int partition; private final KDNode loSon; private final KDNode hiSon; protected KDNode() { - keyIndex = -1; - median = -1; + discriminator = -1; + partition = -1; loSon = null; hiSon = null; } public KDNode(final int keyIndex, final int median, final KDNode loSon, final KDNode hiSon) { - this.keyIndex = keyIndex; - this.median = median; + this.discriminator = keyIndex; + this.partition = median; this.loSon = loSon; this.hiSon = hiSon; } @@ -30,15 +30,98 @@ public class KDNode { return hiSon; } - public final int getKeyIndex() { - return keyIndex; + public final int getDiscriminator() { + return discriminator; } - public final int getMedian() { - return median; + public final int getPartition() { + return partition; } public boolean isTerminal() { return false; } + + public void findNearestNeighbor(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { + + if (isTerminal()) { + // TODO: Examine records in bucket(node), updating PQD , PQR > , + ((TerminalKDNode) this).findNearestNeighborInBucket(queryRecord, searchInfo); + + if (!ballWithinBounds(queryRecord, searchInfo)) + return; + } + assert (loSon != null && hiSon != null); + + if (queryRecord[discriminator] <= partition) { + double tmp = searchInfo.getUpperBounds()[discriminator]; + searchInfo.getUpperBounds()[discriminator] = partition; + loSon.findNearestNeighbor(queryRecord, searchInfo); + searchInfo.getUpperBounds()[discriminator] = tmp; + } else { + double tmp = searchInfo.getLowerBounds()[discriminator]; + searchInfo.getLowerBounds()[discriminator] = partition; + hiSon.findNearestNeighbor(queryRecord, searchInfo); + searchInfo.getLowerBounds()[discriminator] = tmp; + } + + if (queryRecord[discriminator] <= partition) { + double tmp = searchInfo.getLowerBounds()[discriminator]; + searchInfo.getLowerBounds()[discriminator] = partition; + if (boundsOverlapBall(queryRecord, searchInfo)) { + hiSon.findNearestNeighbor(queryRecord, searchInfo); + } + searchInfo.getLowerBounds()[discriminator] = tmp; + } else { + double tmp = searchInfo.getUpperBounds()[discriminator]; + searchInfo.getUpperBounds()[discriminator] = partition; + if (boundsOverlapBall(queryRecord, searchInfo)) { + loSon.findNearestNeighbor(queryRecord, searchInfo); + } + searchInfo.getUpperBounds()[discriminator] = tmp; + } + + if (!ballWithinBounds(queryRecord, searchInfo)) { + return; + } + } + + private static double coordinateDistance(final double x, final double y) { + return Math.pow((x - y), 2); + } + + private static double dissim(final double value) { + return Math.sqrt(value); + } + + @SuppressWarnings("BooleanMethodIsAlwaysInverted") + private boolean ballWithinBounds(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { + double lbDist, ubDist; + for (int d = 0; d < searchInfo.getDimension(); d++) { + lbDist = coordinateDistance(searchInfo.getLowerBounds()[d], queryRecord[d]); + ubDist = coordinateDistance(searchInfo.getUpperBounds()[d], queryRecord[d]); + if ((lbDist <= searchInfo.getCurrentClosestDistance()) || (ubDist <= searchInfo.getCurrentClosestDistance())) { + return false; + } + } + return true; + } + + private boolean boundsOverlapBall(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { + double sum = 0.0; + for (int d = 0; d < searchInfo.getDimension(); d++) { + if (queryRecord[d] < searchInfo.getLowerBounds()[d]) { + sum += coordinateDistance(queryRecord[d], searchInfo.getLowerBounds()[d]); + if (dissim(sum) > searchInfo.getCurrentClosestDistance()) { + return true; + } + } else if (queryRecord[d] > searchInfo.getUpperBounds()[d]) { + sum += coordinateDistance(queryRecord[d], searchInfo.getUpperBounds()[d]); + if (dissim(sum) > searchInfo.getCurrentClosestDistance()) { + return true; + } + } + } + return false; + } } diff --git a/src/main/java/azgracompress/kdtree/KDTree.java b/src/main/java/azgracompress/kdtree/KDTree.java new file mode 100644 index 0000000000000000000000000000000000000000..e97f09b348acc59d94f9faffdb14d4d0f2460453 --- /dev/null +++ b/src/main/java/azgracompress/kdtree/KDTree.java @@ -0,0 +1,84 @@ +package azgracompress.kdtree; + +import java.util.Arrays; + +public class KDTree { + private final int maximumBucketSize; + private final KDNode root; + + private final int dimension; + private final int totalNodeCount; + private final int terminalNodeCount; + + public static class SearchInfo { + private double currentClosestDistance; + private int[] currentClosestRecord = null; + private final double[] coordinateUpperBound; + private final double[] coordinateLowerBound; + private final int dimension; + + public SearchInfo(final int dimension) { + this.dimension = dimension; + currentClosestDistance = Double.POSITIVE_INFINITY; + coordinateUpperBound = new double[dimension]; + coordinateLowerBound = new double[dimension]; + Arrays.fill(coordinateLowerBound, Double.NEGATIVE_INFINITY); + Arrays.fill(coordinateUpperBound, Double.POSITIVE_INFINITY); + } + + public int getDimension() { + return dimension; + } + + public double getCurrentClosestDistance() { + return currentClosestDistance; + } + + public void setCurrentClosestDistance(double currentClosestDistance) { + this.currentClosestDistance = currentClosestDistance; + } + + public int[] getCurrentClosestRecord() { + return currentClosestRecord; + } + + public void setCurrentClosestRecord(int[] currentClosestRecord) { + this.currentClosestRecord = currentClosestRecord; + } + + public double[] getUpperBounds() { + return coordinateUpperBound; + } + + public double[] getLowerBounds() { + return coordinateLowerBound; + } + } + + public KDTree(final KDNode root, + final int dimension, + final int maximumBucketSize, + final int totalNodeCount, + final int terminalNodeCount) { + this.root = root; + this.dimension = dimension; + this.maximumBucketSize = maximumBucketSize; + this.totalNodeCount = totalNodeCount; + this.terminalNodeCount = terminalNodeCount; + } + + public int[] findNearestNeighbor(final int[] queryRecord) { + // TODO(Moravec): Read more about Ball Within Bounds and Bounds Overlap Ball + SearchInfo searchInfo = new SearchInfo(dimension); + root.findNearestNeighbor(queryRecord, searchInfo); + return searchInfo.currentClosestRecord; + } + + public int getTotalNodeCount() { + return totalNodeCount; + } + + public int getTerminalNodeCount() { + return terminalNodeCount; + } +} diff --git a/src/main/java/azgracompress/kdtree/KDTreeBuilder.java b/src/main/java/azgracompress/kdtree/KDTreeBuilder.java index 91ac7203b9287937996a3dc7e04db445b166a8f9..8503db4c313d434ce316224b7cbb59c8847e8830 100644 --- a/src/main/java/azgracompress/kdtree/KDTreeBuilder.java +++ b/src/main/java/azgracompress/kdtree/KDTreeBuilder.java @@ -33,8 +33,14 @@ public class KDTreeBuilder { this.dimension = dimension; } + public KDTree buildTree(final int[][] records) { + nodeCount = 0; + terminalNodeCount = 0; + final KDNode rootNode = buildTreeImpl(records); + return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount); + } - public KDNode buildTree(final int[][] records) { + public KDNode buildTreeImpl(final int[][] records) { if (records.length <= bucketSize) { return makeTerminalNode(records); } @@ -73,8 +79,8 @@ public class KDTreeBuilder { } private KDNode makeNonTerminalNode(final int keyIndex, final int median, final DividedRecords dividedRecords) { - final KDNode loSon = buildTree(dividedRecords.getLoRecords()); - final KDNode hiSon = buildTree(dividedRecords.getHiRecords()); + final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords()); + final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords()); ++nodeCount; return new KDNode(keyIndex, median, loSon, hiSon); } @@ -113,16 +119,9 @@ public class KDTreeBuilder { for (final int[] record : records) { spread += Math.pow(((double) center - (double) record[keyIndex]), 2); + // spread += Math.abs(center - record[keyIndex]); } - return Math.sqrt(spread); - } - - public int getNodeCount() { - return nodeCount; - } - - public int getTerminalNodeCount() { - return terminalNodeCount; + // return (spread / (double) records.length); } } diff --git a/src/main/java/azgracompress/kdtree/TerminalKDNode.java b/src/main/java/azgracompress/kdtree/TerminalKDNode.java index 6dc239b1f0b1aef3a092fb246b807a4f22411a97..881ba78cfd64d362a51356c7bab31e6e929957d5 100644 --- a/src/main/java/azgracompress/kdtree/TerminalKDNode.java +++ b/src/main/java/azgracompress/kdtree/TerminalKDNode.java @@ -17,4 +17,8 @@ public class TerminalKDNode extends KDNode { public int[][] getBucket() { return bucket; } + + public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { + + } }