From 5f85a59d4be37d350bcf1ee62c25ac1f84d71770 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Mon, 6 Jul 2020 22:24:02 -0400 Subject: kd: Use a more traditional k-d tree implementation The slight extra pruning possible in the previous implementation didn't seem to be worth it. The new, simpler implementation is also about 30% faster in most of the benchmarks. This gets rid of Coordinate{Proximity,Metric} as they're not necessary any more (and the old ExactNeighbors impl was too restrictive anyway). --- src/kd.rs | 92 ++++++++++++++++++++++++--------------------------------------- 1 file changed, 35 insertions(+), 57 deletions(-) (limited to 'src/kd.rs') diff --git a/src/kd.rs b/src/kd.rs index 291028e..dae73ec 100644 --- a/src/kd.rs +++ b/src/kd.rs @@ -1,10 +1,13 @@ //! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree). -use crate::coords::{CoordinateMetric, CoordinateProximity, Coordinates}; -use crate::distance::{Metric, Proximity}; +use crate::coords::Coordinates; +use crate::distance::Proximity; +use crate::lp::Minkowski; use crate::util::Ordered; use crate::{ExactNeighbors, NearestNeighbors, Neighborhood}; +use num_traits::Signed; + use std::iter::FromIterator; use std::ops::Deref; @@ -86,7 +89,7 @@ pub trait KdProximity where Self: Coordinates, Self: Proximity, - Self: CoordinateProximity>::Distance>, + Self::Value: PartialOrd, V: Coordinates, {} @@ -95,31 +98,14 @@ impl KdProximity for K where K: Coordinates, K: Proximity, - K: CoordinateProximity>::Distance>, - V: Coordinates, -{} - -/// Marker trait for [`Metric`] implementations that are compatible with k-d tree. -pub trait KdMetric -where - Self: KdProximity, - Self: Metric, - Self: CoordinateMetric, - V: Coordinates, -{} - -/// Blanket [`KdMetric`] implementation. -impl KdMetric for K -where - K: KdProximity, - K: Metric, - K: CoordinateMetric, + K::Value: PartialOrd, V: Coordinates, {} trait KdSearch: Copy where K: KdProximity, + K::Value: PartialOrd, V: Coordinates + Copy, N: Neighborhood, { @@ -133,41 +119,29 @@ where fn right(self) -> Option; /// Recursively search for nearest neighbors. - fn search(self, level: usize, closest: &mut [V::Value], neighborhood: &mut N) { + fn search(self, level: usize, neighborhood: &mut N) { let item = self.item(); neighborhood.consider(item); let target = neighborhood.target(); - if target.coord(level) <= item.coord(level) { - self.search_near(self.left(), level, closest, neighborhood); - self.search_far(self.right(), level, closest, neighborhood); + let bound = target.coord(level) - item.coord(level); + let (near, far) = if bound.is_negative() { + (self.left(), self.right()) } else { - self.search_near(self.right(), level, closest, neighborhood); - self.search_far(self.left(), level, closest, neighborhood); - } - } + (self.right(), self.left()) + }; + + let next = (level + 1) % self.item().dims(); - /// Search the subtree closest to the target. - fn search_near(self, near: Option, level: usize, closest: &mut [V::Value], neighborhood: &mut N) { if let Some(near) = near { - let next = (level + 1) % self.item().dims(); - near.search(next, closest, neighborhood); + near.search(next, neighborhood); } - } - /// Search the subtree farthest from the target. - fn search_far(self, far: Option, level: usize, closest: &mut [V::Value], neighborhood: &mut N) { if let Some(far) = far { - // Update the closest possible point - let item = self.item(); - let target = neighborhood.target(); - let saved = std::mem::replace(&mut closest[level], item.coord(level)); - if neighborhood.contains(target.distance_to_coords(closest)) { - let next = (level + 1) % item.dims(); - far.search(next, closest, neighborhood); + if neighborhood.contains(bound.abs()) { + far.search(next, neighborhood); } - closest[level] = saved; } } } @@ -175,6 +149,7 @@ where impl<'a, K, V, N> KdSearch for &'a KdNode where K: KdProximity<&'a V>, + K::Value: PartialOrd, V: Coordinates, N: Neighborhood, { @@ -315,6 +290,7 @@ impl IntoIterator for KdTree { impl NearestNeighbors for KdTree where K: KdProximity, + K::Value: PartialOrd, V: Coordinates, { fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N @@ -324,16 +300,17 @@ where N: Neighborhood<&'k K, &'v V>, { if let Some(root) = &self.root { - let mut closest = neighborhood.target().as_vec(); - root.search(0, &mut closest, &mut neighborhood); + root.search(0, &mut neighborhood); } neighborhood } } +/// k-d trees are exact for [Minkowski] distances. impl ExactNeighbors for KdTree where - K: KdMetric, + K: KdProximity + Minkowski, + K::Value: PartialOrd, V: Coordinates, {} @@ -389,6 +366,7 @@ impl FlatKdNode { impl<'a, K, V, N> KdSearch for &'a [FlatKdNode] where K: KdProximity<&'a V>, + K::Value: PartialOrd, V: Coordinates, N: Neighborhood, { @@ -465,6 +443,7 @@ impl IntoIterator for FlatKdTree { impl NearestNeighbors for FlatKdTree where K: KdProximity, + K::Value: PartialOrd, V: Coordinates, { fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N @@ -474,18 +453,17 @@ where N: Neighborhood<&'k K, &'v V>, { if !self.nodes.is_empty() { - let mut closest = neighborhood.target().as_vec(); - self.nodes - .as_slice() - .search(0, &mut closest, &mut neighborhood); + self.nodes.as_slice().search(0, &mut neighborhood); } neighborhood } } +/// k-d trees are exact for [Minkowski] distances. impl ExactNeighbors for FlatKdTree where - K: KdMetric, + K: KdProximity + Minkowski, + K::Value: PartialOrd, V: Coordinates, {} @@ -493,16 +471,16 @@ where mod tests { use super::*; - use crate::tests::test_nearest_neighbors; + use crate::tests::test_exact_neighbors; #[test] fn test_kd_tree() { - test_nearest_neighbors(KdTree::from_iter); + test_exact_neighbors(KdTree::from_iter); } #[test] fn test_unbalanced_kd_tree() { - test_nearest_neighbors(|points| { + test_exact_neighbors(|points| { let mut tree = KdTree::new(); for point in points { tree.push(point); @@ -513,6 +491,6 @@ mod tests { #[test] fn test_flat_kd_tree() { - test_nearest_neighbors(FlatKdTree::from_iter); + test_exact_neighbors(FlatKdTree::from_iter); } } -- cgit v1.2.3