Skip to content
Snippets Groups Projects
KDTreeBuilder.java 6.20 KiB
package azgracompress.kdtree;

import java.util.ArrayList;
import java.util.Arrays;

public class KDTreeBuilder {
    private static class DividedRecords {
        private final int[][] hiRecords;
        private final int[][] loRecords;

        DividedRecords(int[][] hiRecords, int[][] loRecords) {
            this.hiRecords = hiRecords;
            this.loRecords = loRecords;
        }

        public int[][] getHiRecords() {
            return hiRecords;
        }

        public int[][] getLoRecords() {
            return loRecords;
        }
    }

    private final int bucketSize;
    private final int dimension;
    private int nodeCount = 0;
    private int terminalNodeCount = 0;

    /**
     * Create KDTree builder.
     *
     * @param dimension  Dimension of the feature vectors.
     * @param bucketSize Bucket size for the terminal nodes.
     */
    public KDTreeBuilder(final int dimension, final int bucketSize) {
        this.bucketSize = bucketSize;
        this.dimension = dimension;
    }

    /**
     * Construct the KDTree for provided feature vectors.
     *
     * @param featureVectors Feature vectors to build the tree with
     * @return KDTree.
     */
    public KDTree buildTree(final int[][] featureVectors) {
        nodeCount = 0;
        terminalNodeCount = 0;
        final KDNode rootNode = buildTreeImpl(featureVectors);
        return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount);
    }

    /**
     * Build KDTree by recursion, feature vectors are split in the dimension with the greatest variance.
     *
     * @param featureVectors Feature vectors to build the tree with.
     * @return Node with its siblings.
     */
    public KDNode buildTreeImpl(final int[][] featureVectors) {
        if (featureVectors.length <= bucketSize) {
            return makeTerminalNode(featureVectors);
        }

        int keyIndexMSE = findDimensionWithGreatestVariance(featureVectors);
        final int median = calculateKeyMedian(featureVectors, keyIndexMSE);

        // Divide records in one method to hi and lo.
        final DividedRecords dividedRecords = divideRecords(featureVectors, median, keyIndexMSE);
        return makeNonTerminalNode(keyIndexMSE, median, dividedRecords);
    }


    /**
     * Divide feature vectors into low and high subgroups.
     *
     * @param featureVectors Feature vectors to divide.
     * @param median         Median in the dimension.
     * @param dimension      Dimension index.
     * @return Divided vectors.
     */
    private DividedRecords divideRecords(final int[][] featureVectors, final int median, final int dimension) {
        ArrayList<int[]> loRecords = new ArrayList<>();
        ArrayList<int[]> hiRecords = new ArrayList<>();
        for (final int[] record : featureVectors) {
            if (record[dimension] <= median) {
                loRecords.add(record);
            } else {
                hiRecords.add(record);
            }
        }
        return new DividedRecords(loRecords.toArray(new int[0][]), hiRecords.toArray(new int[0][]));
    }

    /**
     * Create internal tree node by recursion on buildTreeImpl.
     *
     * @param dimension      Dimension to split at.
     * @param median         Median in the selected dimension.
     * @param dividedRecords Records divided by the median.
     * @return New internal node.
     */
    private KDNode makeNonTerminalNode(final int dimension, final int median, final DividedRecords dividedRecords) {
        final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords());
        final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords());
        ++nodeCount;
        return new KDNode(dimension, median, loSon, hiSon);
    }

    /**
     * Construct terminal node with bucket of feature vectors.
     *
     * @param featureVectors Feature vectors.
     * @return New terminal node.
     */
    public KDNode makeTerminalNode(final int[][] featureVectors) {
        ++nodeCount;
        ++terminalNodeCount;
        return new TerminalKDNode(featureVectors);
    }

    /**
     * Find the dimension with the greatest variance for the feature vectors.
     *
     * @param featureVectors Feature vectors.
     * @return Index of the dimension with greatest variance/spread.
     */
    private int findDimensionWithGreatestVariance(final int[][] featureVectors) {
        double maxVar = -1.0;
        int dimension = 0;
        for (int j = 0; j < this.dimension; j++) {
            // Find coordinate with greatest spread.
            final double dimVar = calculateDimensionVariance(featureVectors, j);
            if (dimVar > maxVar) {
                maxVar = dimVar;
                dimension = j;
            }
        }
        return dimension;
    }

    /**
     * Calculate the median in selected dimension.
     *
     * @param featureVectors Feature vectors.
     * @param dimension      Dimension index.
     * @return Median of the dimension.
     */
    private int calculateKeyMedian(final int[][] featureVectors, final int dimension) {
        assert (featureVectors.length > 1);
        final int[] sortedArray = new int[featureVectors.length];
        for (int i = 0; i < featureVectors.length; i++) {
            sortedArray[i] = featureVectors[i][dimension];
        }
        Arrays.sort(sortedArray);

        final int midIndex = sortedArray.length / 2;
        if ((sortedArray.length % 2) == 0) {
            return (int) (((double) sortedArray[midIndex] + (double) sortedArray[(midIndex - 1)]) / 2.0);
        } else {
            return sortedArray[midIndex];
        }
    }


    /**
     * Calculate variance of the values in selected dimension.
     *
     * @param featureVectors Feature vectors.
     * @param dimension      Dimension index.
     * @return Variance in the dimension.
     */
    private double calculateDimensionVariance(final int[][] featureVectors, final int dimension) {
        double mean = 0.0;
        for (final int[] record : featureVectors) {
            mean += record[dimension];
        }
        mean /= (double) featureVectors.length;

        double var = 0.0;

        for (final int[] record : featureVectors) {
            var += Math.pow(((double) record[dimension] - mean), 2);
        }
        return (var / (double) featureVectors.length);
    }

}