summaryrefslogtreecommitdiffstats
path: root/src/kd.rs
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-07-06 22:24:02 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-07-06 22:33:10 -0400
commit5f85a59d4be37d350bcf1ee62c25ac1f84d71770 (patch)
tree8fc7ea8e59226c5e677d7b9aef39b0b2be5f28b7 /src/kd.rs
parented4d7b7143f1a8a9602698ca3e60e18bbb4dd226 (diff)
downloadacap-5f85a59d4be37d350bcf1ee62c25ac1f84d71770.tar.xz
kd: Use a more traditional k-d tree implementation
The slight extra pruning possible in the previous implementation didn't seem to be worth it. The new, simpler implementation is also about 30% faster in most of the benchmarks. This gets rid of Coordinate{Proximity,Metric} as they're not necessary any more (and the old ExactNeighbors impl was too restrictive anyway).
Diffstat (limited to 'src/kd.rs')
-rw-r--r--src/kd.rs92
1 files changed, 35 insertions, 57 deletions
diff --git a/src/kd.rs b/src/kd.rs
index 291028e..dae73ec 100644
--- a/src/kd.rs
+++ b/src/kd.rs
@@ -1,10 +1,13 @@
//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree).
-use crate::coords::{CoordinateMetric, CoordinateProximity, Coordinates};
-use crate::distance::{Metric, Proximity};
+use crate::coords::Coordinates;
+use crate::distance::Proximity;
+use crate::lp::Minkowski;
use crate::util::Ordered;
use crate::{ExactNeighbors, NearestNeighbors, Neighborhood};
+use num_traits::Signed;
+
use std::iter::FromIterator;
use std::ops::Deref;
@@ -86,7 +89,7 @@ pub trait KdProximity<V: ?Sized = Self>
where
Self: Coordinates<Value = V::Value>,
Self: Proximity<V>,
- Self: CoordinateProximity<V::Value, Distance = <Self as Proximity<V>>::Distance>,
+ Self::Value: PartialOrd<Self::Distance>,
V: Coordinates,
{}
@@ -95,31 +98,14 @@ impl<K, V> KdProximity<V> for K
where
K: Coordinates<Value = V::Value>,
K: Proximity<V>,
- K: CoordinateProximity<V::Value, Distance = <K as Proximity<V>>::Distance>,
- V: Coordinates,
-{}
-
-/// Marker trait for [`Metric`] implementations that are compatible with k-d tree.
-pub trait KdMetric<V: ?Sized = Self>
-where
- Self: KdProximity<V>,
- Self: Metric<V>,
- Self: CoordinateMetric<V::Value>,
- V: Coordinates,
-{}
-
-/// Blanket [`KdMetric`] implementation.
-impl<K, V> KdMetric<V> for K
-where
- K: KdProximity<V>,
- K: Metric<V>,
- K: CoordinateMetric<V::Value>,
+ K::Value: PartialOrd<K::Distance>,
V: Coordinates,
{}
trait KdSearch<K, V, N>: Copy
where
K: KdProximity<V>,
+ K::Value: PartialOrd<K::Distance>,
V: Coordinates + Copy,
N: Neighborhood<K, V>,
{
@@ -133,41 +119,29 @@ where
fn right(self) -> Option<Self>;
/// Recursively search for nearest neighbors.
- fn search(self, level: usize, closest: &mut [V::Value], neighborhood: &mut N) {
+ fn search(self, level: usize, neighborhood: &mut N) {
let item = self.item();
neighborhood.consider(item);
let target = neighborhood.target();
- if target.coord(level) <= item.coord(level) {
- self.search_near(self.left(), level, closest, neighborhood);
- self.search_far(self.right(), level, closest, neighborhood);
+ let bound = target.coord(level) - item.coord(level);
+ let (near, far) = if bound.is_negative() {
+ (self.left(), self.right())
} else {
- self.search_near(self.right(), level, closest, neighborhood);
- self.search_far(self.left(), level, closest, neighborhood);
- }
- }
+ (self.right(), self.left())
+ };
+
+ let next = (level + 1) % self.item().dims();
- /// Search the subtree closest to the target.
- fn search_near(self, near: Option<Self>, level: usize, closest: &mut [V::Value], neighborhood: &mut N) {
if let Some(near) = near {
- let next = (level + 1) % self.item().dims();
- near.search(next, closest, neighborhood);
+ near.search(next, neighborhood);
}
- }
- /// Search the subtree farthest from the target.
- fn search_far(self, far: Option<Self>, level: usize, closest: &mut [V::Value], neighborhood: &mut N) {
if let Some(far) = far {
- // Update the closest possible point
- let item = self.item();
- let target = neighborhood.target();
- let saved = std::mem::replace(&mut closest[level], item.coord(level));
- if neighborhood.contains(target.distance_to_coords(closest)) {
- let next = (level + 1) % item.dims();
- far.search(next, closest, neighborhood);
+ if neighborhood.contains(bound.abs()) {
+ far.search(next, neighborhood);
}
- closest[level] = saved;
}
}
}
@@ -175,6 +149,7 @@ where
impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a KdNode<V>
where
K: KdProximity<&'a V>,
+ K::Value: PartialOrd<K::Distance>,
V: Coordinates,
N: Neighborhood<K, &'a V>,
{
@@ -315,6 +290,7 @@ impl<T> IntoIterator for KdTree<T> {
impl<K, V> NearestNeighbors<K, V> for KdTree<V>
where
K: KdProximity<V>,
+ K::Value: PartialOrd<K::Distance>,
V: Coordinates,
{
fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
@@ -324,16 +300,17 @@ where
N: Neighborhood<&'k K, &'v V>,
{
if let Some(root) = &self.root {
- let mut closest = neighborhood.target().as_vec();
- root.search(0, &mut closest, &mut neighborhood);
+ root.search(0, &mut neighborhood);
}
neighborhood
}
}
+/// k-d trees are exact for [Minkowski] distances.
impl<K, V> ExactNeighbors<K, V> for KdTree<V>
where
- K: KdMetric<V>,
+ K: KdProximity<V> + Minkowski<V>,
+ K::Value: PartialOrd<K::Distance>,
V: Coordinates,
{}
@@ -389,6 +366,7 @@ impl<T: Coordinates> FlatKdNode<T> {
impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a [FlatKdNode<V>]
where
K: KdProximity<&'a V>,
+ K::Value: PartialOrd<K::Distance>,
V: Coordinates,
N: Neighborhood<K, &'a V>,
{
@@ -465,6 +443,7 @@ impl<T> IntoIterator for FlatKdTree<T> {
impl<K, V> NearestNeighbors<K, V> for FlatKdTree<V>
where
K: KdProximity<V>,
+ K::Value: PartialOrd<K::Distance>,
V: Coordinates,
{
fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
@@ -474,18 +453,17 @@ where
N: Neighborhood<&'k K, &'v V>,
{
if !self.nodes.is_empty() {
- let mut closest = neighborhood.target().as_vec();
- self.nodes
- .as_slice()
- .search(0, &mut closest, &mut neighborhood);
+ self.nodes.as_slice().search(0, &mut neighborhood);
}
neighborhood
}
}
+/// k-d trees are exact for [Minkowski] distances.
impl<K, V> ExactNeighbors<K, V> for FlatKdTree<V>
where
- K: KdMetric<V>,
+ K: KdProximity<V> + Minkowski<V>,
+ K::Value: PartialOrd<K::Distance>,
V: Coordinates,
{}
@@ -493,16 +471,16 @@ where
mod tests {
use super::*;
- use crate::tests::test_nearest_neighbors;
+ use crate::tests::test_exact_neighbors;
#[test]
fn test_kd_tree() {
- test_nearest_neighbors(KdTree::from_iter);
+ test_exact_neighbors(KdTree::from_iter);
}
#[test]
fn test_unbalanced_kd_tree() {
- test_nearest_neighbors(|points| {
+ test_exact_neighbors(|points| {
let mut tree = KdTree::new();
for point in points {
tree.push(point);
@@ -513,6 +491,6 @@ mod tests {
#[test]
fn test_flat_kd_tree() {
- test_nearest_neighbors(FlatKdTree::from_iter);
+ test_exact_neighbors(FlatKdTree::from_iter);
}
}