diff --git a/src/main/java/azgracompress/kdtree/KDTree.java b/src/main/java/azgracompress/kdtree/KDTree.java index d45b0b8152e713005ea854bd36d51f818aa13f45..04ab7fc340dee6384e44160a8439334276c2bb49 100644 --- a/src/main/java/azgracompress/kdtree/KDTree.java +++ b/src/main/java/azgracompress/kdtree/KDTree.java @@ -12,6 +12,7 @@ import java.util.PriorityQueue; // https://github.com/iwyoo/kd_tree/blob/master/kd_tree.cxx public class KDTree { + private final int[][] featureVectors; private final int maximumBucketSize; private final KDNode root; @@ -20,14 +21,20 @@ public class KDTree { private final int terminalNodeCount; public static class BBFSearchInfo { + private final int[][] featureVectors; private double nearestRecordDistance; private int[] nearestRecord; - public BBFSearchInfo() { + public BBFSearchInfo(final int[][] featureVectors) { + this.featureVectors = featureVectors; nearestRecord = null; nearestRecordDistance = Double.POSITIVE_INFINITY; } + public int[][] getFeatureVectors() { + return featureVectors; + } + public double getNearestRecordDistance() { return nearestRecordDistance; } @@ -65,13 +72,14 @@ public class KDTree { } } - public KDTree(final KDNode root, - final int dimension, + public KDTree(final int[][] featureVectors, + final KDNode root, final int maximumBucketSize, final int totalNodeCount, final int terminalNodeCount) { + this.featureVectors = featureVectors; this.root = root; - this.dimension = dimension; + this.dimension = featureVectors[0].length; this.maximumBucketSize = maximumBucketSize; this.totalNodeCount = totalNodeCount; this.terminalNodeCount = terminalNodeCount; @@ -82,7 +90,7 @@ public class KDTree { PriorityQueue<NodeWithDistance> priorityQueue = new PriorityQueue<>(); priorityQueue.add(new NodeWithDistance(root, 0.0)); - BBFSearchInfo searchInfo = new BBFSearchInfo(); + BBFSearchInfo searchInfo = new BBFSearchInfo(featureVectors); int tryIndex = 0; int partition, discriminator; while (!priorityQueue.isEmpty() && tryIndex < maxE) { diff --git a/src/main/java/azgracompress/kdtree/KDTreeBuilder.java b/src/main/java/azgracompress/kdtree/KDTreeBuilder.java index 7f0b5dbc4620ae60daaba24766bb2f9294c1c2aa..ace1feead1911c68997e09ae3f9193436f59f7ab 100644 --- a/src/main/java/azgracompress/kdtree/KDTreeBuilder.java +++ b/src/main/java/azgracompress/kdtree/KDTreeBuilder.java @@ -5,23 +5,24 @@ import java.util.Arrays; public class KDTreeBuilder { private static class DividedRecords { - private final int[][] hiRecords; - private final int[][] loRecords; + private final int[] loIndices; + private final int[] hiIndices; - DividedRecords(int[][] hiRecords, int[][] loRecords) { - this.hiRecords = hiRecords; - this.loRecords = loRecords; + DividedRecords(final int[] loIndices, final int[] hiIndices) { + this.loIndices = loIndices; + this.hiIndices = hiIndices; } - public int[][] getHiRecords() { - return hiRecords; + public int[] getHiIndices() { + return hiIndices; } - public int[][] getLoRecords() { - return loRecords; + public int[] getLoIndices() { + return loIndices; } } + private int[][] featureVectors; private final int bucketSize; private final int dimension; private int nodeCount = 0; @@ -47,49 +48,59 @@ public class KDTreeBuilder { public KDTree buildTree(final int[][] featureVectors) { nodeCount = 0; terminalNodeCount = 0; - final KDNode rootNode = buildTreeImpl(featureVectors); - return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount); + this.featureVectors = featureVectors; + final int[] indices = new int[featureVectors.length]; + for (int i = 0; i < featureVectors.length; i++) { + indices[i] = i; + } + + final KDNode rootNode = buildTreeImpl(indices); + return new KDTree(featureVectors, rootNode, bucketSize, nodeCount, terminalNodeCount); } /** * Build KDTree by recursion, feature vectors are split in the dimension with the greatest variance. * - * @param featureVectors Feature vectors to build the tree with. + * @param indices Indices of feature vectors to build the tree with. * @return Node with its siblings. */ - private KDNode buildTreeImpl(final int[][] featureVectors) { - if (featureVectors.length <= bucketSize) { - return makeTerminalNode(featureVectors); + private KDNode buildTreeImpl(final int[] indices) { + if (indices.length <= bucketSize) { + return makeTerminalNode(indices); } - int keyIndexMSE = findDimensionWithGreatestVariance(featureVectors); - final int median = calculateKeyMedian(featureVectors, keyIndexMSE); + int dimensionIndex = findDimensionWithGreatestVariance(indices); + final int median = calculateKeyMedian(indices, dimensionIndex); // Divide records in one method to hi and lo. - final DividedRecords dividedRecords = divideRecords(featureVectors, median, keyIndexMSE); - return makeNonTerminalNode(keyIndexMSE, median, dividedRecords); + final DividedRecords dividedRecords = divideRecords(indices, median, dimensionIndex); + return makeNonTerminalNode(dimensionIndex, median, dividedRecords); } /** * Divide feature vectors into low and high subgroups. * - * @param featureVectors Feature vectors to divide. - * @param median Median in the dimension. - * @param dimension Dimension index. + * @param indices Indices of feature vectors to divide into two groups. + * @param median Median in the dimension. + * @param dimension Dimension index. * @return Divided vectors. */ - private DividedRecords divideRecords(final int[][] featureVectors, final int median, final int dimension) { - ArrayList<int[]> loRecords = new ArrayList<>(); - ArrayList<int[]> hiRecords = new ArrayList<>(); - for (final int[] record : featureVectors) { - if (record[dimension] <= median) { - loRecords.add(record); + private DividedRecords divideRecords(final int[] indices, final int median, final int dimension) { + ArrayList<Integer> loIndices = new ArrayList<>(); + ArrayList<Integer> hiIndices = new ArrayList<>(); + + for (final int fVecIndex : indices) { + if (featureVectors[fVecIndex][dimension] <= median) { + loIndices.add(fVecIndex); } else { - hiRecords.add(record); + + hiIndices.add(fVecIndex); } } - return new DividedRecords(loRecords.toArray(new int[0][]), hiRecords.toArray(new int[0][])); + + return new DividedRecords(loIndices.stream().mapToInt(Integer::intValue).toArray(), + hiIndices.stream().mapToInt(Integer::intValue).toArray()); } /** @@ -101,8 +112,8 @@ public class KDTreeBuilder { * @return New internal node. */ private KDNode makeNonTerminalNode(final int dimension, final int median, final DividedRecords dividedRecords) { - final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords()); - final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords()); + final KDNode loSon = buildTreeImpl(dividedRecords.getLoIndices()); + final KDNode hiSon = buildTreeImpl(dividedRecords.getHiIndices()); ++nodeCount; return new KDNode(dimension, median, loSon, hiSon); } @@ -110,27 +121,27 @@ public class KDTreeBuilder { /** * Construct terminal node with bucket of feature vectors. * - * @param featureVectors Feature vectors. + * @param bucketIndices Indices of feature vectors to be stored in the leaf/terminal node. * @return New terminal node. */ - public KDNode makeTerminalNode(final int[][] featureVectors) { + public KDNode makeTerminalNode(final int[] bucketIndices) { ++nodeCount; ++terminalNodeCount; - return new TerminalKDNode(featureVectors); + return new TerminalKDNode(bucketIndices); } /** * Find the dimension with the greatest variance for the feature vectors. * - * @param featureVectors Feature vectors. + * @param indices Indices of feature vectors. * @return Index of the dimension with greatest variance/spread. */ - private int findDimensionWithGreatestVariance(final int[][] featureVectors) { + private int findDimensionWithGreatestVariance(final int[] indices) { double maxVar = -1.0; int dimension = 0; for (int j = 0; j < this.dimension; j++) { // Find coordinate with greatest spread. - final double dimVar = calculateDimensionVariance(featureVectors, j); + final double dimVar = calculateDimensionVariance(indices, j); if (dimVar > maxVar) { maxVar = dimVar; dimension = j; @@ -142,15 +153,15 @@ public class KDTreeBuilder { /** * Calculate the median in selected dimension. * - * @param featureVectors Feature vectors. - * @param dimension Dimension index. + * @param indices Indices of feature vectors. + * @param dimension Dimension index. * @return Median of the dimension. */ - private int calculateKeyMedian(final int[][] featureVectors, final int dimension) { - assert (featureVectors.length > 1); - final int[] sortedArray = new int[featureVectors.length]; - for (int i = 0; i < featureVectors.length; i++) { - sortedArray[i] = featureVectors[i][dimension]; + private int calculateKeyMedian(final int[] indices, final int dimension) { + assert (indices.length > 1); + final int[] sortedArray = new int[indices.length]; + for (int i = 0; i < indices.length; i++) { + sortedArray[i] = featureVectors[indices[i]][dimension]; } Arrays.sort(sortedArray); @@ -166,23 +177,23 @@ public class KDTreeBuilder { /** * Calculate variance of the values in selected dimension. * - * @param featureVectors Feature vectors. - * @param dimension Dimension index. + * @param indices Indices of feature vectors. + * @param dimension Dimension index. * @return Variance in the dimension. */ - private double calculateDimensionVariance(final int[][] featureVectors, final int dimension) { + private double calculateDimensionVariance(final int[] indices, final int dimension) { double mean = 0.0; - for (final int[] record : featureVectors) { - mean += record[dimension]; + for (final int fVecIndex : indices) { + mean += featureVectors[fVecIndex][dimension]; } - mean /= (double) featureVectors.length; + mean /= (double) indices.length; double var = 0.0; - for (final int[] record : featureVectors) { - var += Math.pow(((double) record[dimension] - mean), 2); + for (final int fVecIndex : indices) { + var += Math.pow(((double) featureVectors[fVecIndex][dimension] - mean), 2); } - return (var / (double) featureVectors.length); + return (var / (double) indices.length); } } diff --git a/src/main/java/azgracompress/kdtree/TerminalKDNode.java b/src/main/java/azgracompress/kdtree/TerminalKDNode.java index eed19064f2b14174264da4dfd54bc498aeb7c310..8ba4fc76c817585dd929c9e6a60ab20fe89ab706 100644 --- a/src/main/java/azgracompress/kdtree/TerminalKDNode.java +++ b/src/main/java/azgracompress/kdtree/TerminalKDNode.java @@ -4,11 +4,11 @@ import azgracompress.utilities.Utils; public class TerminalKDNode extends KDNode { - private final int[][] bucket; + private final int[] bucketIndices; - public TerminalKDNode(final int[][] records) { + public TerminalKDNode(final int[] bucketIndices) { super(); - this.bucket = records; + this.bucketIndices = bucketIndices; } @Override @@ -16,16 +16,16 @@ public class TerminalKDNode extends KDNode { return true; } - public int[][] getBucket() { - return bucket; + public int[] getBucketIndices() { + return bucketIndices; } public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.BBFSearchInfo searchInfo) { double recordDistance; - for (final int[] record : bucket) { - recordDistance = Utils.calculateEuclideanDistance(queryRecord, record); + for (final int index : bucketIndices) { + recordDistance = Utils.calculateEuclideanDistance(queryRecord, searchInfo.getFeatureVectors()[index]); if (recordDistance < searchInfo.getNearestRecordDistance()) { - searchInfo.setNearestRecord(record, recordDistance); + searchInfo.setNearestRecord(searchInfo.getFeatureVectors()[index], recordDistance); } } }