Skip to content
Snippets Groups Projects
Commit b5b92ced authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Modified KDTree to work only with feature vector indices.

This included modification of KDTreeBuilder same as TerminalKDNode.
Also the builder had a bug where `hi` and `lo` records were swapped in
the `divide` method.
parent ec5cab32
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,7 @@ import java.util.PriorityQueue; ...@@ -12,6 +12,7 @@ import java.util.PriorityQueue;
// https://github.com/iwyoo/kd_tree/blob/master/kd_tree.cxx // https://github.com/iwyoo/kd_tree/blob/master/kd_tree.cxx
public class KDTree { public class KDTree {
private final int[][] featureVectors;
private final int maximumBucketSize; private final int maximumBucketSize;
private final KDNode root; private final KDNode root;
...@@ -20,14 +21,20 @@ public class KDTree { ...@@ -20,14 +21,20 @@ public class KDTree {
private final int terminalNodeCount; private final int terminalNodeCount;
public static class BBFSearchInfo { public static class BBFSearchInfo {
private final int[][] featureVectors;
private double nearestRecordDistance; private double nearestRecordDistance;
private int[] nearestRecord; private int[] nearestRecord;
public BBFSearchInfo() { public BBFSearchInfo(final int[][] featureVectors) {
this.featureVectors = featureVectors;
nearestRecord = null; nearestRecord = null;
nearestRecordDistance = Double.POSITIVE_INFINITY; nearestRecordDistance = Double.POSITIVE_INFINITY;
} }
public int[][] getFeatureVectors() {
return featureVectors;
}
public double getNearestRecordDistance() { public double getNearestRecordDistance() {
return nearestRecordDistance; return nearestRecordDistance;
} }
...@@ -65,13 +72,14 @@ public class KDTree { ...@@ -65,13 +72,14 @@ public class KDTree {
} }
} }
public KDTree(final KDNode root, public KDTree(final int[][] featureVectors,
final int dimension, final KDNode root,
final int maximumBucketSize, final int maximumBucketSize,
final int totalNodeCount, final int totalNodeCount,
final int terminalNodeCount) { final int terminalNodeCount) {
this.featureVectors = featureVectors;
this.root = root; this.root = root;
this.dimension = dimension; this.dimension = featureVectors[0].length;
this.maximumBucketSize = maximumBucketSize; this.maximumBucketSize = maximumBucketSize;
this.totalNodeCount = totalNodeCount; this.totalNodeCount = totalNodeCount;
this.terminalNodeCount = terminalNodeCount; this.terminalNodeCount = terminalNodeCount;
...@@ -82,7 +90,7 @@ public class KDTree { ...@@ -82,7 +90,7 @@ public class KDTree {
PriorityQueue<NodeWithDistance> priorityQueue = new PriorityQueue<>(); PriorityQueue<NodeWithDistance> priorityQueue = new PriorityQueue<>();
priorityQueue.add(new NodeWithDistance(root, 0.0)); priorityQueue.add(new NodeWithDistance(root, 0.0));
BBFSearchInfo searchInfo = new BBFSearchInfo(); BBFSearchInfo searchInfo = new BBFSearchInfo(featureVectors);
int tryIndex = 0; int tryIndex = 0;
int partition, discriminator; int partition, discriminator;
while (!priorityQueue.isEmpty() && tryIndex < maxE) { while (!priorityQueue.isEmpty() && tryIndex < maxE) {
......
...@@ -5,23 +5,24 @@ import java.util.Arrays; ...@@ -5,23 +5,24 @@ import java.util.Arrays;
public class KDTreeBuilder { public class KDTreeBuilder {
private static class DividedRecords { private static class DividedRecords {
private final int[][] hiRecords; private final int[] loIndices;
private final int[][] loRecords; private final int[] hiIndices;
DividedRecords(int[][] hiRecords, int[][] loRecords) { DividedRecords(final int[] loIndices, final int[] hiIndices) {
this.hiRecords = hiRecords; this.loIndices = loIndices;
this.loRecords = loRecords; this.hiIndices = hiIndices;
} }
public int[][] getHiRecords() { public int[] getHiIndices() {
return hiRecords; return hiIndices;
} }
public int[][] getLoRecords() { public int[] getLoIndices() {
return loRecords; return loIndices;
} }
} }
private int[][] featureVectors;
private final int bucketSize; private final int bucketSize;
private final int dimension; private final int dimension;
private int nodeCount = 0; private int nodeCount = 0;
...@@ -47,49 +48,59 @@ public class KDTreeBuilder { ...@@ -47,49 +48,59 @@ public class KDTreeBuilder {
public KDTree buildTree(final int[][] featureVectors) { public KDTree buildTree(final int[][] featureVectors) {
nodeCount = 0; nodeCount = 0;
terminalNodeCount = 0; terminalNodeCount = 0;
final KDNode rootNode = buildTreeImpl(featureVectors); this.featureVectors = featureVectors;
return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount); final int[] indices = new int[featureVectors.length];
for (int i = 0; i < featureVectors.length; i++) {
indices[i] = i;
}
final KDNode rootNode = buildTreeImpl(indices);
return new KDTree(featureVectors, rootNode, bucketSize, nodeCount, terminalNodeCount);
} }
/** /**
* Build KDTree by recursion, feature vectors are split in the dimension with the greatest variance. * Build KDTree by recursion, feature vectors are split in the dimension with the greatest variance.
* *
* @param featureVectors Feature vectors to build the tree with. * @param indices Indices of feature vectors to build the tree with.
* @return Node with its siblings. * @return Node with its siblings.
*/ */
private KDNode buildTreeImpl(final int[][] featureVectors) { private KDNode buildTreeImpl(final int[] indices) {
if (featureVectors.length <= bucketSize) { if (indices.length <= bucketSize) {
return makeTerminalNode(featureVectors); return makeTerminalNode(indices);
} }
int keyIndexMSE = findDimensionWithGreatestVariance(featureVectors); int dimensionIndex = findDimensionWithGreatestVariance(indices);
final int median = calculateKeyMedian(featureVectors, keyIndexMSE); final int median = calculateKeyMedian(indices, dimensionIndex);
// Divide records in one method to hi and lo. // Divide records in one method to hi and lo.
final DividedRecords dividedRecords = divideRecords(featureVectors, median, keyIndexMSE); final DividedRecords dividedRecords = divideRecords(indices, median, dimensionIndex);
return makeNonTerminalNode(keyIndexMSE, median, dividedRecords); return makeNonTerminalNode(dimensionIndex, median, dividedRecords);
} }
/** /**
* Divide feature vectors into low and high subgroups. * Divide feature vectors into low and high subgroups.
* *
* @param featureVectors Feature vectors to divide. * @param indices Indices of feature vectors to divide into two groups.
* @param median Median in the dimension. * @param median Median in the dimension.
* @param dimension Dimension index. * @param dimension Dimension index.
* @return Divided vectors. * @return Divided vectors.
*/ */
private DividedRecords divideRecords(final int[][] featureVectors, final int median, final int dimension) { private DividedRecords divideRecords(final int[] indices, final int median, final int dimension) {
ArrayList<int[]> loRecords = new ArrayList<>(); ArrayList<Integer> loIndices = new ArrayList<>();
ArrayList<int[]> hiRecords = new ArrayList<>(); ArrayList<Integer> hiIndices = new ArrayList<>();
for (final int[] record : featureVectors) {
if (record[dimension] <= median) { for (final int fVecIndex : indices) {
loRecords.add(record); if (featureVectors[fVecIndex][dimension] <= median) {
loIndices.add(fVecIndex);
} else { } else {
hiRecords.add(record);
hiIndices.add(fVecIndex);
} }
} }
return new DividedRecords(loRecords.toArray(new int[0][]), hiRecords.toArray(new int[0][]));
return new DividedRecords(loIndices.stream().mapToInt(Integer::intValue).toArray(),
hiIndices.stream().mapToInt(Integer::intValue).toArray());
} }
/** /**
...@@ -101,8 +112,8 @@ public class KDTreeBuilder { ...@@ -101,8 +112,8 @@ public class KDTreeBuilder {
* @return New internal node. * @return New internal node.
*/ */
private KDNode makeNonTerminalNode(final int dimension, final int median, final DividedRecords dividedRecords) { private KDNode makeNonTerminalNode(final int dimension, final int median, final DividedRecords dividedRecords) {
final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords()); final KDNode loSon = buildTreeImpl(dividedRecords.getLoIndices());
final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords()); final KDNode hiSon = buildTreeImpl(dividedRecords.getHiIndices());
++nodeCount; ++nodeCount;
return new KDNode(dimension, median, loSon, hiSon); return new KDNode(dimension, median, loSon, hiSon);
} }
...@@ -110,27 +121,27 @@ public class KDTreeBuilder { ...@@ -110,27 +121,27 @@ public class KDTreeBuilder {
/** /**
* Construct terminal node with bucket of feature vectors. * Construct terminal node with bucket of feature vectors.
* *
* @param featureVectors Feature vectors. * @param bucketIndices Indices of feature vectors to be stored in the leaf/terminal node.
* @return New terminal node. * @return New terminal node.
*/ */
public KDNode makeTerminalNode(final int[][] featureVectors) { public KDNode makeTerminalNode(final int[] bucketIndices) {
++nodeCount; ++nodeCount;
++terminalNodeCount; ++terminalNodeCount;
return new TerminalKDNode(featureVectors); return new TerminalKDNode(bucketIndices);
} }
/** /**
* Find the dimension with the greatest variance for the feature vectors. * Find the dimension with the greatest variance for the feature vectors.
* *
* @param featureVectors Feature vectors. * @param indices Indices of feature vectors.
* @return Index of the dimension with greatest variance/spread. * @return Index of the dimension with greatest variance/spread.
*/ */
private int findDimensionWithGreatestVariance(final int[][] featureVectors) { private int findDimensionWithGreatestVariance(final int[] indices) {
double maxVar = -1.0; double maxVar = -1.0;
int dimension = 0; int dimension = 0;
for (int j = 0; j < this.dimension; j++) { for (int j = 0; j < this.dimension; j++) {
// Find coordinate with greatest spread. // Find coordinate with greatest spread.
final double dimVar = calculateDimensionVariance(featureVectors, j); final double dimVar = calculateDimensionVariance(indices, j);
if (dimVar > maxVar) { if (dimVar > maxVar) {
maxVar = dimVar; maxVar = dimVar;
dimension = j; dimension = j;
...@@ -142,15 +153,15 @@ public class KDTreeBuilder { ...@@ -142,15 +153,15 @@ public class KDTreeBuilder {
/** /**
* Calculate the median in selected dimension. * Calculate the median in selected dimension.
* *
* @param featureVectors Feature vectors. * @param indices Indices of feature vectors.
* @param dimension Dimension index. * @param dimension Dimension index.
* @return Median of the dimension. * @return Median of the dimension.
*/ */
private int calculateKeyMedian(final int[][] featureVectors, final int dimension) { private int calculateKeyMedian(final int[] indices, final int dimension) {
assert (featureVectors.length > 1); assert (indices.length > 1);
final int[] sortedArray = new int[featureVectors.length]; final int[] sortedArray = new int[indices.length];
for (int i = 0; i < featureVectors.length; i++) { for (int i = 0; i < indices.length; i++) {
sortedArray[i] = featureVectors[i][dimension]; sortedArray[i] = featureVectors[indices[i]][dimension];
} }
Arrays.sort(sortedArray); Arrays.sort(sortedArray);
...@@ -166,23 +177,23 @@ public class KDTreeBuilder { ...@@ -166,23 +177,23 @@ public class KDTreeBuilder {
/** /**
* Calculate variance of the values in selected dimension. * Calculate variance of the values in selected dimension.
* *
* @param featureVectors Feature vectors. * @param indices Indices of feature vectors.
* @param dimension Dimension index. * @param dimension Dimension index.
* @return Variance in the dimension. * @return Variance in the dimension.
*/ */
private double calculateDimensionVariance(final int[][] featureVectors, final int dimension) { private double calculateDimensionVariance(final int[] indices, final int dimension) {
double mean = 0.0; double mean = 0.0;
for (final int[] record : featureVectors) { for (final int fVecIndex : indices) {
mean += record[dimension]; mean += featureVectors[fVecIndex][dimension];
} }
mean /= (double) featureVectors.length; mean /= (double) indices.length;
double var = 0.0; double var = 0.0;
for (final int[] record : featureVectors) { for (final int fVecIndex : indices) {
var += Math.pow(((double) record[dimension] - mean), 2); var += Math.pow(((double) featureVectors[fVecIndex][dimension] - mean), 2);
} }
return (var / (double) featureVectors.length); return (var / (double) indices.length);
} }
} }
...@@ -4,11 +4,11 @@ import azgracompress.utilities.Utils; ...@@ -4,11 +4,11 @@ import azgracompress.utilities.Utils;
public class TerminalKDNode extends KDNode { public class TerminalKDNode extends KDNode {
private final int[][] bucket; private final int[] bucketIndices;
public TerminalKDNode(final int[][] records) { public TerminalKDNode(final int[] bucketIndices) {
super(); super();
this.bucket = records; this.bucketIndices = bucketIndices;
} }
@Override @Override
...@@ -16,16 +16,16 @@ public class TerminalKDNode extends KDNode { ...@@ -16,16 +16,16 @@ public class TerminalKDNode extends KDNode {
return true; return true;
} }
public int[][] getBucket() { public int[] getBucketIndices() {
return bucket; return bucketIndices;
} }
public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.BBFSearchInfo searchInfo) { public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.BBFSearchInfo searchInfo) {
double recordDistance; double recordDistance;
for (final int[] record : bucket) { for (final int index : bucketIndices) {
recordDistance = Utils.calculateEuclideanDistance(queryRecord, record); recordDistance = Utils.calculateEuclideanDistance(queryRecord, searchInfo.getFeatureVectors()[index]);
if (recordDistance < searchInfo.getNearestRecordDistance()) { if (recordDistance < searchInfo.getNearestRecordDistance()) {
searchInfo.setNearestRecord(record, recordDistance); searchInfo.setNearestRecord(searchInfo.getFeatureVectors()[index], recordDistance);
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment