From 6c6e06e1c45d79695786a748d7a8643821be80fc Mon Sep 17 00:00:00 2001 From: Vojtech Moravec <vojtech.moravec.st@vsb.cz> Date: Fri, 4 Sep 2020 10:45:18 +0200 Subject: [PATCH] Finish first implementation of searching in k-DTree. --- .../java/azgracompress/kdtree/KDNode.java | 29 +++++++++----- .../java/azgracompress/kdtree/KDTree.java | 39 +++++++++++-------- .../azgracompress/kdtree/TerminalKDNode.java | 10 ++++- .../quantization/vector/VectorQuantizer.java | 8 ++-- .../java/azgracompress/utilities/Utils.java | 16 ++++++++ 5 files changed, 69 insertions(+), 33 deletions(-) diff --git a/src/main/java/azgracompress/kdtree/KDNode.java b/src/main/java/azgracompress/kdtree/KDNode.java index c911b54..a6aa6a4 100644 --- a/src/main/java/azgracompress/kdtree/KDNode.java +++ b/src/main/java/azgracompress/kdtree/KDNode.java @@ -44,26 +44,34 @@ public class KDNode { public void findNearestNeighbor(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { + if (searchInfo.stopSearching()) + return; + if (isTerminal()) { - // TODO: Examine records in bucket(node), updating PQD , PQR > , ((TerminalKDNode) this).findNearestNeighborInBucket(queryRecord, searchInfo); - if (!ballWithinBounds(queryRecord, searchInfo)) + if (ballWithinBounds(queryRecord, searchInfo)) { + searchInfo.setContinueSearching(false); return; + } } - assert (loSon != null && hiSon != null); + assert (loSon != null && hiSon != null); if (queryRecord[discriminator] <= partition) { double tmp = searchInfo.getUpperBounds()[discriminator]; searchInfo.getUpperBounds()[discriminator] = partition; loSon.findNearestNeighbor(queryRecord, searchInfo); searchInfo.getUpperBounds()[discriminator] = tmp; + } else { double tmp = searchInfo.getLowerBounds()[discriminator]; searchInfo.getLowerBounds()[discriminator] = partition; hiSon.findNearestNeighbor(queryRecord, searchInfo); searchInfo.getLowerBounds()[discriminator] = tmp; } + if (searchInfo.stopSearching()) + return; + if (queryRecord[discriminator] <= partition) { double tmp = searchInfo.getLowerBounds()[discriminator]; @@ -80,9 +88,11 @@ public class KDNode { } searchInfo.getUpperBounds()[discriminator] = tmp; } - - if (!ballWithinBounds(queryRecord, searchInfo)) { + if (searchInfo.stopSearching()) return; + + if (ballWithinBounds(queryRecord, searchInfo)) { + searchInfo.setContinueSearching(false); } } @@ -90,17 +100,16 @@ public class KDNode { return Math.pow((x - y), 2); } - private static double dissim(final double value) { + private static double dissimilarity(final double value) { return Math.sqrt(value); } - @SuppressWarnings("BooleanMethodIsAlwaysInverted") private boolean ballWithinBounds(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { double lbDist, ubDist; for (int d = 0; d < searchInfo.getDimension(); d++) { lbDist = coordinateDistance(searchInfo.getLowerBounds()[d], queryRecord[d]); ubDist = coordinateDistance(searchInfo.getUpperBounds()[d], queryRecord[d]); - if ((lbDist <= searchInfo.getCurrentClosestDistance()) || (ubDist <= searchInfo.getCurrentClosestDistance())) { + if ((lbDist <= searchInfo.getNearestRecordDistance()) || (ubDist <= searchInfo.getNearestRecordDistance())) { return false; } } @@ -112,12 +121,12 @@ public class KDNode { for (int d = 0; d < searchInfo.getDimension(); d++) { if (queryRecord[d] < searchInfo.getLowerBounds()[d]) { sum += coordinateDistance(queryRecord[d], searchInfo.getLowerBounds()[d]); - if (dissim(sum) > searchInfo.getCurrentClosestDistance()) { + if (dissimilarity(sum) > searchInfo.getNearestRecordDistance()) { return true; } } else if (queryRecord[d] > searchInfo.getUpperBounds()[d]) { sum += coordinateDistance(queryRecord[d], searchInfo.getUpperBounds()[d]); - if (dissim(sum) > searchInfo.getCurrentClosestDistance()) { + if (dissimilarity(sum) > searchInfo.getNearestRecordDistance()) { return true; } } diff --git a/src/main/java/azgracompress/kdtree/KDTree.java b/src/main/java/azgracompress/kdtree/KDTree.java index e97f09b..71e3890 100644 --- a/src/main/java/azgracompress/kdtree/KDTree.java +++ b/src/main/java/azgracompress/kdtree/KDTree.java @@ -11,15 +11,16 @@ public class KDTree { private final int terminalNodeCount; public static class SearchInfo { - private double currentClosestDistance; - private int[] currentClosestRecord = null; + private boolean continueSearching = true; + private double nearestRecordDistance; + private int[] nearestRecord = null; private final double[] coordinateUpperBound; private final double[] coordinateLowerBound; private final int dimension; public SearchInfo(final int dimension) { this.dimension = dimension; - currentClosestDistance = Double.POSITIVE_INFINITY; + nearestRecordDistance = Double.POSITIVE_INFINITY; coordinateUpperBound = new double[dimension]; coordinateLowerBound = new double[dimension]; Arrays.fill(coordinateLowerBound, Double.NEGATIVE_INFINITY); @@ -30,20 +31,12 @@ public class KDTree { return dimension; } - public double getCurrentClosestDistance() { - return currentClosestDistance; + public double getNearestRecordDistance() { + return nearestRecordDistance; } - public void setCurrentClosestDistance(double currentClosestDistance) { - this.currentClosestDistance = currentClosestDistance; - } - - public int[] getCurrentClosestRecord() { - return currentClosestRecord; - } - - public void setCurrentClosestRecord(int[] currentClosestRecord) { - this.currentClosestRecord = currentClosestRecord; + public int[] getNearestRecord() { + return nearestRecord; } public double[] getUpperBounds() { @@ -53,6 +46,19 @@ public class KDTree { public double[] getLowerBounds() { return coordinateLowerBound; } + + public boolean stopSearching() { + return !continueSearching; + } + + public void setContinueSearching(boolean continueSearching) { + this.continueSearching = continueSearching; + } + + public void setNearestRecord(final int[] record, final double recordDistance) { + this.nearestRecord = record; + this.nearestRecordDistance = recordDistance; + } } public KDTree(final KDNode root, @@ -68,10 +74,9 @@ public class KDTree { } public int[] findNearestNeighbor(final int[] queryRecord) { - // TODO(Moravec): Read more about Ball Within Bounds and Bounds Overlap Ball SearchInfo searchInfo = new SearchInfo(dimension); root.findNearestNeighbor(queryRecord, searchInfo); - return searchInfo.currentClosestRecord; + return searchInfo.nearestRecord; } public int getTotalNodeCount() { diff --git a/src/main/java/azgracompress/kdtree/TerminalKDNode.java b/src/main/java/azgracompress/kdtree/TerminalKDNode.java index 881ba78..e1425f3 100644 --- a/src/main/java/azgracompress/kdtree/TerminalKDNode.java +++ b/src/main/java/azgracompress/kdtree/TerminalKDNode.java @@ -1,5 +1,7 @@ package azgracompress.kdtree; +import azgracompress.utilities.Utils; + public class TerminalKDNode extends KDNode { private final int[][] bucket; @@ -19,6 +21,12 @@ public class TerminalKDNode extends KDNode { } public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { - + double recordDistance; + for (final int[] record : bucket) { + recordDistance = Utils.calculateEuclideanDistance(queryRecord, record); + if (recordDistance < searchInfo.getNearestRecordDistance()) { + searchInfo.setNearestRecord(record, recordDistance); + } + } } } diff --git a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java index 5413e9d..d00696e 100644 --- a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java +++ b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java @@ -1,5 +1,7 @@ package azgracompress.quantization.vector; +import azgracompress.utilities.Utils; + public class VectorQuantizer { private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean; @@ -104,11 +106,7 @@ public class VectorQuantizer { return sum; } case Euclidean: { - double sum = 0.0; - for (int i = 0; i < originalDataVector.length; i++) { - sum += Math.pow(((double) originalDataVector[i] - (double) codebookEntry[i]), 2); - } - return Math.sqrt(sum); + return Utils.calculateEuclideanDistance(originalDataVector, codebookEntry); } case MaxDiff: { double maxDiff = Double.MIN_VALUE; diff --git a/src/main/java/azgracompress/utilities/Utils.java b/src/main/java/azgracompress/utilities/Utils.java index cf0bc15..7ec9b1e 100644 --- a/src/main/java/azgracompress/utilities/Utils.java +++ b/src/main/java/azgracompress/utilities/Utils.java @@ -139,4 +139,20 @@ public class Utils { final double mse = (sum / (double) difference.length); return mse; } + + /** + * Calculate the euclidean distance between two vectors. + * + * @param v1 First vector. + * @param v2 Second vector. + * @return Euclidean distance. + */ + public static double calculateEuclideanDistance(final int[] v1, final int[] v2) { + assert (v1.length == v2.length); + double sum = 0.0; + for (int i = 0; i < v1.length; i++) { + sum += Math.pow(((double) v1[i] - (double) v2[i]), 2); + } + return Math.sqrt(sum); + } } \ No newline at end of file -- GitLab