Skip to content
Snippets Groups Projects
Commit e75cc073 authored by Vojtěch Moravec's avatar Vojtěch Moravec
Browse files

Clean KDTree building code.

Refactored the build code, with new proper variable names and some
simplifications.
Also the dimension to split on is calculated as the variance in the
dimension.
parent 2ca3d033
Branches
No related tags found
No related merge requests found
......@@ -4,7 +4,6 @@ import java.util.ArrayList;
import java.util.Arrays;
public class KDTreeBuilder {
private static class DividedRecords {
private final int[][] hiRecords;
private final int[][] loRecords;
......@@ -28,48 +27,63 @@ public class KDTreeBuilder {
private int nodeCount = 0;
private int terminalNodeCount = 0;
/**
* Create KDTree builder.
*
* @param dimension Dimension of the feature vectors.
* @param bucketSize Bucket size for the terminal nodes.
*/
public KDTreeBuilder(final int dimension, final int bucketSize) {
this.bucketSize = bucketSize;
this.dimension = dimension;
}
public KDTree buildTree(final int[][] records) {
/**
* Construct the KDTree for provided feature vectors.
*
* @param featureVectors Feature vectors to build the tree with
* @return KDTree.
*/
public KDTree buildTree(final int[][] featureVectors) {
nodeCount = 0;
terminalNodeCount = 0;
final KDNode rootNode = buildTreeImpl(records);
final KDNode rootNode = buildTreeImpl(featureVectors);
return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount);
}
public KDNode buildTreeImpl(final int[][] records) {
if (records.length <= bucketSize) {
return makeTerminalNode(records);
}
double maxSpread = -1.0;
int keyIndex = 0;
for (int j = 0; j < dimension; j++) {
// Find coordinate with greatest spread.
final double greatestSpread = calculateKeySpread(records, j);
if (greatestSpread > maxSpread) {
maxSpread = greatestSpread;
keyIndex = j;
}
/**
* Build KDTree by recursion, feature vectors are split in the dimension with the greatest variance.
*
* @param featureVectors Feature vectors to build the tree with.
* @return Node with its siblings.
*/
public KDNode buildTreeImpl(final int[][] featureVectors) {
if (featureVectors.length <= bucketSize) {
return makeTerminalNode(featureVectors);
}
final int median = calculateKeyMedian(records, keyIndex);
int keyIndexMSE = findDimensionWithGreatestVariance(featureVectors);
final int median = calculateKeyMedian(featureVectors, keyIndexMSE);
// Divide records in one method to hi and lo.
final DividedRecords dividedRecords = divideRecords(records, median, keyIndex);
return makeNonTerminalNode(keyIndex, median, dividedRecords);
final DividedRecords dividedRecords = divideRecords(featureVectors, median, keyIndexMSE);
return makeNonTerminalNode(keyIndexMSE, median, dividedRecords);
}
private DividedRecords divideRecords(final int[][] records, final int median, final int keyIndex) {
/**
* Divide feature vectors into low and high subgroups.
*
* @param featureVectors Feature vectors to divide.
* @param median Median in the dimension.
* @param dimension Dimension index.
* @return Divided vectors.
*/
private DividedRecords divideRecords(final int[][] featureVectors, final int median, final int dimension) {
ArrayList<int[]> loRecords = new ArrayList<>();
ArrayList<int[]> hiRecords = new ArrayList<>();
for (final int[] record : records) {
if (record[keyIndex] <= median) {
for (final int[] record : featureVectors) {
if (record[dimension] <= median) {
loRecords.add(record);
} else {
hiRecords.add(record);
......@@ -78,24 +92,66 @@ public class KDTreeBuilder {
return new DividedRecords(loRecords.toArray(new int[0][]), hiRecords.toArray(new int[0][]));
}
private KDNode makeNonTerminalNode(final int keyIndex, final int median, final DividedRecords dividedRecords) {
/**
* Create internal tree node by recursion on buildTreeImpl.
*
* @param dimension Dimension to split at.
* @param median Median in the selected dimension.
* @param dividedRecords Records divided by the median.
* @return New internal node.
*/
private KDNode makeNonTerminalNode(final int dimension, final int median, final DividedRecords dividedRecords) {
final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords());
final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords());
++nodeCount;
return new KDNode(keyIndex, median, loSon, hiSon);
return new KDNode(dimension, median, loSon, hiSon);
}
public KDNode makeTerminalNode(final int[][] records) {
/**
* Construct terminal node with bucket of feature vectors.
*
* @param featureVectors Feature vectors.
* @return New terminal node.
*/
public KDNode makeTerminalNode(final int[][] featureVectors) {
++nodeCount;
++terminalNodeCount;
return new TerminalKDNode(records);
System.out.printf("Terminal node bucket size: %d\n", featureVectors.length);
return new TerminalKDNode(featureVectors);
}
private int calculateKeyMedian(final int[][] records, final int keyIndex) {
assert (records.length > 1);
final int[] sortedArray = new int[records.length];
for (int i = 0; i < records.length; i++) {
sortedArray[i] = records[i][keyIndex];
/**
* Find the dimension with the greatest variance for the feature vectors.
*
* @param featureVectors Feature vectors.
* @return Index of the dimension with greatest variance/spread.
*/
private int findDimensionWithGreatestVariance(final int[][] featureVectors) {
double maxVar = -1.0;
int dimension = 0;
for (int j = 0; j < this.dimension; j++) {
// Find coordinate with greatest spread.
final double dimVar = calculateDimensionVariance(featureVectors, j);
if (dimVar > maxVar) {
maxVar = dimVar;
dimension = j;
}
}
return dimension;
}
/**
* Calculate the median in selected dimension.
*
* @param featureVectors Feature vectors.
* @param dimension Dimension index.
* @return Median of the dimension.
*/
private int calculateKeyMedian(final int[][] featureVectors, final int dimension) {
assert (featureVectors.length > 1);
final int[] sortedArray = new int[featureVectors.length];
for (int i = 0; i < featureVectors.length; i++) {
sortedArray[i] = featureVectors[i][dimension];
}
Arrays.sort(sortedArray);
......@@ -108,20 +164,26 @@ public class KDTreeBuilder {
}
private double calculateKeySpread(final int[][] records, final int keyIndex) {
double center = 0.0;
for (final int[] record : records) {
center += record[keyIndex];
/**
* Calculate variance of the values in selected dimension.
*
* @param featureVectors Feature vectors.
* @param dimension Dimension index.
* @return Variance in the dimension.
*/
private double calculateDimensionVariance(final int[][] featureVectors, final int dimension) {
double mean = 0.0;
for (final int[] record : featureVectors) {
mean += record[dimension];
}
center /= (double) records.length;
mean /= (double) featureVectors.length;
double spread = 0.0;
double var = 0.0;
for (final int[] record : records) {
spread += Math.pow(((double) center - (double) record[keyIndex]), 2);
// spread += Math.abs(center - record[keyIndex]);
for (final int[] record : featureVectors) {
var += Math.pow(((double) record[dimension] - mean), 2);
}
return Math.sqrt(spread);
// return (spread / (double) records.length);
return (var / (double) featureVectors.length);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment