Skip to content
Snippets Groups Projects
Commit 6c6e06e1 authored by Vojtech Moravec's avatar Vojtech Moravec
Browse files

Finish first implementation of searching in k-DTree.

parent 314de612
Branches
No related tags found
No related merge requests found
...@@ -44,26 +44,34 @@ public class KDNode { ...@@ -44,26 +44,34 @@ public class KDNode {
public void findNearestNeighbor(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { public void findNearestNeighbor(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
if (searchInfo.stopSearching())
return;
if (isTerminal()) { if (isTerminal()) {
// TODO: Examine records in bucket(node), updating PQD , PQR > ,
((TerminalKDNode) this).findNearestNeighborInBucket(queryRecord, searchInfo); ((TerminalKDNode) this).findNearestNeighborInBucket(queryRecord, searchInfo);
if (!ballWithinBounds(queryRecord, searchInfo)) if (ballWithinBounds(queryRecord, searchInfo)) {
searchInfo.setContinueSearching(false);
return; return;
}
} }
assert (loSon != null && hiSon != null);
assert (loSon != null && hiSon != null);
if (queryRecord[discriminator] <= partition) { if (queryRecord[discriminator] <= partition) {
double tmp = searchInfo.getUpperBounds()[discriminator]; double tmp = searchInfo.getUpperBounds()[discriminator];
searchInfo.getUpperBounds()[discriminator] = partition; searchInfo.getUpperBounds()[discriminator] = partition;
loSon.findNearestNeighbor(queryRecord, searchInfo); loSon.findNearestNeighbor(queryRecord, searchInfo);
searchInfo.getUpperBounds()[discriminator] = tmp; searchInfo.getUpperBounds()[discriminator] = tmp;
} else { } else {
double tmp = searchInfo.getLowerBounds()[discriminator]; double tmp = searchInfo.getLowerBounds()[discriminator];
searchInfo.getLowerBounds()[discriminator] = partition; searchInfo.getLowerBounds()[discriminator] = partition;
hiSon.findNearestNeighbor(queryRecord, searchInfo); hiSon.findNearestNeighbor(queryRecord, searchInfo);
searchInfo.getLowerBounds()[discriminator] = tmp; searchInfo.getLowerBounds()[discriminator] = tmp;
} }
if (searchInfo.stopSearching())
return;
if (queryRecord[discriminator] <= partition) { if (queryRecord[discriminator] <= partition) {
double tmp = searchInfo.getLowerBounds()[discriminator]; double tmp = searchInfo.getLowerBounds()[discriminator];
...@@ -80,9 +88,11 @@ public class KDNode { ...@@ -80,9 +88,11 @@ public class KDNode {
} }
searchInfo.getUpperBounds()[discriminator] = tmp; searchInfo.getUpperBounds()[discriminator] = tmp;
} }
if (searchInfo.stopSearching())
if (!ballWithinBounds(queryRecord, searchInfo)) {
return; return;
if (ballWithinBounds(queryRecord, searchInfo)) {
searchInfo.setContinueSearching(false);
} }
} }
...@@ -90,17 +100,16 @@ public class KDNode { ...@@ -90,17 +100,16 @@ public class KDNode {
return Math.pow((x - y), 2); return Math.pow((x - y), 2);
} }
private static double dissim(final double value) { private static double dissimilarity(final double value) {
return Math.sqrt(value); return Math.sqrt(value);
} }
@SuppressWarnings("BooleanMethodIsAlwaysInverted")
private boolean ballWithinBounds(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { private boolean ballWithinBounds(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
double lbDist, ubDist; double lbDist, ubDist;
for (int d = 0; d < searchInfo.getDimension(); d++) { for (int d = 0; d < searchInfo.getDimension(); d++) {
lbDist = coordinateDistance(searchInfo.getLowerBounds()[d], queryRecord[d]); lbDist = coordinateDistance(searchInfo.getLowerBounds()[d], queryRecord[d]);
ubDist = coordinateDistance(searchInfo.getUpperBounds()[d], queryRecord[d]); ubDist = coordinateDistance(searchInfo.getUpperBounds()[d], queryRecord[d]);
if ((lbDist <= searchInfo.getCurrentClosestDistance()) || (ubDist <= searchInfo.getCurrentClosestDistance())) { if ((lbDist <= searchInfo.getNearestRecordDistance()) || (ubDist <= searchInfo.getNearestRecordDistance())) {
return false; return false;
} }
} }
...@@ -112,12 +121,12 @@ public class KDNode { ...@@ -112,12 +121,12 @@ public class KDNode {
for (int d = 0; d < searchInfo.getDimension(); d++) { for (int d = 0; d < searchInfo.getDimension(); d++) {
if (queryRecord[d] < searchInfo.getLowerBounds()[d]) { if (queryRecord[d] < searchInfo.getLowerBounds()[d]) {
sum += coordinateDistance(queryRecord[d], searchInfo.getLowerBounds()[d]); sum += coordinateDistance(queryRecord[d], searchInfo.getLowerBounds()[d]);
if (dissim(sum) > searchInfo.getCurrentClosestDistance()) { if (dissimilarity(sum) > searchInfo.getNearestRecordDistance()) {
return true; return true;
} }
} else if (queryRecord[d] > searchInfo.getUpperBounds()[d]) { } else if (queryRecord[d] > searchInfo.getUpperBounds()[d]) {
sum += coordinateDistance(queryRecord[d], searchInfo.getUpperBounds()[d]); sum += coordinateDistance(queryRecord[d], searchInfo.getUpperBounds()[d]);
if (dissim(sum) > searchInfo.getCurrentClosestDistance()) { if (dissimilarity(sum) > searchInfo.getNearestRecordDistance()) {
return true; return true;
} }
} }
......
...@@ -11,15 +11,16 @@ public class KDTree { ...@@ -11,15 +11,16 @@ public class KDTree {
private final int terminalNodeCount; private final int terminalNodeCount;
public static class SearchInfo { public static class SearchInfo {
private double currentClosestDistance; private boolean continueSearching = true;
private int[] currentClosestRecord = null; private double nearestRecordDistance;
private int[] nearestRecord = null;
private final double[] coordinateUpperBound; private final double[] coordinateUpperBound;
private final double[] coordinateLowerBound; private final double[] coordinateLowerBound;
private final int dimension; private final int dimension;
public SearchInfo(final int dimension) { public SearchInfo(final int dimension) {
this.dimension = dimension; this.dimension = dimension;
currentClosestDistance = Double.POSITIVE_INFINITY; nearestRecordDistance = Double.POSITIVE_INFINITY;
coordinateUpperBound = new double[dimension]; coordinateUpperBound = new double[dimension];
coordinateLowerBound = new double[dimension]; coordinateLowerBound = new double[dimension];
Arrays.fill(coordinateLowerBound, Double.NEGATIVE_INFINITY); Arrays.fill(coordinateLowerBound, Double.NEGATIVE_INFINITY);
...@@ -30,20 +31,12 @@ public class KDTree { ...@@ -30,20 +31,12 @@ public class KDTree {
return dimension; return dimension;
} }
public double getCurrentClosestDistance() { public double getNearestRecordDistance() {
return currentClosestDistance; return nearestRecordDistance;
} }
public void setCurrentClosestDistance(double currentClosestDistance) { public int[] getNearestRecord() {
this.currentClosestDistance = currentClosestDistance; return nearestRecord;
}
public int[] getCurrentClosestRecord() {
return currentClosestRecord;
}
public void setCurrentClosestRecord(int[] currentClosestRecord) {
this.currentClosestRecord = currentClosestRecord;
} }
public double[] getUpperBounds() { public double[] getUpperBounds() {
...@@ -53,6 +46,19 @@ public class KDTree { ...@@ -53,6 +46,19 @@ public class KDTree {
public double[] getLowerBounds() { public double[] getLowerBounds() {
return coordinateLowerBound; return coordinateLowerBound;
} }
public boolean stopSearching() {
return !continueSearching;
}
public void setContinueSearching(boolean continueSearching) {
this.continueSearching = continueSearching;
}
public void setNearestRecord(final int[] record, final double recordDistance) {
this.nearestRecord = record;
this.nearestRecordDistance = recordDistance;
}
} }
public KDTree(final KDNode root, public KDTree(final KDNode root,
...@@ -68,10 +74,9 @@ public class KDTree { ...@@ -68,10 +74,9 @@ public class KDTree {
} }
public int[] findNearestNeighbor(final int[] queryRecord) { public int[] findNearestNeighbor(final int[] queryRecord) {
// TODO(Moravec): Read more about Ball Within Bounds and Bounds Overlap Ball
SearchInfo searchInfo = new SearchInfo(dimension); SearchInfo searchInfo = new SearchInfo(dimension);
root.findNearestNeighbor(queryRecord, searchInfo); root.findNearestNeighbor(queryRecord, searchInfo);
return searchInfo.currentClosestRecord; return searchInfo.nearestRecord;
} }
public int getTotalNodeCount() { public int getTotalNodeCount() {
......
package azgracompress.kdtree; package azgracompress.kdtree;
import azgracompress.utilities.Utils;
public class TerminalKDNode extends KDNode { public class TerminalKDNode extends KDNode {
private final int[][] bucket; private final int[][] bucket;
...@@ -19,6 +21,12 @@ public class TerminalKDNode extends KDNode { ...@@ -19,6 +21,12 @@ public class TerminalKDNode extends KDNode {
} }
public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.SearchInfo searchInfo) { public void findNearestNeighborInBucket(final int[] queryRecord, final KDTree.SearchInfo searchInfo) {
double recordDistance;
for (final int[] record : bucket) {
recordDistance = Utils.calculateEuclideanDistance(queryRecord, record);
if (recordDistance < searchInfo.getNearestRecordDistance()) {
searchInfo.setNearestRecord(record, recordDistance);
}
}
} }
} }
package azgracompress.quantization.vector; package azgracompress.quantization.vector;
import azgracompress.utilities.Utils;
public class VectorQuantizer { public class VectorQuantizer {
private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean; private final VectorDistanceMetric metric = VectorDistanceMetric.Euclidean;
...@@ -104,11 +106,7 @@ public class VectorQuantizer { ...@@ -104,11 +106,7 @@ public class VectorQuantizer {
return sum; return sum;
} }
case Euclidean: { case Euclidean: {
double sum = 0.0; return Utils.calculateEuclideanDistance(originalDataVector, codebookEntry);
for (int i = 0; i < originalDataVector.length; i++) {
sum += Math.pow(((double) originalDataVector[i] - (double) codebookEntry[i]), 2);
}
return Math.sqrt(sum);
} }
case MaxDiff: { case MaxDiff: {
double maxDiff = Double.MIN_VALUE; double maxDiff = Double.MIN_VALUE;
......
...@@ -139,4 +139,20 @@ public class Utils { ...@@ -139,4 +139,20 @@ public class Utils {
final double mse = (sum / (double) difference.length); final double mse = (sum / (double) difference.length);
return mse; return mse;
} }
/**
* Calculate the euclidean distance between two vectors.
*
* @param v1 First vector.
* @param v2 Second vector.
* @return Euclidean distance.
*/
public static double calculateEuclideanDistance(final int[] v1, final int[] v2) {
assert (v1.length == v2.length);
double sum = 0.0;
for (int i = 0; i < v1.length; i++) {
sum += Math.pow(((double) v1[i] - (double) v2[i]), 2);
}
return Math.sqrt(sum);
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment