From 5de377b2b00a927a4f6463c1c5a5fd18606ad006 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Thu, 30 Apr 2020 22:51:06 -0400 Subject: metric/kd: Prune k-d tree searches more aggressively --- src/metric/kd.rs | 116 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 65 insertions(+), 51 deletions(-) diff --git a/src/metric/kd.rs b/src/metric/kd.rs index ab0cd2e..db1b2bd 100644 --- a/src/metric/kd.rs +++ b/src/metric/kd.rs @@ -7,7 +7,7 @@ use ordered_float::OrderedFloat; use std::iter::FromIterator; /// A point in Cartesian space. -pub trait Cartesian { +pub trait Cartesian: Metric<[f64]> { /// Returns the number of dimensions necessary to describe this point. fn dimensions(&self) -> usize; @@ -26,6 +26,15 @@ impl<'a, T: Cartesian> Cartesian for &'a T { } } +/// Blanket [Metric<[f64]>](Metric) implementation for [Cartesian] references. +impl<'a, T: Cartesian> Metric<[f64]> for &'a T { + type Distance = T::Distance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + (*self).distance(other) + } +} + /// Standard cartesian space. impl Cartesian for [f64] { fn dimensions(&self) -> usize { @@ -37,6 +46,21 @@ impl Cartesian for [f64] { } } +/// Marker trait for cartesian metric spaces. +pub trait CartesianMetric: + Cartesian + Metric>::Distance> +{ +} + +/// Blanket [CartesianMetric] implementation for cartesian spaces with compatible metric distance +/// types. +impl CartesianMetric for U +where + T: ?Sized, + U: ?Sized + Cartesian + Metric>::Distance>, +{ +} + /// A node in a k-d tree. #[derive(Debug)] struct KdNode { @@ -48,54 +72,6 @@ struct KdNode { right: Option>, } -trait KdSearch<'a, T, U, N> { - /// Recursively search for nearest neighbors. - fn search(&'a self, i: usize, neighborhood: &mut N); - - /// Search the left subtree. - fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N); - - /// Search the right subtree. - fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N); -} - -impl<'a, T, U, N> KdSearch<'a, T, U, N> for KdNode -where - T: 'a + Cartesian, - U: Cartesian + Metric<&'a T>, - N: Neighborhood<&'a T, U>, -{ - fn search(&'a self, i: usize, neighborhood: &mut N) { - neighborhood.consider(&self.item); - - let distance = neighborhood.target().coordinate(i) - self.item.coordinate(i); - let j = (i + 1) % self.item.dimensions(); - if distance <= 0.0 { - self.search_left(j, distance, neighborhood); - self.search_right(j, -distance, neighborhood); - } else { - self.search_right(j, -distance, neighborhood); - self.search_left(j, distance, neighborhood); - } - } - - fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N) { - if let Some(left) = &self.left { - if neighborhood.contains(distance) { - left.search(i, neighborhood); - } - } - } - - fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N) { - if let Some(right) = &self.right { - if neighborhood.contains(distance) { - right.search(i, neighborhood); - } - } - } -} - impl KdNode { /// Create a new KdNode. fn new(i: usize, mut items: Vec) -> Option> { @@ -115,6 +91,40 @@ impl KdNode { right: Self::new(j, right), })) } + + /// Recursively search for nearest neighbors. + fn search<'a, U, N>(&'a self, i: usize, closest: &mut [f64], neighborhood: &mut N) + where + T: 'a, + U: CartesianMetric<&'a T>, + N: Neighborhood<&'a T, U>, + { + neighborhood.consider(&self.item); + + let target = neighborhood.target(); + let ti = target.coordinate(i); + let si = self.item.coordinate(i); + let j = (i + 1) % self.item.dimensions(); + + let (near, far) = if ti <= si { + (&self.left, &self.right) + } else { + (&self.right, &self.left) + }; + + if let Some(near) = near { + near.search(j, closest, neighborhood); + } + + if let Some(far) = far { + let saved = closest[i]; + closest[i] = si; + if neighborhood.contains_distance(target.distance(closest)) { + far.search(j, closest, neighborhood); + } + closest[i] = saved; + } + } } /// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). @@ -135,7 +145,7 @@ impl FromIterator for KdTree { impl NearestNeighbors for KdTree where T: Cartesian, - U: Cartesian + Metric, + U: CartesianMetric, { fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N where @@ -143,8 +153,12 @@ where U: 'b, N: Neighborhood<&'a T, &'b U>, { + let target = neighborhood.target(); + let dims = target.dimensions(); + let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); + if let Some(root) = &self.root { - root.search(0, &mut neighborhood); + root.search(0, &mut closest, &mut neighborhood); } neighborhood } -- cgit v1.2.3