Skip to content
Snippets Groups Projects
Commit 314de612 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Start implementing search of nearest neighbor.

parent dd23d26d
Branches
No related tags found
No related merge requests found
package azgracompress.kdtree; package azgracompress.kdtree;
public class KDNode { public class KDNode {
private final int keyIndex; private final int discriminator;
private final int median; private final int partition;
private final KDNode loSon; private final KDNode loSon;
private final KDNode hiSon; private final KDNode hiSon;
protected KDNode() { protected KDNode() {
keyIndex = -1; discriminator = -1;
median = -1; partition = -1;
loSon = null; loSon = null;
hiSon = null; hiSon = null;
} }
public KDNode(final int keyIndex, final int median, final KDNode loSon, final KDNode hiSon) { public KDNode(final int keyIndex, final int median, final KDNode loSon, final KDNode hiSon) {
this.keyIndex = keyIndex; this.discriminator = keyIndex;
this.median = median; this.partition = median;
this.loSon = loSon; this.loSon = loSon;
this.hiSon = hiSon; this.hiSon = hiSon;
} }
...@@ -30,15 +30,98 @@ public class KDNode { ...@@ -30,15 +30,98 @@ public class KDNode {
return hiSon; return hiSon;
} }
public final int getKeyIndex() { public final int getDiscriminator() {
return keyIndex; return discriminator;
} }
public final int getMedian() { public final int getPartition() {
return median; return partition;
} }
public boolean isTerminal() { public boolean isTerminal() {
return false; return false;
} }
public void findNearestNeighbor(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
if (isTerminal()) {
// TODO: Examine records in bucket(node), updating PQD , PQR > ,
((TerminalKDNode) this).findNearestNeighborInBucket(queryRecord, searchInfo);
if (!ballWithinBounds(queryRecord, searchInfo))
return;
}
assert (loSon != null && hiSon != null);
if (queryRecord[discriminator] <= partition) {
double tmp = searchInfo.getUpperBounds()[discriminator];
searchInfo.getUpperBounds()[discriminator] = partition;
loSon.findNearestNeighbor(queryRecord, searchInfo);
searchInfo.getUpperBounds()[discriminator] = tmp;
} else {
double tmp = searchInfo.getLowerBounds()[discriminator];
searchInfo.getLowerBounds()[discriminator] = partition;
hiSon.findNearestNeighbor(queryRecord, searchInfo);
searchInfo.getLowerBounds()[discriminator] = tmp;
}
if (queryRecord[discriminator] <= partition) {
double tmp = searchInfo.getLowerBounds()[discriminator];
searchInfo.getLowerBounds()[discriminator] = partition;
if (boundsOverlapBall(queryRecord, searchInfo)) {
hiSon.findNearestNeighbor(queryRecord, searchInfo);
}
searchInfo.getLowerBounds()[discriminator] = tmp;
} else {
double tmp = searchInfo.getUpperBounds()[discriminator];
searchInfo.getUpperBounds()[discriminator] = partition;
if (boundsOverlapBall(queryRecord, searchInfo)) {
loSon.findNearestNeighbor(queryRecord, searchInfo);
}
searchInfo.getUpperBounds()[discriminator] = tmp;
}
if (!ballWithinBounds(queryRecord, searchInfo)) {
return;
}
}
private static double coordinateDistance(final double x, final double y) {
return Math.pow((x - y), 2);
}
private static double dissim(final double value) {
return Math.sqrt(value);
}
@SuppressWarnings("BooleanMethodIsAlwaysInverted")
private boolean ballWithinBounds(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
double lbDist, ubDist;
for (int d = 0; d < searchInfo.getDimension(); d++) {
lbDist = coordinateDistance(searchInfo.getLowerBounds()[d], queryRecord[d]);
ubDist = coordinateDistance(searchInfo.getUpperBounds()[d], queryRecord[d]);
if ((lbDist <= searchInfo.getCurrentClosestDistance()) || (ubDist <= searchInfo.getCurrentClosestDistance())) {
return false;
}
}
return true;
}
private boolean boundsOverlapBall(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
double sum = 0.0;
for (int d = 0; d < searchInfo.getDimension(); d++) {
if (queryRecord[d] < searchInfo.getLowerBounds()[d]) {
sum += coordinateDistance(queryRecord[d], searchInfo.getLowerBounds()[d]);
if (dissim(sum) > searchInfo.getCurrentClosestDistance()) {
return true;
}
} else if (queryRecord[d] > searchInfo.getUpperBounds()[d]) {
sum += coordinateDistance(queryRecord[d], searchInfo.getUpperBounds()[d]);
if (dissim(sum) > searchInfo.getCurrentClosestDistance()) {
return true;
}
}
}
return false;
}
} }
package azgracompress.kdtree;
import java.util.Arrays;
public class KDTree {
private final int maximumBucketSize;
private final KDNode root;
private final int dimension;
private final int totalNodeCount;
private final int terminalNodeCount;
public static class SearchInfo {
private double currentClosestDistance;
private int[] currentClosestRecord = null;
private final double[] coordinateUpperBound;
private final double[] coordinateLowerBound;
private final int dimension;
public SearchInfo(final int dimension) {
this.dimension = dimension;
currentClosestDistance = Double.POSITIVE_INFINITY;
coordinateUpperBound = new double[dimension];
coordinateLowerBound = new double[dimension];
Arrays.fill(coordinateLowerBound, Double.NEGATIVE_INFINITY);
Arrays.fill(coordinateUpperBound, Double.POSITIVE_INFINITY);
}
public int getDimension() {
return dimension;
}
public double getCurrentClosestDistance() {
return currentClosestDistance;
}
public void setCurrentClosestDistance(double currentClosestDistance) {
this.currentClosestDistance = currentClosestDistance;
}
public int[] getCurrentClosestRecord() {
return currentClosestRecord;
}
public void setCurrentClosestRecord(int[] currentClosestRecord) {
this.currentClosestRecord = currentClosestRecord;
}
public double[] getUpperBounds() {
return coordinateUpperBound;
}
public double[] getLowerBounds() {
return coordinateLowerBound;
}
}
public KDTree(final KDNode root,
final int dimension,
final int maximumBucketSize,
final int totalNodeCount,
final int terminalNodeCount) {
this.root = root;
this.dimension = dimension;
this.maximumBucketSize = maximumBucketSize;
this.totalNodeCount = totalNodeCount;
this.terminalNodeCount = terminalNodeCount;
}
public int[] findNearestNeighbor(final int[] queryRecord) {
// TODO(Moravec): Read more about Ball Within Bounds and Bounds Overlap Ball
SearchInfo searchInfo = new SearchInfo(dimension);
root.findNearestNeighbor(queryRecord, searchInfo);
return searchInfo.currentClosestRecord;
}
public int getTotalNodeCount() {
return totalNodeCount;
}
public int getTerminalNodeCount() {
return terminalNodeCount;
}
}
...@@ -33,8 +33,14 @@ public class KDTreeBuilder { ...@@ -33,8 +33,14 @@ public class KDTreeBuilder {
this.dimension = dimension; this.dimension = dimension;
} }
public KDTree buildTree(final int[][] records) {
nodeCount = 0;
terminalNodeCount = 0;
final KDNode rootNode = buildTreeImpl(records);
return new KDTree(rootNode, dimension, bucketSize, nodeCount, terminalNodeCount);
}
public KDNode buildTree(final int[][] records) { public KDNode buildTreeImpl(final int[][] records) {
if (records.length <= bucketSize) { if (records.length <= bucketSize) {
return makeTerminalNode(records); return makeTerminalNode(records);
} }
...@@ -73,8 +79,8 @@ public class KDTreeBuilder { ...@@ -73,8 +79,8 @@ public class KDTreeBuilder {
} }
private KDNode makeNonTerminalNode(final int keyIndex, final int median, final DividedRecords dividedRecords) { private KDNode makeNonTerminalNode(final int keyIndex, final int median, final DividedRecords dividedRecords) {
final KDNode loSon = buildTree(dividedRecords.getLoRecords()); final KDNode loSon = buildTreeImpl(dividedRecords.getLoRecords());
final KDNode hiSon = buildTree(dividedRecords.getHiRecords()); final KDNode hiSon = buildTreeImpl(dividedRecords.getHiRecords());
++nodeCount; ++nodeCount;
return new KDNode(keyIndex, median, loSon, hiSon); return new KDNode(keyIndex, median, loSon, hiSon);
} }
...@@ -113,16 +119,9 @@ public class KDTreeBuilder { ...@@ -113,16 +119,9 @@ public class KDTreeBuilder {
for (final int[] record : records) { for (final int[] record : records) {
spread += Math.pow(((double) center - (double) record[keyIndex]), 2); spread += Math.pow(((double) center - (double) record[keyIndex]), 2);
// spread += Math.abs(center - record[keyIndex]);
} }
return Math.sqrt(spread); return Math.sqrt(spread);
} // return (spread / (double) records.length);
public int getNodeCount() {
return nodeCount;
}
public int getTerminalNodeCount() {
return terminalNodeCount;
} }
} }
...@@ -17,4 +17,8 @@ public class TerminalKDNode extends KDNode { ...@@ -17,4 +17,8 @@ public class TerminalKDNode extends KDNode {
public int[][] getBucket() { public int[][] getBucket() {
return bucket; return bucket;
} }
public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment