From 8bec85f8c40a05fc0986677b27ee25950f3e56ad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Vojt=C4=9Bch=20Moravec?= <theazgra@gmail.com>
Date: Sun, 6 Sep 2020 21:41:12 +0200
Subject: [PATCH] First test of BBF search algorithm.

Mostly based on https://github.com/iwyoo/kd_tree
---
 .../java/azgracompress/kdtree/KDNode.java     |  3 +
 .../java/azgracompress/kdtree/KDTree.java     | 98 ++++++++++++++++---
 .../azgracompress/kdtree/TerminalKDNode.java  |  2 +-
 3 files changed, 86 insertions(+), 17 deletions(-)

diff --git a/src/main/java/azgracompress/kdtree/KDNode.java b/src/main/java/azgracompress/kdtree/KDNode.java
index 1cc3a2a..8d8081f 100644
--- a/src/main/java/azgracompress/kdtree/KDNode.java
+++ b/src/main/java/azgracompress/kdtree/KDNode.java
@@ -44,6 +44,9 @@ public class KDNode {
 
     public void findNearestNeighbor(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
 
+
+
+
         if (searchInfo.stopSearching())
             return;
 
diff --git a/src/main/java/azgracompress/kdtree/KDTree.java b/src/main/java/azgracompress/kdtree/KDTree.java
index fb1938f..1bb50c1 100644
--- a/src/main/java/azgracompress/kdtree/KDTree.java
+++ b/src/main/java/azgracompress/kdtree/KDTree.java
@@ -1,12 +1,16 @@
 package azgracompress.kdtree;
 
+import org.jetbrains.annotations.NotNull;
+
 import java.util.Arrays;
+import java.util.PriorityQueue;
 
 // TODO(Moravec):   One more time read the paper and check the implementation!
 //                  Fix the spreadest function (max-min) may be used.
 //                  https://dl.acm.org/doi/pdf/10.1145/355744.355745
 //                  Actually get rid of this kdTree and look into BBF (Best Bin First)
 //                  https://www.cs.ubc.ca/~lowe/papers/cvpr97.pdf
+//                  https://github.com/iwyoo/kd_tree/blob/master/kd_tree.cxx
 
 public class KDTree {
     private final int maximumBucketSize;
@@ -16,17 +20,37 @@ public class KDTree {
     private final int totalNodeCount;
     private final int terminalNodeCount;
 
-    public static class SearchInfo {
-        private boolean continueSearching = true;
+    public static class BBFSearchInfo {
         private double nearestRecordDistance;
-        private int[] nearestRecord = null;
+        private int[] nearestRecord;
+
+        public BBFSearchInfo() {
+            nearestRecord = null;
+            nearestRecordDistance = Double.POSITIVE_INFINITY;
+        }
+
+        public double getNearestRecordDistance() {
+            return nearestRecordDistance;
+        }
+
+        public int[] getNearestRecord() {
+            return nearestRecord;
+        }
+
+        public void setNearestRecord(final int[] record, final double recordDistance) {
+            this.nearestRecord = record;
+            this.nearestRecordDistance = recordDistance;
+        }
+    }
+
+    public static class SearchInfo extends BBFSearchInfo {
+        private boolean continueSearching = true;
         private final double[] coordinateUpperBound;
         private final double[] coordinateLowerBound;
         private final int dimension;
 
         public SearchInfo(final int dimension) {
             this.dimension = dimension;
-            nearestRecordDistance = Double.POSITIVE_INFINITY;
             coordinateUpperBound = new double[dimension];
             coordinateLowerBound = new double[dimension];
             Arrays.fill(coordinateLowerBound, Double.NEGATIVE_INFINITY);
@@ -37,14 +61,6 @@ public class KDTree {
             return dimension;
         }
 
-        public double getNearestRecordDistance() {
-            return nearestRecordDistance;
-        }
-
-        public int[] getNearestRecord() {
-            return nearestRecord;
-        }
-
         public double[] getUpperBounds() {
             return coordinateUpperBound;
         }
@@ -60,10 +76,28 @@ public class KDTree {
         public void setContinueSearching(boolean continueSearching) {
             this.continueSearching = continueSearching;
         }
+    }
 
-        public void setNearestRecord(final int[] record, final double recordDistance) {
-            this.nearestRecord = record;
-            this.nearestRecordDistance = recordDistance;
+    private static class NodeWithDistance implements Comparable<NodeWithDistance> {
+        private final KDNode node;
+        private final double distance;
+
+        private NodeWithDistance(KDNode node, double distance) {
+            this.node = node;
+            this.distance = distance;
+        }
+
+        public KDNode getNode() {
+            return node;
+        }
+
+        public double getDistance() {
+            return distance;
+        }
+
+        @Override
+        public int compareTo(@NotNull KDTree.NodeWithDistance o) {
+            return Double.compare(distance, o.distance);
         }
     }
 
@@ -82,7 +116,39 @@ public class KDTree {
     public int[] findNearestNeighbor(final int[] queryRecord) {
         SearchInfo searchInfo = new SearchInfo(dimension);
         root.findNearestNeighbor(queryRecord, searchInfo);
-        return searchInfo.nearestRecord;
+        return searchInfo.getNearestRecord();
+    }
+
+    public int[] findNearestBBF(final int[] queryVector, final int maxE) {
+
+        PriorityQueue<NodeWithDistance> priorityQueue = new PriorityQueue<>();
+        priorityQueue.add(new NodeWithDistance(root, 0.0));
+
+        BBFSearchInfo searchInfo = new BBFSearchInfo();
+        int tryIndex = 0;
+        int partition, discriminator;
+        while (!priorityQueue.isEmpty() && tryIndex < maxE) {
+            NodeWithDistance current = priorityQueue.remove();
+            if (current.getNode().isTerminal()) {
+                ((TerminalKDNode) current.getNode()).findNearestNeighborInBucket(queryVector, searchInfo);
+                ++tryIndex;
+            } else {
+                discriminator = current.getNode().getDiscriminator();
+                partition = current.getNode().getPartition();
+                if (queryVector[discriminator] < partition) {
+                    priorityQueue.add(new NodeWithDistance(current.getNode().getLoSon(),
+                            0.0));
+                    priorityQueue.add(new NodeWithDistance(current.getNode().getHiSon(),
+                            (double) partition - (double) queryVector[discriminator]));
+                } else {
+                    priorityQueue.add(new NodeWithDistance(current.getNode().getHiSon(),
+                            0.0));
+                    priorityQueue.add(new NodeWithDistance(current.getNode().getLoSon(),
+                            (double) queryVector[discriminator] - (double) partition));
+                }
+            }
+        }
+        return searchInfo.getNearestRecord();
     }
 
     public int getTotalNodeCount() {
diff --git a/src/main/java/azgracompress/kdtree/TerminalKDNode.java b/src/main/java/azgracompress/kdtree/TerminalKDNode.java
index e1425f3..eed1906 100644
--- a/src/main/java/azgracompress/kdtree/TerminalKDNode.java
+++ b/src/main/java/azgracompress/kdtree/TerminalKDNode.java
@@ -20,7 +20,7 @@ public class TerminalKDNode extends KDNode {
         return bucket;
     }
 
-    public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
+    public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.BBFSearchInfo searchInfo) {
         double recordDistance;
         for (final int[] record : bucket) {
             recordDistance = Utils.calculateEuclideanDistance(queryRecord, record);
-- 
GitLab