diff --git a/src/main/java/azgracompress/kdtree/KDTreeBuilder.java b/src/main/java/azgracompress/kdtree/KDTreeBuilder.java index 8503db4c313d434ce316224b7cbb59c8847e8830..fd53e542f7614e1ad4bac87dd14764220ae1c350 100644 --- a/src/main/java/azgracompress/kdtree/KDTreeBuilder.java +++ b/src/main/java/azgracompress/kdtree/KDTreeBuilder.java @@ -4,7 +4,6 @@ import java.util.ArrayList; import java.util.Arrays; public class KDTreeBuilder { - private static class DividedRecords { private final int[][] hiRecords; private final int[][] loRecords; @@ -28,48 +27,63 @@ public class KDTreeBuilder { private int nodeCount = 0; private int terminalNodeCount = 0; + /** + * Create KDTree builder. + * + * @param dimension Dimension of the feature vectors. + * @param bucketSize Bucket size for the terminal nodes. + */ public KDTreeBuilder(final int dimension, final int bucketSize) { this.bucketSize = bucketSize; this.dimension = dimension; } - public KDTree buildTree(final int[][] records) { + /** + * Construct the KDTree for provided feature vectors. + * + * @param featureVectors Feature vectors to build the tree with + * @return KDTree. + */ + public KDTree buildTree(final int[][] featureVectors) { nodeCount = 0; terminalNodeCount = 0; - final KDNode rootNode = buildTreeImpl(records); + final KDNode rootNode = buildTreeImpl(featureVectors); return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount); } - public KDNode buildTreeImpl(final int[][] records) { - if (records.length <= bucketSize) { - return makeTerminalNode(records); - } - - double maxSpread = -1.0; - int keyIndex = 0; - - for (int j = 0; j < dimension; j++) { - // Find coordinate with greatest spread. - final double greatestSpread = calculateKeySpread(records, j); - if (greatestSpread > maxSpread) { - maxSpread = greatestSpread; - keyIndex = j; - } + /** + * Build KDTree by recursion, feature vectors are split in the dimension with the greatest variance. + * + * @param featureVectors Feature vectors to build the tree with. + * @return Node with its siblings. + */ + public KDNode buildTreeImpl(final int[][] featureVectors) { + if (featureVectors.length <= bucketSize) { + return makeTerminalNode(featureVectors); } - final int median = calculateKeyMedian(records, keyIndex); + int keyIndexMSE = findDimensionWithGreatestVariance(featureVectors); + final int median = calculateKeyMedian(featureVectors, keyIndexMSE); // Divide records in one method to hi and lo. - final DividedRecords dividedRecords = divideRecords(records, median, keyIndex); - return makeNonTerminalNode(keyIndex, median, dividedRecords); + final DividedRecords dividedRecords = divideRecords(featureVectors, median, keyIndexMSE); + return makeNonTerminalNode(keyIndexMSE, median, dividedRecords); } - private DividedRecords divideRecords(final int[][] records, final int median, final int keyIndex) { + /** + * Divide feature vectors into low and high subgroups. + * + * @param featureVectors Feature vectors to divide. + * @param median Median in the dimension. + * @param dimension Dimension index. + * @return Divided vectors. + */ + private DividedRecords divideRecords(final int[][] featureVectors, final int median, final int dimension) { ArrayList<int[]> loRecords = new ArrayList<>(); ArrayList<int[]> hiRecords = new ArrayList<>(); - for (final int[] record : records) { - if (record[keyIndex] <= median) { + for (final int[] record : featureVectors) { + if (record[dimension] <= median) { loRecords.add(record); } else { hiRecords.add(record); @@ -78,24 +92,66 @@ public class KDTreeBuilder { return new DividedRecords(loRecords.toArray(new int[0][]), hiRecords.toArray(new int[0][])); } - private KDNode makeNonTerminalNode(final int keyIndex, final int median, final DividedRecords dividedRecords) { + /** + * Create internal tree node by recursion on buildTreeImpl. + * + * @param dimension Dimension to split at. + * @param median Median in the selected dimension. + * @param dividedRecords Records divided by the median. + * @return New internal node. + */ + private KDNode makeNonTerminalNode(final int dimension, final int median, final DividedRecords dividedRecords) { final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords()); final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords()); ++nodeCount; - return new KDNode(keyIndex, median, loSon, hiSon); + return new KDNode(dimension, median, loSon, hiSon); } - public KDNode makeTerminalNode(final int[][] records) { + /** + * Construct terminal node with bucket of feature vectors. + * + * @param featureVectors Feature vectors. + * @return New terminal node. + */ + public KDNode makeTerminalNode(final int[][] featureVectors) { ++nodeCount; ++terminalNodeCount; - return new TerminalKDNode(records); + System.out.printf("Terminal node bucket size: %d\n", featureVectors.length); + return new TerminalKDNode(featureVectors); } - private int calculateKeyMedian(final int[][] records, final int keyIndex) { - assert (records.length > 1); - final int[] sortedArray = new int[records.length]; - for (int i = 0; i < records.length; i++) { - sortedArray[i] = records[i][keyIndex]; + /** + * Find the dimension with the greatest variance for the feature vectors. + * + * @param featureVectors Feature vectors. + * @return Index of the dimension with greatest variance/spread. + */ + private int findDimensionWithGreatestVariance(final int[][] featureVectors) { + double maxVar = -1.0; + int dimension = 0; + for (int j = 0; j < this.dimension; j++) { + // Find coordinate with greatest spread. + final double dimVar = calculateDimensionVariance(featureVectors, j); + if (dimVar > maxVar) { + maxVar = dimVar; + dimension = j; + } + } + return dimension; + } + + /** + * Calculate the median in selected dimension. + * + * @param featureVectors Feature vectors. + * @param dimension Dimension index. + * @return Median of the dimension. + */ + private int calculateKeyMedian(final int[][] featureVectors, final int dimension) { + assert (featureVectors.length > 1); + final int[] sortedArray = new int[featureVectors.length]; + for (int i = 0; i < featureVectors.length; i++) { + sortedArray[i] = featureVectors[i][dimension]; } Arrays.sort(sortedArray); @@ -108,20 +164,26 @@ public class KDTreeBuilder { } - private double calculateKeySpread(final int[][] records, final int keyIndex) { - double center = 0.0; - for (final int[] record : records) { - center += record[keyIndex]; + /** + * Calculate variance of the values in selected dimension. + * + * @param featureVectors Feature vectors. + * @param dimension Dimension index. + * @return Variance in the dimension. + */ + private double calculateDimensionVariance(final int[][] featureVectors, final int dimension) { + double mean = 0.0; + for (final int[] record : featureVectors) { + mean += record[dimension]; } - center /= (double) records.length; + mean /= (double) featureVectors.length; - double spread = 0.0; + double var = 0.0; - for (final int[] record : records) { - spread += Math.pow(((double) center - (double) record[keyIndex]), 2); - // spread += Math.abs(center - record[keyIndex]); + for (final int[] record : featureVectors) { + var += Math.pow(((double) record[dimension] - mean), 2); } - return Math.sqrt(spread); - // return (spread / (double) records.length); + return (var / (double) featureVectors.length); } + }