Skip to content
Snippets Groups Projects
KDTreeBuilder.java 6.28 KiB
Newer Older
  • Learn to ignore specific revisions
  • Vojtech Moravec's avatar
    Vojtech Moravec committed
    package azgracompress.kdtree;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    
    public class KDTreeBuilder {
        private static class DividedRecords {
            private final int[][] hiRecords;
            private final int[][] loRecords;
    
            DividedRecords(int[][] hiRecords, int[][] loRecords) {
                this.hiRecords = hiRecords;
                this.loRecords = loRecords;
            }
    
            public int[][] getHiRecords() {
                return hiRecords;
            }
    
            public int[][] getLoRecords() {
                return loRecords;
            }
        }
    
        private final int bucketSize;
        private final int dimension;
        private int nodeCount = 0;
        private int terminalNodeCount = 0;
    
    
        /**
         * Create KDTree builder.
         *
         * @param dimension  Dimension of the feature vectors.
         * @param bucketSize Bucket size for the terminal nodes.
         */
    
    Vojtech Moravec's avatar
    Vojtech Moravec committed
        public KDTreeBuilder(final int dimension, final int bucketSize) {
            this.bucketSize = bucketSize;
            this.dimension = dimension;
        }
    
    
        /**
         * Construct the KDTree for provided feature vectors.
         *
         * @param featureVectors Feature vectors to build the tree with
         * @return KDTree.
         */
        public KDTree buildTree(final int[][] featureVectors) {
    
            nodeCount = 0;
            terminalNodeCount = 0;
    
            final KDNode rootNode = buildTreeImpl(featureVectors);
    
            return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount);
        }
    
        /**
         * Build KDTree by recursion, feature vectors are split in the dimension with the greatest variance.
         *
         * @param featureVectors Feature vectors to build the tree with.
         * @return Node with its siblings.
         */
        public KDNode buildTreeImpl(final int[][] featureVectors) {
            if (featureVectors.length <= bucketSize) {
                return makeTerminalNode(featureVectors);
    
            int keyIndexMSE = findDimensionWithGreatestVariance(featureVectors);
            final int median = calculateKeyMedian(featureVectors, keyIndexMSE);
    
    Vojtech Moravec's avatar
    Vojtech Moravec committed
    
            // Divide records in one method to hi and lo.
    
            final DividedRecords dividedRecords = divideRecords(featureVectors, median, keyIndexMSE);
            return makeNonTerminalNode(keyIndexMSE, median, dividedRecords);
    
        /**
         * Divide feature vectors into low and high subgroups.
         *
         * @param featureVectors Feature vectors to divide.
         * @param median         Median in the dimension.
         * @param dimension      Dimension index.
         * @return Divided vectors.
         */
        private DividedRecords divideRecords(final int[][] featureVectors, final int median, final int dimension) {
    
    Vojtech Moravec's avatar
    Vojtech Moravec committed
            ArrayList<int[]> loRecords = new ArrayList<>();
            ArrayList<int[]> hiRecords = new ArrayList<>();
    
            for (final int[] record : featureVectors) {
                if (record[dimension] <= median) {
    
    Vojtech Moravec's avatar
    Vojtech Moravec committed
                    loRecords.add(record);
                } else {
                    hiRecords.add(record);
                }
            }
            return new DividedRecords(loRecords.toArray(new int[0][]), hiRecords.toArray(new int[0][]));
        }
    
    
        /**
         * Create internal tree node by recursion on buildTreeImpl.
         *
         * @param dimension      Dimension to split at.
         * @param median         Median in the selected dimension.
         * @param dividedRecords Records divided by the median.
         * @return New internal node.
         */
        private KDNode makeNonTerminalNode(final int dimension, final int median, final DividedRecords dividedRecords) {
    
            final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords());
            final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords());
    
    Vojtech Moravec's avatar
    Vojtech Moravec committed
            ++nodeCount;
    
            return new KDNode(dimension, median, loSon, hiSon);
    
        /**
         * Construct terminal node with bucket of feature vectors.
         *
         * @param featureVectors Feature vectors.
         * @return New terminal node.
         */
        public KDNode makeTerminalNode(final int[][] featureVectors) {
    
    Vojtech Moravec's avatar
    Vojtech Moravec committed
            ++nodeCount;
            ++terminalNodeCount;
    
            System.out.printf("Terminal node bucket size: %d\n", featureVectors.length);
            return new TerminalKDNode(featureVectors);
    
        /**
         * Find the dimension with the greatest variance for the feature vectors.
         *
         * @param featureVectors Feature vectors.
         * @return Index of the dimension with greatest variance/spread.
         */
        private int findDimensionWithGreatestVariance(final int[][] featureVectors) {
            double maxVar = -1.0;
            int dimension = 0;
            for (int j = 0; j < this.dimension; j++) {
                // Find coordinate with greatest spread.
                final double dimVar = calculateDimensionVariance(featureVectors, j);
                if (dimVar > maxVar) {
                    maxVar = dimVar;
                    dimension = j;
                }
            }
            return dimension;
        }
    
        /**
         * Calculate the median in selected dimension.
         *
         * @param featureVectors Feature vectors.
         * @param dimension      Dimension index.
         * @return Median of the dimension.
         */
        private int calculateKeyMedian(final int[][] featureVectors, final int dimension) {
            assert (featureVectors.length > 1);
            final int[] sortedArray = new int[featureVectors.length];
            for (int i = 0; i < featureVectors.length; i++) {
                sortedArray[i] = featureVectors[i][dimension];
    
    Vojtech Moravec's avatar
    Vojtech Moravec committed
            }
            Arrays.sort(sortedArray);
    
            final int midIndex = sortedArray.length / 2;
            if ((sortedArray.length % 2) == 0) {
                return (int) (((double) sortedArray[midIndex] + (double) sortedArray[(midIndex - 1)]) / 2.0);
            } else {
                return sortedArray[midIndex];
            }
        }
    
    
    
        /**
         * Calculate variance of the values in selected dimension.
         *
         * @param featureVectors Feature vectors.
         * @param dimension      Dimension index.
         * @return Variance in the dimension.
         */
        private double calculateDimensionVariance(final int[][] featureVectors, final int dimension) {
            double mean = 0.0;
            for (final int[] record : featureVectors) {
                mean += record[dimension];
    
            mean /= (double) featureVectors.length;
    
            double var = 0.0;
    
            for (final int[] record : featureVectors) {
                var += Math.pow(((double) record[dimension] - mean), 2);
    
            return (var / (double) featureVectors.length);