//! [Metric spaces](https://en.wikipedia.org/wiki/Metric_space). pub mod approx; pub mod forest; pub mod kd; pub mod soft; pub mod vp; use std::cmp::Ordering; use std::collections::BinaryHeap; use std::iter::FromIterator; /// A wrapper that converts a partial ordering into a total one by panicking. #[derive(Debug, PartialEq, PartialOrd)] struct Ordered(T); impl Ord for Ordered { fn cmp(&self, other: &Self) -> Ordering { self.partial_cmp(other).unwrap() } } impl Eq for Ordered {} /// An [order embedding](https://en.wikipedia.org/wiki/Order_embedding) for distances. /// /// Implementations of this trait must satisfy, for all non-negative distances `x` and `y`: /// /// * `x == Self::from(x).into()` /// * `x <= y` iff `Self::from(x) <= Self::from(y)` /// /// This trait exists to optimize the common case where distances can be compared more efficiently /// than their exact values can be computed. For example, taking the square root can be avoided /// when comparing Euclidean distances (see [SquaredDistance]). pub trait Distance: Copy + From + Into + PartialOrd {} impl Distance for f64 {} /// A squared distance, to avoid computing square roots unless absolutely necessary. #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] pub struct SquaredDistance(f64); impl SquaredDistance { /// Create a SquaredDistance from an already squared value. pub fn from_squared(value: f64) -> Self { Self(value) } } impl From for SquaredDistance { fn from(value: f64) -> Self { Self::from_squared(value * value) } } impl From for f64 { fn from(value: SquaredDistance) -> Self { value.0.sqrt() } } impl Distance for SquaredDistance {} /// A [metric space](https://en.wikipedia.org/wiki/Metric_space). pub trait Metric { /// The type used to represent distances. type Distance: Distance; /// Computes the distance between this point and another point. This function must satisfy /// three conditions: /// /// * `x.distance(y) == 0` iff `x == y` (identity of indiscernibles) /// * `x.distance(y) == y.distance(x)` (symmetry) /// * `x.distance(z) <= x.distance(y) + y.distance(z)` (triangle inequality) fn distance(&self, other: &T) -> Self::Distance; } /// Blanket [Metric] implementation for references. impl<'a, 'b, T, U: Metric> Metric<&'a T> for &'b U { type Distance = U::Distance; fn distance(&self, other: &&'a T) -> Self::Distance { (*self).distance(other) } } /// The standard [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance) metric. impl Metric for [f64] { type Distance = SquaredDistance; fn distance(&self, other: &Self) -> Self::Distance { debug_assert!(self.len() == other.len()); let mut sum = 0.0; for i in 0..self.len() { let diff = self[i] - other[i]; sum += diff * diff; } Self::Distance::from_squared(sum) } } /// A nearest neighbor to a target. #[derive(Clone, Copy, Debug, PartialEq)] pub struct Neighbor { /// The found item. pub item: T, /// The distance from the target. pub distance: f64, } impl Neighbor { /// Create a new Neighbor. pub fn new(item: T, distance: f64) -> Self { Self { item, distance } } } /// A candidate nearest neighbor found during a search. #[derive(Debug)] struct Candidate { item: T, distance: D, } impl Candidate { fn new(target: U, item: T) -> Self where U: Metric, { let distance = target.distance(&item); Self { item, distance } } fn into_neighbor(self) -> Neighbor { Neighbor::new(self.item, self.distance.into()) } } impl PartialOrd for Candidate { fn partial_cmp(&self, other: &Self) -> Option { self.distance.partial_cmp(&other.distance) } } impl Ord for Candidate { fn cmp(&self, other: &Self) -> Ordering { self.partial_cmp(other).unwrap() } } impl PartialEq for Candidate { fn eq(&self, other: &Self) -> bool { self.distance.eq(&other.distance) } } impl Eq for Candidate {} /// Accumulates nearest neighbor search results. pub trait Neighborhood> { /// Returns the target of the nearest neighbor search. fn target(&self) -> U; /// Check whether a distance is within this neighborhood. fn contains(&self, distance: f64) -> bool { distance < 0.0 || self.contains_distance(distance.into()) } /// Check whether a distance is within this neighborhood. fn contains_distance(&self, distance: U::Distance) -> bool; /// Consider a new candidate neighbor. fn consider(&mut self, item: T) -> U::Distance; } /// A [Neighborhood] with at most one result. #[derive(Debug)] struct SingletonNeighborhood> { /// The target of the nearest neighbor search. target: U, /// The current threshold distance to the farthest result. threshold: Option, /// The current nearest neighbor, if any. candidate: Option>, } impl SingletonNeighborhood where U: Copy + Metric, { /// Create a new single metric result tracker. /// /// * `target`: The target fo the nearest neighbor search. /// * `threshold`: The maximum allowable distance. fn new(target: U, threshold: Option) -> Self { Self { target, threshold: threshold.map(U::Distance::from), candidate: None, } } /// Consider a candidate. fn push(&mut self, candidate: Candidate) -> U::Distance { let distance = candidate.distance; if self.contains_distance(distance) { self.threshold = Some(distance); self.candidate = Some(candidate); } distance } /// Convert this result into an optional neighbor. fn into_option(self) -> Option> { self.candidate.map(Candidate::into_neighbor) } } impl Neighborhood for SingletonNeighborhood where U: Copy + Metric, { fn target(&self) -> U { self.target } fn contains_distance(&self, distance: U::Distance) -> bool { self.threshold.map(|t| distance <= t).unwrap_or(true) } fn consider(&mut self, item: T) -> U::Distance { self.push(Candidate::new(self.target, item)) } } /// A [Neighborhood] of up to `k` results, using a binary heap. #[derive(Debug)] struct HeapNeighborhood> { /// The target of the nearest neighbor search. target: U, /// The number of nearest neighbors to find. k: usize, /// The current threshold distance to the farthest result. threshold: Option, /// A max-heap of the best candidates found so far. heap: BinaryHeap>, } impl HeapNeighborhood where U: Copy + Metric, { /// Create a new metric result tracker. /// /// * `target`: The target fo the nearest neighbor search. /// * `k`: The number of nearest neighbors to find. /// * `threshold`: The maximum allowable distance. fn new(target: U, k: usize, threshold: Option) -> Self { Self { target, k, threshold: threshold.map(U::Distance::from), heap: BinaryHeap::with_capacity(k), } } /// Consider a candidate. fn push(&mut self, candidate: Candidate) -> U::Distance { let distance = candidate.distance; if self.contains_distance(distance) { let heap = &mut self.heap; if heap.len() == self.k { heap.pop(); } heap.push(candidate); if heap.len() == self.k { self.threshold = self.heap.peek().map(|c| c.distance) } } distance } /// Convert these results into a vector of neighbors. fn into_vec(self) -> Vec> { self.heap .into_sorted_vec() .into_iter() .map(Candidate::into_neighbor) .collect() } } impl Neighborhood for HeapNeighborhood where U: Copy + Metric, { fn target(&self) -> U { self.target } fn contains_distance(&self, distance: U::Distance) -> bool { self.k > 0 && self.threshold.map(|t| distance <= t).unwrap_or(true) } fn consider(&mut self, item: T) -> U::Distance { self.push(Candidate::new(self.target, item)) } } /// A [nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search) index. /// /// Type parameters: /// * `T`: The search result type. /// * `U`: The query type. pub trait NearestNeighbors = T> { /// Returns the nearest neighbor to `target` (or `None` if this index is empty). fn nearest(&self, target: &U) -> Option> { self.search(SingletonNeighborhood::new(target, None)) .into_option() } /// Returns the nearest neighbor to `target` within the distance `threshold`, if one exists. fn nearest_within(&self, target: &U, threshold: f64) -> Option> { self.search(SingletonNeighborhood::new(target, Some(threshold))) .into_option() } /// Returns the up to `k` nearest neighbors to `target`. fn k_nearest(&self, target: &U, k: usize) -> Vec> { self.search(HeapNeighborhood::new(target, k, None)) .into_vec() } /// Returns the up to `k` nearest neighbors to `target` within the distance `threshold`. fn k_nearest_within(&self, target: &U, k: usize, threshold: f64) -> Vec> { self.search(HeapNeighborhood::new(target, k, Some(threshold))) .into_vec() } /// Search for nearest neighbors and add them to a neighborhood. fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N where T: 'a, U: 'b, N: Neighborhood<&'a T, &'b U>; } /// A [NearestNeighbors] implementation that does exhaustive search. #[derive(Debug)] pub struct ExhaustiveSearch(Vec); impl ExhaustiveSearch { /// Create an empty ExhaustiveSearch index. pub fn new() -> Self { Self(Vec::new()) } /// Add a new item to the index. pub fn push(&mut self, item: T) { self.0.push(item); } /// Get the size of this index. pub fn len(&self) -> usize { self.0.len() } /// Check if this index is empty. pub fn is_empty(&self) -> bool { self.0.is_empty() } } impl Default for ExhaustiveSearch { fn default() -> Self { Self::new() } } impl FromIterator for ExhaustiveSearch { fn from_iter>(items: I) -> Self { Self(items.into_iter().collect()) } } impl IntoIterator for ExhaustiveSearch { type Item = T; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } impl Extend for ExhaustiveSearch { fn extend>(&mut self, iter: I) { for value in iter { self.push(value); } } } impl> NearestNeighbors for ExhaustiveSearch { fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N where T: 'a, U: 'b, N: Neighborhood<&'a T, &'b U>, { for e in &self.0 { neighborhood.consider(e); } neighborhood } } #[cfg(test)] pub mod tests { use super::*; use rand::prelude::*; #[derive(Clone, Copy, Debug, PartialEq)] pub struct Point(pub [f64; 3]); impl Metric for Point { type Distance = SquaredDistance; fn distance(&self, other: &Self) -> Self::Distance { self.0.distance(&other.0) } } /// Test a [NearestNeighbors] impl. pub fn test_nearest_neighbors(from_iter: F) where T: NearestNeighbors, F: Fn(Vec) -> T, { test_empty(&from_iter); test_pythagorean(&from_iter); test_random_points(&from_iter); } fn test_empty(from_iter: &F) where T: NearestNeighbors, F: Fn(Vec) -> T, { let points = Vec::new(); let index = from_iter(points); let target = Point([0.0, 0.0, 0.0]); assert_eq!(index.nearest(&target), None); assert_eq!(index.nearest_within(&target, 1.0), None); assert!(index.k_nearest(&target, 0).is_empty()); assert!(index.k_nearest(&target, 3).is_empty()); assert!(index.k_nearest_within(&target, 0, 1.0).is_empty()); assert!(index.k_nearest_within(&target, 3, 1.0).is_empty()); } fn test_pythagorean(from_iter: &F) where T: NearestNeighbors, F: Fn(Vec) -> T, { let points = vec![ Point([3.0, 4.0, 0.0]), Point([5.0, 0.0, 12.0]), Point([0.0, 8.0, 15.0]), Point([1.0, 2.0, 2.0]), Point([2.0, 3.0, 6.0]), Point([4.0, 4.0, 7.0]), ]; let index = from_iter(points); let target = Point([0.0, 0.0, 0.0]); assert_eq!( index.nearest(&target), Some(Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0)) ); assert_eq!(index.nearest_within(&target, 2.0), None); assert_eq!( index.nearest_within(&target, 4.0), Some(Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0)) ); assert!(index.k_nearest(&target, 0).is_empty()); assert_eq!( index.k_nearest(&target, 3), vec![ Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0), Neighbor::new(&Point([3.0, 4.0, 0.0]), 5.0), Neighbor::new(&Point([2.0, 3.0, 6.0]), 7.0), ] ); assert!(index.k_nearest(&target, 0).is_empty()); assert_eq!( index.k_nearest_within(&target, 3, 6.0), vec![ Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0), Neighbor::new(&Point([3.0, 4.0, 0.0]), 5.0), ] ); assert_eq!( index.k_nearest_within(&target, 3, 8.0), vec![ Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0), Neighbor::new(&Point([3.0, 4.0, 0.0]), 5.0), Neighbor::new(&Point([2.0, 3.0, 6.0]), 7.0), ] ); } fn test_random_points(from_iter: &F) where T: NearestNeighbors, F: Fn(Vec) -> T, { let mut points = Vec::new(); for _ in 0..255 { points.push(Point([random(), random(), random()])); } let target = Point([random(), random(), random()]); let eindex = ExhaustiveSearch::from_iter(points.clone()); let index = from_iter(points); assert_eq!(index.k_nearest(&target, 3), eindex.k_nearest(&target, 3)); } #[test] fn test_exhaustive_index() { test_nearest_neighbors(ExhaustiveSearch::from_iter); } }