From e75cc073b03ff863b290e40e0cfbe659610b48a8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Vojt=C4=9Bch=20Moravec?= <theazgra@gmail.com>
Date: Sun, 6 Sep 2020 20:24:16 +0200
Subject: [PATCH] Clean KDTree building code.

Refactored the build code, with new proper variable names and some
simplifications.
Also the dimension to split on is calculated as the variance in the
dimension.
---
 .../azgracompress/kdtree/KDTreeBuilder.java   | 150 +++++++++++++-----
 1 file changed, 106 insertions(+), 44 deletions(-)

diff --git a/src/main/java/azgracompress/kdtree/KDTreeBuilder.java b/src/main/java/azgracompress/kdtree/KDTreeBuilder.java
index 8503db4..fd53e54 100644
--- a/src/main/java/azgracompress/kdtree/KDTreeBuilder.java
+++ b/src/main/java/azgracompress/kdtree/KDTreeBuilder.java
@@ -4,7 +4,6 @@ import java.util.ArrayList;
 import java.util.Arrays;
 
 public class KDTreeBuilder {
-
     private static class DividedRecords {
         private final int[][] hiRecords;
         private final int[][] loRecords;
@@ -28,48 +27,63 @@ public class KDTreeBuilder {
     private int nodeCount = 0;
     private int terminalNodeCount = 0;
 
+    /**
+     * Create KDTree builder.
+     *
+     * @param dimension  Dimension of the feature vectors.
+     * @param bucketSize Bucket size for the terminal nodes.
+     */
     public KDTreeBuilder(final int dimension, final int bucketSize) {
         this.bucketSize = bucketSize;
         this.dimension = dimension;
     }
 
-    public KDTree buildTree(final int[][] records) {
+    /**
+     * Construct the KDTree for provided feature vectors.
+     *
+     * @param featureVectors Feature vectors to build the tree with
+     * @return KDTree.
+     */
+    public KDTree buildTree(final int[][] featureVectors) {
         nodeCount = 0;
         terminalNodeCount = 0;
-        final KDNode rootNode = buildTreeImpl(records);
+        final KDNode rootNode = buildTreeImpl(featureVectors);
         return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount);
     }
 
-    public KDNode buildTreeImpl(final int[][] records) {
-        if (records.length <= bucketSize) {
-            return makeTerminalNode(records);
-        }
-
-        double maxSpread = -1.0;
-        int keyIndex = 0;
-
-        for (int j = 0; j < dimension; j++) {
-            // Find coordinate with greatest spread.
-            final double greatestSpread = calculateKeySpread(records, j);
-            if (greatestSpread > maxSpread) {
-                maxSpread = greatestSpread;
-                keyIndex = j;
-            }
+    /**
+     * Build KDTree by recursion, feature vectors are split in the dimension with the greatest variance.
+     *
+     * @param featureVectors Feature vectors to build the tree with.
+     * @return Node with its siblings.
+     */
+    public KDNode buildTreeImpl(final int[][] featureVectors) {
+        if (featureVectors.length <= bucketSize) {
+            return makeTerminalNode(featureVectors);
         }
-        final int median = calculateKeyMedian(records, keyIndex);
 
+        int keyIndexMSE = findDimensionWithGreatestVariance(featureVectors);
+        final int median = calculateKeyMedian(featureVectors, keyIndexMSE);
 
         // Divide records in one method to hi and lo.
-        final DividedRecords dividedRecords = divideRecords(records, median, keyIndex);
-        return makeNonTerminalNode(keyIndex, median, dividedRecords);
+        final DividedRecords dividedRecords = divideRecords(featureVectors, median, keyIndexMSE);
+        return makeNonTerminalNode(keyIndexMSE, median, dividedRecords);
     }
 
 
-    private DividedRecords divideRecords(final int[][] records, final int median, final int keyIndex) {
+    /**
+     * Divide feature vectors into low and high subgroups.
+     *
+     * @param featureVectors Feature vectors to divide.
+     * @param median         Median in the dimension.
+     * @param dimension      Dimension index.
+     * @return Divided vectors.
+     */
+    private DividedRecords divideRecords(final int[][] featureVectors, final int median, final int dimension) {
         ArrayList<int[]> loRecords = new ArrayList<>();
         ArrayList<int[]> hiRecords = new ArrayList<>();
-        for (final int[] record : records) {
-            if (record[keyIndex] <= median) {
+        for (final int[] record : featureVectors) {
+            if (record[dimension] <= median) {
                 loRecords.add(record);
             } else {
                 hiRecords.add(record);
@@ -78,24 +92,66 @@ public class KDTreeBuilder {
         return new DividedRecords(loRecords.toArray(new int[0][]), hiRecords.toArray(new int[0][]));
     }
 
-    private KDNode makeNonTerminalNode(final int keyIndex, final int median, final DividedRecords dividedRecords) {
+    /**
+     * Create internal tree node by recursion on buildTreeImpl.
+     *
+     * @param dimension      Dimension to split at.
+     * @param median         Median in the selected dimension.
+     * @param dividedRecords Records divided by the median.
+     * @return New internal node.
+     */
+    private KDNode makeNonTerminalNode(final int dimension, final int median, final DividedRecords dividedRecords) {
         final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords());
         final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords());
         ++nodeCount;
-        return new KDNode(keyIndex, median, loSon, hiSon);
+        return new KDNode(dimension, median, loSon, hiSon);
     }
 
-    public KDNode makeTerminalNode(final int[][] records) {
+    /**
+     * Construct terminal node with bucket of feature vectors.
+     *
+     * @param featureVectors Feature vectors.
+     * @return New terminal node.
+     */
+    public KDNode makeTerminalNode(final int[][] featureVectors) {
         ++nodeCount;
         ++terminalNodeCount;
-        return new TerminalKDNode(records);
+        System.out.printf("Terminal node bucket size: %d\n", featureVectors.length);
+        return new TerminalKDNode(featureVectors);
     }
 
-    private int calculateKeyMedian(final int[][] records, final int keyIndex) {
-        assert (records.length > 1);
-        final int[] sortedArray = new int[records.length];
-        for (int i = 0; i < records.length; i++) {
-            sortedArray[i] = records[i][keyIndex];
+    /**
+     * Find the dimension with the greatest variance for the feature vectors.
+     *
+     * @param featureVectors Feature vectors.
+     * @return Index of the dimension with greatest variance/spread.
+     */
+    private int findDimensionWithGreatestVariance(final int[][] featureVectors) {
+        double maxVar = -1.0;
+        int dimension = 0;
+        for (int j = 0; j < this.dimension; j++) {
+            // Find coordinate with greatest spread.
+            final double dimVar = calculateDimensionVariance(featureVectors, j);
+            if (dimVar > maxVar) {
+                maxVar = dimVar;
+                dimension = j;
+            }
+        }
+        return dimension;
+    }
+
+    /**
+     * Calculate the median in selected dimension.
+     *
+     * @param featureVectors Feature vectors.
+     * @param dimension      Dimension index.
+     * @return Median of the dimension.
+     */
+    private int calculateKeyMedian(final int[][] featureVectors, final int dimension) {
+        assert (featureVectors.length > 1);
+        final int[] sortedArray = new int[featureVectors.length];
+        for (int i = 0; i < featureVectors.length; i++) {
+            sortedArray[i] = featureVectors[i][dimension];
         }
         Arrays.sort(sortedArray);
 
@@ -108,20 +164,26 @@ public class KDTreeBuilder {
     }
 
 
-    private double calculateKeySpread(final int[][] records, final int keyIndex) {
-        double center = 0.0;
-        for (final int[] record : records) {
-            center += record[keyIndex];
+    /**
+     * Calculate variance of the values in selected dimension.
+     *
+     * @param featureVectors Feature vectors.
+     * @param dimension      Dimension index.
+     * @return Variance in the dimension.
+     */
+    private double calculateDimensionVariance(final int[][] featureVectors, final int dimension) {
+        double mean = 0.0;
+        for (final int[] record : featureVectors) {
+            mean += record[dimension];
         }
-        center /= (double) records.length;
+        mean /= (double) featureVectors.length;
 
-        double spread = 0.0;
+        double var = 0.0;
 
-        for (final int[] record : records) {
-            spread += Math.pow(((double) center - (double) record[keyIndex]), 2);
-            // spread += Math.abs(center - record[keyIndex]);
+        for (final int[] record : featureVectors) {
+            var += Math.pow(((double) record[dimension] - mean), 2);
         }
-        return Math.sqrt(spread);
-        // return (spread / (double) records.length);
+        return (var / (double) featureVectors.length);
     }
+
 }
-- 
GitLab