summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-04-30 22:51:06 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-05-02 13:19:05 -0400
commit5de377b2b00a927a4f6463c1c5a5fd18606ad006 (patch)
treefec5304ea55f418e862bbb211d6e99c0f9921ada
parent1c560791902a4ef72efa671106d8f6d97fea50c1 (diff)
downloadkd-forest-5de377b2b00a927a4f6463c1c5a5fd18606ad006.tar.xz
metric/kd: Prune k-d tree searches more aggressively
-rw-r--r--src/metric/kd.rs116
1 files changed, 65 insertions, 51 deletions
diff --git a/src/metric/kd.rs b/src/metric/kd.rs
index ab0cd2e..db1b2bd 100644
--- a/src/metric/kd.rs
+++ b/src/metric/kd.rs
@@ -7,7 +7,7 @@ use ordered_float::OrderedFloat;
use std::iter::FromIterator;
/// A point in Cartesian space.
-pub trait Cartesian {
+pub trait Cartesian: Metric<[f64]> {
/// Returns the number of dimensions necessary to describe this point.
fn dimensions(&self) -> usize;
@@ -26,6 +26,15 @@ impl<'a, T: Cartesian> Cartesian for &'a T {
}
}
+/// Blanket [Metric<[f64]>](Metric) implementation for [Cartesian] references.
+impl<'a, T: Cartesian> Metric<[f64]> for &'a T {
+ type Distance = T::Distance;
+
+ fn distance(&self, other: &[f64]) -> Self::Distance {
+ (*self).distance(other)
+ }
+}
+
/// Standard cartesian space.
impl Cartesian for [f64] {
fn dimensions(&self) -> usize {
@@ -37,6 +46,21 @@ impl Cartesian for [f64] {
}
}
+/// Marker trait for cartesian metric spaces.
+pub trait CartesianMetric<T: ?Sized = Self>:
+ Cartesian + Metric<T, Distance = <Self as Metric<[f64]>>::Distance>
+{
+}
+
+/// Blanket [CartesianMetric] implementation for cartesian spaces with compatible metric distance
+/// types.
+impl<T, U> CartesianMetric<T> for U
+where
+ T: ?Sized,
+ U: ?Sized + Cartesian + Metric<T, Distance = <U as Metric<[f64]>>::Distance>,
+{
+}
+
/// A node in a k-d tree.
#[derive(Debug)]
struct KdNode<T> {
@@ -48,54 +72,6 @@ struct KdNode<T> {
right: Option<Box<Self>>,
}
-trait KdSearch<'a, T, U, N> {
- /// Recursively search for nearest neighbors.
- fn search(&'a self, i: usize, neighborhood: &mut N);
-
- /// Search the left subtree.
- fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N);
-
- /// Search the right subtree.
- fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N);
-}
-
-impl<'a, T, U, N> KdSearch<'a, T, U, N> for KdNode<T>
-where
- T: 'a + Cartesian,
- U: Cartesian + Metric<&'a T>,
- N: Neighborhood<&'a T, U>,
-{
- fn search(&'a self, i: usize, neighborhood: &mut N) {
- neighborhood.consider(&self.item);
-
- let distance = neighborhood.target().coordinate(i) - self.item.coordinate(i);
- let j = (i + 1) % self.item.dimensions();
- if distance <= 0.0 {
- self.search_left(j, distance, neighborhood);
- self.search_right(j, -distance, neighborhood);
- } else {
- self.search_right(j, -distance, neighborhood);
- self.search_left(j, distance, neighborhood);
- }
- }
-
- fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N) {
- if let Some(left) = &self.left {
- if neighborhood.contains(distance) {
- left.search(i, neighborhood);
- }
- }
- }
-
- fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N) {
- if let Some(right) = &self.right {
- if neighborhood.contains(distance) {
- right.search(i, neighborhood);
- }
- }
- }
-}
-
impl<T: Cartesian> KdNode<T> {
/// Create a new KdNode.
fn new(i: usize, mut items: Vec<T>) -> Option<Box<Self>> {
@@ -115,6 +91,40 @@ impl<T: Cartesian> KdNode<T> {
right: Self::new(j, right),
}))
}
+
+ /// Recursively search for nearest neighbors.
+ fn search<'a, U, N>(&'a self, i: usize, closest: &mut [f64], neighborhood: &mut N)
+ where
+ T: 'a,
+ U: CartesianMetric<&'a T>,
+ N: Neighborhood<&'a T, U>,
+ {
+ neighborhood.consider(&self.item);
+
+ let target = neighborhood.target();
+ let ti = target.coordinate(i);
+ let si = self.item.coordinate(i);
+ let j = (i + 1) % self.item.dimensions();
+
+ let (near, far) = if ti <= si {
+ (&self.left, &self.right)
+ } else {
+ (&self.right, &self.left)
+ };
+
+ if let Some(near) = near {
+ near.search(j, closest, neighborhood);
+ }
+
+ if let Some(far) = far {
+ let saved = closest[i];
+ closest[i] = si;
+ if neighborhood.contains_distance(target.distance(closest)) {
+ far.search(j, closest, neighborhood);
+ }
+ closest[i] = saved;
+ }
+ }
}
/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree).
@@ -135,7 +145,7 @@ impl<T: Cartesian> FromIterator<T> for KdTree<T> {
impl<T, U> NearestNeighbors<T, U> for KdTree<T>
where
T: Cartesian,
- U: Cartesian + Metric<T>,
+ U: CartesianMetric<T>,
{
fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N
where
@@ -143,8 +153,12 @@ where
U: 'b,
N: Neighborhood<&'a T, &'b U>,
{
+ let target = neighborhood.target();
+ let dims = target.dimensions();
+ let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect();
+
if let Some(root) = &self.root {
- root.search(0, &mut neighborhood);
+ root.search(0, &mut closest, &mut neighborhood);
}
neighborhood
}