From 6c6e06e1c45d79695786a748d7a8643821be80fc Mon Sep 17 00:00:00 2001
From: Vojtech Moravec <vojtech.moravec.st@vsb.cz>
Date: Fri, 4 Sep 2020 10:45:18 +0200
Subject: [PATCH] Finish first implementation of searching in k-DTree.

---
 .../java/azgracompress/kdtree/KDNode.java     | 29 +++++++++-----
 .../java/azgracompress/kdtree/KDTree.java     | 39 +++++++++++--------
 .../azgracompress/kdtree/TerminalKDNode.java  | 10 ++++-
 .../quantization/vector/VectorQuantizer.java  |  8 ++--
 .../java/azgracompress/utilities/Utils.java   | 16 ++++++++
 5 files changed, 69 insertions(+), 33 deletions(-)

diff --git a/src/main/java/azgracompress/kdtree/KDNode.java b/src/main/java/azgracompress/kdtree/KDNode.java
index c911b54..a6aa6a4 100644
--- a/src/main/java/azgracompress/kdtree/KDNode.java
+++ b/src/main/java/azgracompress/kdtree/KDNode.java
@@ -44,26 +44,34 @@ public class KDNode {
 
     public void findNearestNeighbor(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
 
+        if (searchInfo.stopSearching())
+            return;
+
         if (isTerminal()) {
-            // TODO: Examine records in bucket(node), updating PQD , PQR > ,
             ((TerminalKDNode) this).findNearestNeighborInBucket(queryRecord, searchInfo);
 
-            if (!ballWithinBounds(queryRecord, searchInfo))
+            if (ballWithinBounds(queryRecord, searchInfo)) {
+                searchInfo.setContinueSearching(false);
                 return;
+            }
         }
-        assert (loSon != null && hiSon != null);
 
+        assert (loSon != null && hiSon != null);
         if (queryRecord[discriminator] <= partition) {
             double tmp = searchInfo.getUpperBounds()[discriminator];
             searchInfo.getUpperBounds()[discriminator] = partition;
             loSon.findNearestNeighbor(queryRecord, searchInfo);
             searchInfo.getUpperBounds()[discriminator] = tmp;
+
         } else {
             double tmp = searchInfo.getLowerBounds()[discriminator];
             searchInfo.getLowerBounds()[discriminator] = partition;
             hiSon.findNearestNeighbor(queryRecord, searchInfo);
             searchInfo.getLowerBounds()[discriminator] = tmp;
         }
+        if (searchInfo.stopSearching())
+            return;
+
 
         if (queryRecord[discriminator] <= partition) {
             double tmp = searchInfo.getLowerBounds()[discriminator];
@@ -80,9 +88,11 @@ public class KDNode {
             }
             searchInfo.getUpperBounds()[discriminator] = tmp;
         }
-
-        if (!ballWithinBounds(queryRecord, searchInfo)) {
+        if (searchInfo.stopSearching())
             return;
+
+        if (ballWithinBounds(queryRecord, searchInfo)) {
+            searchInfo.setContinueSearching(false);
         }
     }
 
@@ -90,17 +100,16 @@ public class KDNode {
         return Math.pow((x - y), 2);
     }
 
-    private static double dissim(final double value) {
+    private static double dissimilarity(final double value) {
         return Math.sqrt(value);
     }
 
-    @SuppressWarnings("BooleanMethodIsAlwaysInverted")
     private boolean ballWithinBounds(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
         double lbDist, ubDist;
         for (int d = 0; d < searchInfo.getDimension(); d++) {
             lbDist = coordinateDistance(searchInfo.getLowerBounds()[d], queryRecord[d]);
             ubDist = coordinateDistance(searchInfo.getUpperBounds()[d], queryRecord[d]);
-            if ((lbDist <= searchInfo.getCurrentClosestDistance()) || (ubDist <= searchInfo.getCurrentClosestDistance())) {
+            if ((lbDist <= searchInfo.getNearestRecordDistance()) || (ubDist <= searchInfo.getNearestRecordDistance())) {
                 return false;
             }
         }
@@ -112,12 +121,12 @@ public class KDNode {
         for (int d = 0; d < searchInfo.getDimension(); d++) {
             if (queryRecord[d] < searchInfo.getLowerBounds()[d]) {
                 sum += coordinateDistance(queryRecord[d], searchInfo.getLowerBounds()[d]);
-                if (dissim(sum) > searchInfo.getCurrentClosestDistance()) {
+                if (dissimilarity(sum) > searchInfo.getNearestRecordDistance()) {
                     return true;
                 }
             } else if (queryRecord[d] > searchInfo.getUpperBounds()[d]) {
                 sum += coordinateDistance(queryRecord[d], searchInfo.getUpperBounds()[d]);
-                if (dissim(sum) > searchInfo.getCurrentClosestDistance()) {
+                if (dissimilarity(sum) > searchInfo.getNearestRecordDistance()) {
                     return true;
                 }
             }
diff --git a/src/main/java/azgracompress/kdtree/KDTree.java b/src/main/java/azgracompress/kdtree/KDTree.java
index e97f09b..71e3890 100644
--- a/src/main/java/azgracompress/kdtree/KDTree.java
+++ b/src/main/java/azgracompress/kdtree/KDTree.java
@@ -11,15 +11,16 @@ public class KDTree {
     private final int terminalNodeCount;
 
     public static class SearchInfo {
-        private double currentClosestDistance;
-        private int[] currentClosestRecord = null;
+        private boolean continueSearching = true;
+        private double nearestRecordDistance;
+        private int[] nearestRecord = null;
         private final double[] coordinateUpperBound;
         private final double[] coordinateLowerBound;
         private final int dimension;
 
         public SearchInfo(final int dimension) {
             this.dimension = dimension;
-            currentClosestDistance = Double.POSITIVE_INFINITY;
+            nearestRecordDistance = Double.POSITIVE_INFINITY;
             coordinateUpperBound = new double[dimension];
             coordinateLowerBound = new double[dimension];
             Arrays.fill(coordinateLowerBound, Double.NEGATIVE_INFINITY);
@@ -30,20 +31,12 @@ public class KDTree {
             return dimension;
         }
 
-        public double getCurrentClosestDistance() {
-            return currentClosestDistance;
+        public double getNearestRecordDistance() {
+            return nearestRecordDistance;
         }
 
-        public void setCurrentClosestDistance(double currentClosestDistance) {
-            this.currentClosestDistance = currentClosestDistance;
-        }
-
-        public int[] getCurrentClosestRecord() {
-            return currentClosestRecord;
-        }
-
-        public void setCurrentClosestRecord(int[] currentClosestRecord) {
-            this.currentClosestRecord = currentClosestRecord;
+        public int[] getNearestRecord() {
+            return nearestRecord;
         }
 
         public double[] getUpperBounds() {
@@ -53,6 +46,19 @@ public class KDTree {
         public double[] getLowerBounds() {
             return coordinateLowerBound;
         }
+
+        public boolean stopSearching() {
+            return !continueSearching;
+        }
+
+        public void setContinueSearching(boolean continueSearching) {
+            this.continueSearching = continueSearching;
+        }
+
+        public void setNearestRecord(final int[] record, final double recordDistance) {
+            this.nearestRecord = record;
+            this.nearestRecordDistance = recordDistance;
+        }
     }
 
     public KDTree(final KDNode root,
@@ -68,10 +74,9 @@ public class KDTree {
     }
 
     public int[] findNearestNeighbor(final int[] queryRecord) {
-        // TODO(Moravec): Read more about Ball Within Bounds and Bounds Overlap Ball
         SearchInfo searchInfo = new SearchInfo(dimension);
         root.findNearestNeighbor(queryRecord, searchInfo);
-        return searchInfo.currentClosestRecord;
+        return searchInfo.nearestRecord;
     }
 
     public int getTotalNodeCount() {
diff --git a/src/main/java/azgracompress/kdtree/TerminalKDNode.java b/src/main/java/azgracompress/kdtree/TerminalKDNode.java
index 881ba78..e1425f3 100644
--- a/src/main/java/azgracompress/kdtree/TerminalKDNode.java
+++ b/src/main/java/azgracompress/kdtree/TerminalKDNode.java
@@ -1,5 +1,7 @@
 package azgracompress.kdtree;
 
+import azgracompress.utilities.Utils;
+
 public class TerminalKDNode extends KDNode {
 
     private final int[][] bucket;
@@ -19,6 +21,12 @@ public class TerminalKDNode extends KDNode {
     }
 
     public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
-
+        double recordDistance;
+        for (final int[] record : bucket) {
+            recordDistance = Utils.calculateEuclideanDistance(queryRecord, record);
+            if (recordDistance < searchInfo.getNearestRecordDistance()) {
+                searchInfo.setNearestRecord(record, recordDistance);
+            }
+        }
     }
 }
diff --git a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java
index 5413e9d..d00696e 100644
--- a/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java
+++ b/src/main/java/azgracompress/quantization/vector/VectorQuantizer.java
@@ -1,5 +1,7 @@
 package azgracompress.quantization.vector;
 
+import azgracompress.utilities.Utils;
+
 public class VectorQuantizer {
 
     private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean;
@@ -104,11 +106,7 @@ public class VectorQuantizer {
                 return sum;
             }
             case Euclidean: {
-                double sum = 0.0;
-                for (int i = 0; i < originalDataVector.length; i++) {
-                    sum += Math.pow(((double) originalDataVector[i] - (double) codebookEntry[i]), 2);
-                }
-                return Math.sqrt(sum);
+                return Utils.calculateEuclideanDistance(originalDataVector, codebookEntry);
             }
             case MaxDiff: {
                 double maxDiff = Double.MIN_VALUE;
diff --git a/src/main/java/azgracompress/utilities/Utils.java b/src/main/java/azgracompress/utilities/Utils.java
index cf0bc15..7ec9b1e 100644
--- a/src/main/java/azgracompress/utilities/Utils.java
+++ b/src/main/java/azgracompress/utilities/Utils.java
@@ -139,4 +139,20 @@ public class Utils {
         final double mse = (sum / (double) difference.length);
         return mse;
     }
+
+    /**
+     * Calculate the euclidean distance between two vectors.
+     *
+     * @param v1 First vector.
+     * @param v2 Second vector.
+     * @return Euclidean distance.
+     */
+    public static double calculateEuclideanDistance(final int[] v1, final int[] v2) {
+        assert (v1.length == v2.length);
+        double sum = 0.0;
+        for (int i = 0; i < v1.length; i++) {
+            sum += Math.pow(((double) v1[i] - (double) v2[i]), 2);
+        }
+        return Math.sqrt(sum);
+    }
 }
\ No newline at end of file
-- 
GitLab