From 15ec99c64f65da7966b4282ff94fee0a611c23df Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Thu, 25 Feb 2021 11:24:42 -0500 Subject: knn: Move NearestNeighbor interfaces to a submodule --- src/distance.rs | 2 +- src/exhaustive.rs | 6 +- src/kd.rs | 4 +- src/knn.rs | 491 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 514 ++---------------------------------------------------- src/util.rs | 10 +- src/vp.rs | 4 +- 7 files changed, 523 insertions(+), 508 deletions(-) create mode 100644 src/knn.rs (limited to 'src') diff --git a/src/distance.rs b/src/distance.rs index e44ed03..680f11f 100644 --- a/src/distance.rs +++ b/src/distance.rs @@ -108,7 +108,7 @@ impl Distance for T { /// With those implementations available, you could use a [`NearestNeighbors`] /// instance to find the closest point(s) of interest to any GPS location. /// -/// [`NearestNeighbors`]: super::NearestNeighbors +/// [`NearestNeighbors`]: crate::knn::NearestNeighbors pub trait Proximity { /// The type that represents distances. type Distance: Distance; diff --git a/src/exhaustive.rs b/src/exhaustive.rs index 442850c..f0abf9c 100644 --- a/src/exhaustive.rs +++ b/src/exhaustive.rs @@ -1,7 +1,7 @@ -//! Exhaustive nearest neighbor search. +//! [Exhaustive nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Linear_search). use crate::distance::Proximity; -use crate::{ExactNeighbors, NearestNeighbors, Neighborhood}; +use crate::knn::{ExactNeighbors, NearestNeighbors, Neighborhood}; use std::iter::FromIterator; @@ -118,7 +118,7 @@ impl, V> ExactNeighbors for ExhaustiveSearch {} pub mod tests { use super::*; - use crate::tests::test_exact_neighbors; + use crate::knn::tests::test_exact_neighbors; #[test] fn test_exhaustive_index() { diff --git a/src/kd.rs b/src/kd.rs index d37321e..bf6b7c6 100644 --- a/src/kd.rs +++ b/src/kd.rs @@ -3,8 +3,8 @@ use crate::coords::Coordinates; use crate::distance::Proximity; use crate::lp::Minkowski; +use crate::knn::{ExactNeighbors, NearestNeighbors, Neighborhood}; use crate::util::Ordered; -use crate::{ExactNeighbors, NearestNeighbors, Neighborhood}; use num_traits::Signed; @@ -541,7 +541,7 @@ where mod tests { use super::*; - use crate::tests::test_exact_neighbors; + use crate::knn::tests::test_exact_neighbors; #[test] fn test_kd_tree() { diff --git a/src/knn.rs b/src/knn.rs new file mode 100644 index 0000000..1cc1f39 --- /dev/null +++ b/src/knn.rs @@ -0,0 +1,491 @@ +//! [Nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search) interfaces. + +use crate::distance::{Distance, Proximity}; + +use std::convert::TryInto; + +/// A nearest neighbor. +#[derive(Clone, Copy, Debug)] +pub struct Neighbor { + /// The neighbor itself. + pub item: V, + /// The distance from the target to this neighbor. + pub distance: D, +} + +impl Neighbor { + /// Create a new Neighbor. + pub fn new(item: V, distance: D) -> Self { + Self { item, distance } + } +} + +impl PartialEq> for Neighbor +where + V1: PartialEq, + D1: PartialEq, +{ + fn eq(&self, other: &Neighbor) -> bool { + self.item == other.item && self.distance == other.distance + } +} + +/// Accumulates nearest neighbor search results. +/// +/// Type parameters: +/// +/// * `K`: The type of the search target (the "key" type) +/// * `V`: The type of neighbors this contains (the "value" type) +/// +/// Neighborhood implementations keep track of the current search radius and accumulate the results, +/// work which would otherwise have to be duplicated for every nearest neighbor search algorithm. +/// They also serve as a customization point, allowing for functionality to be injected into any +/// [NearestNeighbors] implementation (for example, filtering the result set or limiting the number +/// of neighbors considered). +pub trait Neighborhood, V> { + /// Returns the target of the nearest neighbor search. + fn target(&self) -> K; + + /// Check whether a distance is within the current search radius. + fn contains(&self, distance: D) -> bool + where + D: PartialOrd; + + /// Consider a new candidate neighbor. + /// + /// Returns `self.target().distance(item)`. + fn consider(&mut self, item: V) -> K::Distance; +} + +/// A [Neighborhood] with at most one result. +#[derive(Debug)] +struct SingletonNeighborhood { + /// The search target. + target: K, + /// The current threshold distance. + threshold: Option, + /// The current nearest neighbor, if any. + neighbor: Option>, +} + +impl SingletonNeighborhood { + /// Create a new singleton neighborhood. + /// + /// * `target`: The search target. + /// * `threshold`: The maximum allowable distance. + fn new(target: K, threshold: Option) -> Self { + Self { + target, + threshold, + neighbor: None, + } + } + + /// Convert this result into an optional neighbor. + fn into_option(self) -> Option> { + self.neighbor + } +} + +impl Neighborhood for SingletonNeighborhood +where + K: Copy + Proximity, +{ + fn target(&self) -> K { + self.target + } + + fn contains(&self, distance: D) -> bool + where + D: PartialOrd, + { + self.threshold.map_or(true, |t| distance <= t) + } + + fn consider(&mut self, item: V) -> K::Distance { + let distance = self.target.distance(&item); + + if self.contains(distance) { + self.threshold = Some(distance); + self.neighbor = Some(Neighbor::new(item, distance)); + } + + distance + } +} + +/// A [Neighborhood] of up to `k` results, using a binary heap. +#[derive(Debug)] +struct HeapNeighborhood<'a, K, V, D> { + /// The target of the nearest neighbor search. + target: K, + /// The number of nearest neighbors to find. + k: usize, + /// The current threshold distance to the farthest result. + threshold: Option, + /// A max-heap of the best candidates found so far. + heap: &'a mut Vec>, +} + +impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> { + /// Create a new HeapNeighborhood. + /// + /// * `target`: The search target. + /// * `k`: The maximum number of nearest neighbors to find. + /// * `threshold`: The maximum allowable distance. + /// * `heap`: The vector of neighbors to use as the heap. + fn new( + target: K, + k: usize, + mut threshold: Option, + heap: &'a mut Vec>, + ) -> Self { + // A descending array is also a max-heap + heap.reverse(); + + if k > 0 && heap.len() == k { + let distance = heap[0].distance; + if threshold.map_or(true, |t| distance <= t) { + threshold = Some(distance); + } + } + + Self { + target, + k, + threshold, + heap, + } + } + + /// Push a new element into the heap. + fn push(&mut self, item: Neighbor) { + let mut i = self.heap.len(); + self.heap.push(item); + + while i > 0 { + let parent = (i - 1) / 2; + if self.heap[i].distance > self.heap[parent].distance { + self.heap.swap(i, parent); + i = parent; + } else { + break; + } + } + } + + /// Restore the heap property by lowering the root. + fn sink_root(&mut self, len: usize) { + let mut i = 0; + let dist = self.heap[i].distance; + + loop { + let mut child = 2 * i + 1; + let right = child + 1; + if right < len && self.heap[child].distance < self.heap[right].distance { + child = right; + } + + if child < len && dist < self.heap[child].distance { + self.heap.swap(i, child); + i = child; + } else { + break; + } + } + } + + /// Replace the root of the heap with a new element. + fn replace_root(&mut self, item: Neighbor) { + self.heap[0] = item; + self.sink_root(self.heap.len()); + } + + /// Sort the heap from smallest to largest distance. + fn sort(&mut self) { + for i in (0..self.heap.len()).rev() { + self.heap.swap(0, i); + self.sink_root(i); + } + } +} + +impl<'a, K, V> Neighborhood for HeapNeighborhood<'a, K, V, K::Distance> +where + K: Copy + Proximity, +{ + fn target(&self) -> K { + self.target + } + + fn contains(&self, distance: D) -> bool + where + D: PartialOrd, + { + self.k > 0 && self.threshold.map_or(true, |t| distance <= t) + } + + fn consider(&mut self, item: V) -> K::Distance { + let distance = self.target.distance(&item); + + if self.contains(distance) { + let neighbor = Neighbor::new(item, distance); + + if self.heap.len() < self.k { + self.push(neighbor); + } else { + self.replace_root(neighbor); + } + + if self.heap.len() == self.k { + self.threshold = Some(self.heap[0].distance); + } + } + + distance + } +} + +/// A [nearest neighbor search] index. +/// +/// Type parameters: +/// +/// * `K`: The type of the search target (the "key" type) +/// * `V`: The type of the returned neighbors (the "value" type) +/// +/// In general, exact nearest neighbor searches may be prohibitively expensive due to the [curse of +/// dimensionality]. Therefore, NearestNeighbor implementations are allowed to give approximate +/// results. The marker trait [ExactNeighbors] denotes implementations which are guaranteed to give +/// exact results. +/// +/// [nearest neighbor search]: https://en.wikipedia.org/wiki/Nearest_neighbor_search +/// [curse of dimensionality]: https://en.wikipedia.org/wiki/Curse_of_dimensionality +pub trait NearestNeighbors, V = K> { + /// Returns the nearest neighbor to `target` (or `None` if this index is empty). + fn nearest(&self, target: &K) -> Option> { + self.search(SingletonNeighborhood::new(target, None)) + .into_option() + } + + /// Returns the nearest neighbor to `target` within the distance `threshold`, if one exists. + fn nearest_within(&self, target: &K, threshold: D) -> Option> + where + D: TryInto, + { + if let Ok(distance) = threshold.try_into() { + self.search(SingletonNeighborhood::new(target, Some(distance))) + .into_option() + } else { + None + } + } + + /// Returns the up to `k` nearest neighbors to `target`. + /// + /// The result will be sorted from nearest to farthest. + fn k_nearest(&self, target: &K, k: usize) -> Vec> { + let mut neighbors = Vec::with_capacity(k); + self.merge_k_nearest(target, k, &mut neighbors); + neighbors + } + + /// Returns the up to `k` nearest neighbors to `target` within the distance `threshold`. + /// + /// The result will be sorted from nearest to farthest. + fn k_nearest_within( + &self, + target: &K, + k: usize, + threshold: D, + ) -> Vec> + where + D: TryInto, + { + let mut neighbors = Vec::with_capacity(k); + self.merge_k_nearest_within(target, k, threshold, &mut neighbors); + neighbors + } + + /// Merges up to `k` nearest neighbors into an existing sorted vector. + fn merge_k_nearest<'v>( + &'v self, + target: &K, + k: usize, + neighbors: &mut Vec>, + ) { + self.search(HeapNeighborhood::new(target, k, None, neighbors)) + .sort(); + } + + /// Merges up to `k` nearest neighbors within the `threshold` into an existing sorted vector. + fn merge_k_nearest_within<'v, D>( + &'v self, + target: &K, + k: usize, + threshold: D, + neighbors: &mut Vec>, + ) where + D: TryInto, + { + if let Ok(distance) = threshold.try_into() { + self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors)) + .sort(); + } + } + + /// Search for nearest neighbors and add them to a neighborhood. + fn search<'k, 'v, N>(&'v self, neighborhood: N) -> N + where + K: 'k, + V: 'v, + N: Neighborhood<&'k K, &'v V>; +} + +/// Marker trait for [NearestNeighbors] implementations that always return exact results. +pub trait ExactNeighbors, V = K>: NearestNeighbors {} + +#[cfg(test)] +pub mod tests { + use super::*; + + use crate::euclid::{Euclidean, EuclideanDistance}; + use crate::exhaustive::ExhaustiveSearch; + + use rand::prelude::*; + + use std::iter::FromIterator; + + type Point = Euclidean<[f32; 3]>; + + /// Test an [ExactNeighbors] implementation. + pub fn test_exact_neighbors(from_iter: F) + where + T: ExactNeighbors, + F: Fn(Vec) -> T, + { + test_empty(&from_iter); + test_pythagorean(&from_iter); + test_random_points(&from_iter); + } + + fn test_empty(from_iter: &F) + where + T: NearestNeighbors, + F: Fn(Vec) -> T, + { + let points = Vec::new(); + let index = from_iter(points); + let target = Euclidean([0.0, 0.0, 0.0]); + assert_eq!(index.nearest(&target), None); + assert_eq!(index.nearest_within(&target, 1.0), None); + assert!(index.k_nearest(&target, 0).is_empty()); + assert!(index.k_nearest(&target, 3).is_empty()); + assert!(index.k_nearest_within(&target, 0, 1.0).is_empty()); + assert!(index.k_nearest_within(&target, 3, 1.0).is_empty()); + } + + fn test_pythagorean(from_iter: &F) + where + T: NearestNeighbors, + F: Fn(Vec) -> T, + { + let points = vec![ + Euclidean([3.0, 4.0, 0.0]), + Euclidean([5.0, 0.0, 12.0]), + Euclidean([0.0, 8.0, 15.0]), + Euclidean([1.0, 2.0, 2.0]), + Euclidean([2.0, 3.0, 6.0]), + Euclidean([4.0, 4.0, 7.0]), + ]; + let index = from_iter(points); + let target = Euclidean([0.0, 0.0, 0.0]); + + assert_eq!( + index.nearest(&target).expect("No nearest neighbor found"), + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0) + ); + + assert_eq!(index.nearest_within(&target, 2.0), None); + assert_eq!( + index.nearest_within(&target, 4.0).expect("No nearest neighbor found within 4.0"), + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0) + ); + + assert!(index.k_nearest(&target, 0).is_empty()); + assert_eq!( + index.k_nearest(&target, 3), + vec![ + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), + Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), + ] + ); + + assert!(index.k_nearest(&target, 0).is_empty()); + assert_eq!( + index.k_nearest_within(&target, 3, 6.0), + vec![ + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), + ] + ); + assert_eq!( + index.k_nearest_within(&target, 3, 8.0), + vec![ + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), + Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), + ] + ); + + let mut neighbors = Vec::new(); + index.merge_k_nearest(&target, 3, &mut neighbors); + assert_eq!( + neighbors, + vec![ + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), + Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), + ] + ); + + neighbors = vec![ + Neighbor::new(&target, EuclideanDistance::from_squared(0.0)), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), EuclideanDistance::from_squared(25.0)), + Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), EuclideanDistance::from_squared(49.0)), + ]; + index.merge_k_nearest_within(&target, 3, 4.0, &mut neighbors); + assert_eq!( + neighbors, + vec![ + Neighbor::new(&target, 0.0), + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), + ] + ); + } + + fn test_random_points(from_iter: &F) + where + T: NearestNeighbors, + F: Fn(Vec) -> T, + { + let mut points = Vec::new(); + for _ in 0..256 { + points.push(Euclidean([random(), random(), random()])); + } + + let index = from_iter(points.clone()); + let eindex = ExhaustiveSearch::from_iter(points.clone()); + + let target = Euclidean([random(), random(), random()]); + + assert_eq!( + index.k_nearest(&target, 3), + eindex.k_nearest(&target, 3), + "target: {:?}, points: {:#?}", + target, + points, + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 6402da2..1e77f6b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,7 +46,6 @@ //! //! # use acap::euclid::Euclidean; //! use acap::vp::VpTree; -//! use acap::NearestNeighbors; //! //! let tree = VpTree::balanced(vec![ //! Euclidean([3, 4]), @@ -61,7 +60,8 @@ //! //! # use acap::euclid::Euclidean; //! # use acap::vp::VpTree; -//! # use acap::NearestNeighbors; +//! use acap::knn::NearestNeighbors; +//! //! # let tree = VpTree::balanced( //! # vec![Euclidean([3, 4]), Euclidean([5, 12]), Euclidean([8, 15]), Euclidean([7, 24])] //! # ); @@ -87,8 +87,8 @@ //! nearest neighbor index instead of having it hold the data itself: //! //! use acap::euclid::Euclidean; +//! use acap::knn::NearestNeighbors; //! use acap::vp::VpTree; -//! use acap::NearestNeighbors; //! //! let points = vec![ //! Euclidean([3, 4]), @@ -107,17 +107,20 @@ //! See the [`Proximity`] documentation. //! //! [nearest neighbor search]: https://en.wikipedia.org/wiki/Nearest_neighbor_search -//! [`distance()`]: Proximity#tymethod.distance -//! [`value()`]: Distance#method.value -//! [coordinates]: Coordinates +//! [`distance()`]: distance::Proximity#tymethod.distance +//! [`value()`]: distance::Distance#method.value +//! [coordinates]: coords::Coordinates //! [Euclidean distance]: https://en.wikipedia.org/wiki/Euclidean_distance -//! [many different similarity search data structures]: NearestNeighbors#implementors +//! [`NearestNeighbors`]: knn::NearestNeighbors +//! [many different similarity search data structures]: knn::NearestNeighbors#implementors //! [vantage-point tree]: https://en.wikipedia.org/wiki/Vantage-point_tree //! [`VpTree`]: vp::VpTree -//! [`nearest()`]: NearestNeighbors#method.nearest -//! [`k_nearest()`]: NearestNeighbors#method.k_nearest -//! [`nearest_within()`]: NearestNeighbors#method.nearest_within -//! [`k_nearest_within()`]: NearestNeighbors#method.k_nearest_within +//! [`Neighbor`]: knn::Neighbor +//! [`nearest()`]: knn::NearestNeighbors#method.nearest +//! [`k_nearest()`]: knn::NearestNeighbors#method.k_nearest +//! [`nearest_within()`]: knn::NearestNeighbors#method.nearest_within +//! [`k_nearest_within()`]: knn::NearestNeighbors#method.k_nearest_within +//! [`ExactNeighbors`]: knn::ExactNeighbors pub mod chebyshev; pub mod coords; @@ -127,6 +130,7 @@ pub mod euclid; pub mod exhaustive; pub mod hamming; pub mod kd; +pub mod knn; pub mod lp; pub mod taxi; pub mod vp; @@ -136,490 +140,4 @@ mod util; pub use coords::Coordinates; pub use distance::{Distance, Metric, Proximity}; pub use euclid::{euclidean_distance, Euclidean, EuclideanDistance}; - -use std::convert::TryInto; - -/// A nearest neighbor. -#[derive(Clone, Copy, Debug)] -pub struct Neighbor { - /// The neighbor itself. - pub item: V, - /// The distance from the target to this neighbor. - pub distance: D, -} - -impl Neighbor { - /// Create a new Neighbor. - pub fn new(item: V, distance: D) -> Self { - Self { item, distance } - } -} - -impl PartialEq> for Neighbor -where - V1: PartialEq, - D1: PartialEq, -{ - fn eq(&self, other: &Neighbor) -> bool { - self.item == other.item && self.distance == other.distance - } -} - -/// Accumulates nearest neighbor search results. -/// -/// Type parameters: -/// -/// * `K`: The type of the search target (the "key" type) -/// * `V`: The type of neighbors this contains (the "value" type) -/// -/// Neighborhood implementations keep track of the current search radius and accumulate the results, -/// work which would otherwise have to be duplicated for every nearest neighbor search algorithm. -/// They also serve as a customization point, allowing for functionality to be injected into any -/// [NearestNeighbors] implementation (for example, filtering the result set or limiting the number -/// of neighbors considered). -pub trait Neighborhood, V> { - /// Returns the target of the nearest neighbor search. - fn target(&self) -> K; - - /// Check whether a distance is within the current search radius. - fn contains(&self, distance: D) -> bool - where - D: PartialOrd; - - /// Consider a new candidate neighbor. - /// - /// Returns `self.target().distance(item)`. - fn consider(&mut self, item: V) -> K::Distance; -} - -/// A [Neighborhood] with at most one result. -#[derive(Debug)] -struct SingletonNeighborhood { - /// The search target. - target: K, - /// The current threshold distance. - threshold: Option, - /// The current nearest neighbor, if any. - neighbor: Option>, -} - -impl SingletonNeighborhood { - /// Create a new singleton neighborhood. - /// - /// * `target`: The search target. - /// * `threshold`: The maximum allowable distance. - fn new(target: K, threshold: Option) -> Self { - Self { - target, - threshold, - neighbor: None, - } - } - - /// Convert this result into an optional neighbor. - fn into_option(self) -> Option> { - self.neighbor - } -} - -impl Neighborhood for SingletonNeighborhood -where - K: Copy + Proximity, -{ - fn target(&self) -> K { - self.target - } - - fn contains(&self, distance: D) -> bool - where - D: PartialOrd, - { - self.threshold.map_or(true, |t| distance <= t) - } - - fn consider(&mut self, item: V) -> K::Distance { - let distance = self.target.distance(&item); - - if self.contains(distance) { - self.threshold = Some(distance); - self.neighbor = Some(Neighbor::new(item, distance)); - } - - distance - } -} - -/// A [Neighborhood] of up to `k` results, using a binary heap. -#[derive(Debug)] -struct HeapNeighborhood<'a, K, V, D> { - /// The target of the nearest neighbor search. - target: K, - /// The number of nearest neighbors to find. - k: usize, - /// The current threshold distance to the farthest result. - threshold: Option, - /// A max-heap of the best candidates found so far. - heap: &'a mut Vec>, -} - -impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> { - /// Create a new HeapNeighborhood. - /// - /// * `target`: The search target. - /// * `k`: The maximum number of nearest neighbors to find. - /// * `threshold`: The maximum allowable distance. - /// * `heap`: The vector of neighbors to use as the heap. - fn new( - target: K, - k: usize, - mut threshold: Option, - heap: &'a mut Vec>, - ) -> Self { - // A descending array is also a max-heap - heap.reverse(); - - if k > 0 && heap.len() == k { - let distance = heap[0].distance; - if threshold.map_or(true, |t| distance <= t) { - threshold = Some(distance); - } - } - - Self { - target, - k, - threshold, - heap, - } - } - - /// Push a new element into the heap. - fn push(&mut self, item: Neighbor) { - let mut i = self.heap.len(); - self.heap.push(item); - - while i > 0 { - let parent = (i - 1) / 2; - if self.heap[i].distance > self.heap[parent].distance { - self.heap.swap(i, parent); - i = parent; - } else { - break; - } - } - } - - /// Restore the heap property by lowering the root. - fn sink_root(&mut self, len: usize) { - let mut i = 0; - let dist = self.heap[i].distance; - - loop { - let mut child = 2 * i + 1; - let right = child + 1; - if right < len && self.heap[child].distance < self.heap[right].distance { - child = right; - } - - if child < len && dist < self.heap[child].distance { - self.heap.swap(i, child); - i = child; - } else { - break; - } - } - } - - /// Replace the root of the heap with a new element. - fn replace_root(&mut self, item: Neighbor) { - self.heap[0] = item; - self.sink_root(self.heap.len()); - } - - /// Sort the heap from smallest to largest distance. - fn sort(&mut self) { - for i in (0..self.heap.len()).rev() { - self.heap.swap(0, i); - self.sink_root(i); - } - } -} - -impl<'a, K, V> Neighborhood for HeapNeighborhood<'a, K, V, K::Distance> -where - K: Copy + Proximity, -{ - fn target(&self) -> K { - self.target - } - - fn contains(&self, distance: D) -> bool - where - D: PartialOrd, - { - self.k > 0 && self.threshold.map_or(true, |t| distance <= t) - } - - fn consider(&mut self, item: V) -> K::Distance { - let distance = self.target.distance(&item); - - if self.contains(distance) { - let neighbor = Neighbor::new(item, distance); - - if self.heap.len() < self.k { - self.push(neighbor); - } else { - self.replace_root(neighbor); - } - - if self.heap.len() == self.k { - self.threshold = Some(self.heap[0].distance); - } - } - - distance - } -} - -/// A [nearest neighbor search] index. -/// -/// Type parameters: -/// -/// * `K`: The type of the search target (the "key" type) -/// * `V`: The type of the returned neighbors (the "value" type) -/// -/// In general, exact nearest neighbor searches may be prohibitively expensive due to the [curse of -/// dimensionality]. Therefore, NearestNeighbor implementations are allowed to give approximate -/// results. The marker trait [ExactNeighbors] denotes implementations which are guaranteed to give -/// exact results. -/// -/// [nearest neighbor search]: https://en.wikipedia.org/wiki/Nearest_neighbor_search -/// [curse of dimensionality]: https://en.wikipedia.org/wiki/Curse_of_dimensionality -pub trait NearestNeighbors, V = K> { - /// Returns the nearest neighbor to `target` (or `None` if this index is empty). - fn nearest(&self, target: &K) -> Option> { - self.search(SingletonNeighborhood::new(target, None)) - .into_option() - } - - /// Returns the nearest neighbor to `target` within the distance `threshold`, if one exists. - fn nearest_within(&self, target: &K, threshold: D) -> Option> - where - D: TryInto, - { - if let Ok(distance) = threshold.try_into() { - self.search(SingletonNeighborhood::new(target, Some(distance))) - .into_option() - } else { - None - } - } - - /// Returns the up to `k` nearest neighbors to `target`. - /// - /// The result will be sorted from nearest to farthest. - fn k_nearest(&self, target: &K, k: usize) -> Vec> { - let mut neighbors = Vec::with_capacity(k); - self.merge_k_nearest(target, k, &mut neighbors); - neighbors - } - - /// Returns the up to `k` nearest neighbors to `target` within the distance `threshold`. - /// - /// The result will be sorted from nearest to farthest. - fn k_nearest_within( - &self, - target: &K, - k: usize, - threshold: D, - ) -> Vec> - where - D: TryInto, - { - let mut neighbors = Vec::with_capacity(k); - self.merge_k_nearest_within(target, k, threshold, &mut neighbors); - neighbors - } - - /// Merges up to `k` nearest neighbors into an existing sorted vector. - fn merge_k_nearest<'v>( - &'v self, - target: &K, - k: usize, - neighbors: &mut Vec>, - ) { - self.search(HeapNeighborhood::new(target, k, None, neighbors)) - .sort(); - } - - /// Merges up to `k` nearest neighbors within the `threshold` into an existing sorted vector. - fn merge_k_nearest_within<'v, D>( - &'v self, - target: &K, - k: usize, - threshold: D, - neighbors: &mut Vec>, - ) where - D: TryInto, - { - if let Ok(distance) = threshold.try_into() { - self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors)) - .sort(); - } - } - - /// Search for nearest neighbors and add them to a neighborhood. - fn search<'k, 'v, N>(&'v self, neighborhood: N) -> N - where - K: 'k, - V: 'v, - N: Neighborhood<&'k K, &'v V>; -} - -/// Marker trait for [NearestNeighbors] implementations that always return exact results. -pub trait ExactNeighbors, V = K>: NearestNeighbors {} - -#[cfg(test)] -pub mod tests { - use super::*; - - use crate::exhaustive::ExhaustiveSearch; - - use rand::prelude::*; - - use std::iter::FromIterator; - - type Point = Euclidean<[f32; 3]>; - - /// Test an [ExactNeighbors] implementation. - pub fn test_exact_neighbors(from_iter: F) - where - T: ExactNeighbors, - F: Fn(Vec) -> T, - { - test_empty(&from_iter); - test_pythagorean(&from_iter); - test_random_points(&from_iter); - } - - fn test_empty(from_iter: &F) - where - T: NearestNeighbors, - F: Fn(Vec) -> T, - { - let points = Vec::new(); - let index = from_iter(points); - let target = Euclidean([0.0, 0.0, 0.0]); - assert_eq!(index.nearest(&target), None); - assert_eq!(index.nearest_within(&target, 1.0), None); - assert!(index.k_nearest(&target, 0).is_empty()); - assert!(index.k_nearest(&target, 3).is_empty()); - assert!(index.k_nearest_within(&target, 0, 1.0).is_empty()); - assert!(index.k_nearest_within(&target, 3, 1.0).is_empty()); - } - - fn test_pythagorean(from_iter: &F) - where - T: NearestNeighbors, - F: Fn(Vec) -> T, - { - let points = vec![ - Euclidean([3.0, 4.0, 0.0]), - Euclidean([5.0, 0.0, 12.0]), - Euclidean([0.0, 8.0, 15.0]), - Euclidean([1.0, 2.0, 2.0]), - Euclidean([2.0, 3.0, 6.0]), - Euclidean([4.0, 4.0, 7.0]), - ]; - let index = from_iter(points); - let target = Euclidean([0.0, 0.0, 0.0]); - - assert_eq!( - index.nearest(&target).expect("No nearest neighbor found"), - Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0) - ); - - assert_eq!(index.nearest_within(&target, 2.0), None); - assert_eq!( - index.nearest_within(&target, 4.0).expect("No nearest neighbor found within 4.0"), - Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0) - ); - - assert!(index.k_nearest(&target, 0).is_empty()); - assert_eq!( - index.k_nearest(&target, 3), - vec![ - Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), - Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), - Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), - ] - ); - - assert!(index.k_nearest(&target, 0).is_empty()); - assert_eq!( - index.k_nearest_within(&target, 3, 6.0), - vec![ - Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), - Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), - ] - ); - assert_eq!( - index.k_nearest_within(&target, 3, 8.0), - vec![ - Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), - Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), - Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), - ] - ); - - let mut neighbors = Vec::new(); - index.merge_k_nearest(&target, 3, &mut neighbors); - assert_eq!( - neighbors, - vec![ - Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), - Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), - Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), - ] - ); - - neighbors = vec![ - Neighbor::new(&target, EuclideanDistance::from_squared(0.0)), - Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), EuclideanDistance::from_squared(25.0)), - Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), EuclideanDistance::from_squared(49.0)), - ]; - index.merge_k_nearest_within(&target, 3, 4.0, &mut neighbors); - assert_eq!( - neighbors, - vec![ - Neighbor::new(&target, 0.0), - Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), - Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), - ] - ); - } - - fn test_random_points(from_iter: &F) - where - T: NearestNeighbors, - F: Fn(Vec) -> T, - { - let mut points = Vec::new(); - for _ in 0..256 { - points.push(Euclidean([random(), random(), random()])); - } - - let index = from_iter(points.clone()); - let eindex = ExhaustiveSearch::from_iter(points.clone()); - - let target = Euclidean([random(), random(), random()]); - - assert_eq!( - index.k_nearest(&target, 3), - eindex.k_nearest(&target, 3), - "target: {:?}, points: {:#?}", - target, - points, - ); - } -} +pub use knn::{ExactNeighbors, NearestNeighbors, Neighbor}; diff --git a/src/util.rs b/src/util.rs index f838a9b..0979782 100644 --- a/src/util.rs +++ b/src/util.rs @@ -3,7 +3,7 @@ use std::cmp::Ordering; /// A wrapper that converts a partial ordering into a total one by panicking. -#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialOrd)] pub struct Ordered(T); impl Ordered { @@ -25,7 +25,13 @@ impl Ord for Ordered { } } -impl Eq for Ordered {} +impl PartialEq for Ordered { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl Eq for Ordered {} #[cfg(test)] mod tests { diff --git a/src/vp.rs b/src/vp.rs index e0b218f..a5859ae 100644 --- a/src/vp.rs +++ b/src/vp.rs @@ -1,8 +1,8 @@ //! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree). use crate::distance::{Distance, DistanceValue, Metric, Proximity}; +use crate::knn::{ExactNeighbors, NearestNeighbors, Neighborhood}; use crate::util::Ordered; -use crate::{ExactNeighbors, NearestNeighbors, Neighborhood}; use num_traits::zero; @@ -620,7 +620,7 @@ where mod tests { use super::*; - use crate::tests::test_exact_neighbors; + use crate::knn::tests::test_exact_neighbors; #[test] fn test_vp_tree() { -- cgit v1.2.3