From 8aa2911b4c978ad7131a430d07a580cedf6f8f65 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sun, 19 Apr 2020 16:37:16 -0400 Subject: metric: Add some general interfaces for metric spaces --- src/metric.rs | 531 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 531 insertions(+) create mode 100644 src/metric.rs (limited to 'src/metric.rs') diff --git a/src/metric.rs b/src/metric.rs new file mode 100644 index 0000000..e067771 --- /dev/null +++ b/src/metric.rs @@ -0,0 +1,531 @@ +//! [Metric spaces](https://en.wikipedia.org/wiki/Metric_space). + +use ordered_float::OrderedFloat; + +use std::cmp::Ordering; +use std::collections::BinaryHeap; +use std::iter::FromIterator; + +/// An [order embedding](https://en.wikipedia.org/wiki/Order_embedding) for distances. +/// +/// Implementations of this trait must satisfy, for all non-negative distances `x` and `y`: +/// +/// * `x == Self::from(x).into()` +/// * `x <= y` iff `Self::from(x) <= Self::from(y)` +/// +/// This trait exists to optimize the common case where distances can be compared more efficiently +/// than their exact values can be computed. For example, taking the square root can be avoided +/// when comparing Euclidean distances (see [SquaredDistance]). +pub trait Distance: Copy + From + Into + Ord {} + +/// A raw numerical distance. +#[derive(Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] +pub struct RawDistance(OrderedFloat); + +impl From for RawDistance { + fn from(value: f64) -> Self { + Self(value.into()) + } +} + +impl From for f64 { + fn from(value: RawDistance) -> Self { + value.0.into_inner() + } +} + +impl Distance for RawDistance {} + +/// A squared distance, to avoid computing square roots unless absolutely necessary. +#[derive(Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] +pub struct SquaredDistance(OrderedFloat); + +impl SquaredDistance { + /// Create a SquaredDistance from an already squared value. + pub fn from_squared(value: f64) -> Self { + Self(value.into()) + } +} + +impl From for SquaredDistance { + fn from(value: f64) -> Self { + Self::from_squared(value * value) + } +} + +impl From for f64 { + fn from(value: SquaredDistance) -> Self { + value.0.into_inner().sqrt() + } +} + +impl Distance for SquaredDistance {} + +/// A [metric space](https://en.wikipedia.org/wiki/Metric_space). +pub trait Metric { + /// The type used to represent distances. Use [RawDistance] to compare the actual values + /// directly, or another type if comparisons can be implemented more efficiently. + type Distance: Distance; + + /// Computes the distance between this point and another point. This function must satisfy + /// three conditions: + /// + /// * `x.distance(y) == 0` iff `x == y` (identity of indiscernibles) + /// * `x.distance(y) == y.distance(x)` (symmetry) + /// * `x.distance(z) <= x.distance(y) + y.distance(z)` (triangle inequality) + fn distance(&self, other: &T) -> Self::Distance; +} + +/// Blanket [Metric] implementation for references. +impl<'a, 'b, T, U: Metric> Metric<&'a T> for &'b U { + type Distance = U::Distance; + + fn distance(&self, other: &&'a T) -> Self::Distance { + (*self).distance(other) + } +} + +/// The standard [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance) metric. +impl Metric for [f64] { + type Distance = SquaredDistance; + + fn distance(&self, other: &Self) -> Self::Distance { + debug_assert!(self.len() == other.len()); + + let mut sum = 0.0; + for i in 0..self.len() { + let diff = self[i] - other[i]; + sum += diff * diff; + } + + Self::Distance::from_squared(sum) + } +} + +/// A nearest neighbor to a target. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Neighbor { + /// The found item. + pub item: T, + /// The distance from the target. + pub distance: f64, +} + +impl Neighbor { + /// Create a new Neighbor. + pub fn new(item: T, distance: f64) -> Self { + Self { item, distance } + } +} + +/// A candidate nearest neighbor found during a search. +#[derive(Debug)] +struct Candidate { + item: T, + distance: D, +} + +impl Candidate { + fn new(target: U, item: T) -> Self + where + U: Metric, + { + let distance = target.distance(&item); + Self { item, distance } + } + + fn into_neighbor(self) -> Neighbor { + Neighbor::new(self.item, self.distance.into()) + } +} + +impl PartialOrd for Candidate { + fn partial_cmp(&self, other: &Self) -> Option { + self.distance.partial_cmp(&other.distance) + } +} + +impl Ord for Candidate { + fn cmp(&self, other: &Self) -> Ordering { + self.distance.cmp(&other.distance) + } +} + +impl PartialEq for Candidate { + fn eq(&self, other: &Self) -> bool { + self.distance.eq(&other.distance) + } +} + +impl Eq for Candidate {} + +/// Accumulates nearest neighbor search results. +pub trait Neighborhood> { + /// Returns the target of the nearest neighbor search. + fn target(&self) -> U; + + /// Check whether a distance is within this neighborhood. + fn contains(&self, distance: f64) -> bool { + distance < 0.0 || self.contains_distance(distance.into()) + } + + /// Check whether a distance is within this neighborhood. + fn contains_distance(&self, distance: U::Distance) -> bool; + + /// Consider a new candidate neighbor. + fn consider(&mut self, item: T) -> U::Distance; +} + +/// A [Neighborhood] with at most one result. +#[derive(Debug)] +struct SingletonNeighborhood> { + /// The target of the nearest neighbor search. + target: U, + /// The current threshold distance to the farthest result. + threshold: Option, + /// The current nearest neighbor, if any. + candidate: Option>, +} + +impl SingletonNeighborhood +where + U: Copy + Metric, +{ + /// Create a new single metric result tracker. + /// + /// * `target`: The target fo the nearest neighbor search. + /// * `threshold`: The maximum allowable distance. + fn new(target: U, threshold: Option) -> Self { + Self { + target, + threshold: threshold.map(U::Distance::from), + candidate: None, + } + } + + /// Consider a candidate. + fn push(&mut self, candidate: Candidate) -> U::Distance { + let distance = candidate.distance; + + if self.contains_distance(distance) { + self.threshold = Some(distance); + self.candidate = Some(candidate); + } + + distance + } + + /// Convert this result into an optional neighbor. + fn into_option(self) -> Option> { + self.candidate.map(Candidate::into_neighbor) + } +} + +impl Neighborhood for SingletonNeighborhood +where + U: Copy + Metric, +{ + fn target(&self) -> U { + self.target + } + + fn contains_distance(&self, distance: U::Distance) -> bool { + self.threshold.map(|t| distance <= t).unwrap_or(true) + } + + fn consider(&mut self, item: T) -> U::Distance { + self.push(Candidate::new(self.target, item)) + } +} + +/// A [Neighborhood] of up to `k` results, using a binary heap. +#[derive(Debug)] +struct HeapNeighborhood> { + /// The target of the nearest neighbor search. + target: U, + /// The number of nearest neighbors to find. + k: usize, + /// The current threshold distance to the farthest result. + threshold: Option, + /// A max-heap of the best candidates found so far. + heap: BinaryHeap>, +} + +impl HeapNeighborhood +where + U: Copy + Metric, +{ + /// Create a new metric result tracker. + /// + /// * `target`: The target fo the nearest neighbor search. + /// * `k`: The number of nearest neighbors to find. + /// * `threshold`: The maximum allowable distance. + fn new(target: U, k: usize, threshold: Option) -> Self { + Self { + target, + k, + threshold: threshold.map(U::Distance::from), + heap: BinaryHeap::with_capacity(k), + } + } + + /// Consider a candidate. + fn push(&mut self, candidate: Candidate) -> U::Distance { + let distance = candidate.distance; + + if self.contains_distance(distance) { + let heap = &mut self.heap; + + if heap.len() == self.k { + heap.pop(); + } + + heap.push(candidate); + + if heap.len() == self.k { + self.threshold = self.heap.peek().map(|c| c.distance) + } + } + + distance + } + + /// Convert these results into a vector of neighbors. + fn into_vec(self) -> Vec> { + self.heap + .into_sorted_vec() + .into_iter() + .map(Candidate::into_neighbor) + .collect() + } +} + +impl Neighborhood for HeapNeighborhood +where + U: Copy + Metric, +{ + fn target(&self) -> U { + self.target + } + + fn contains_distance(&self, distance: U::Distance) -> bool { + self.k > 0 && self.threshold.map(|t| distance <= t).unwrap_or(true) + } + + fn consider(&mut self, item: T) -> U::Distance { + self.push(Candidate::new(self.target, item)) + } +} + +/// A [nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search) index. +/// +/// Type parameters: +/// * `T`: The search result type. +/// * `U`: The query type. +pub trait NearestNeighbors = T> { + /// Returns the nearest neighbor to `target` (or `None` if this index is empty). + fn nearest(&self, target: &U) -> Option> { + self.search(SingletonNeighborhood::new(target, None)) + .into_option() + } + + /// Returns the nearest neighbor to `target` within the distance `threshold`, if one exists. + fn nearest_within(&self, target: &U, threshold: f64) -> Option> { + self.search(SingletonNeighborhood::new(target, Some(threshold))) + .into_option() + } + + /// Returns the up to `k` nearest neighbors to `target`. + fn k_nearest(&self, target: &U, k: usize) -> Vec> { + self.search(HeapNeighborhood::new(target, k, None)) + .into_vec() + } + + /// Returns the up to `k` nearest neighbors to `target` within the distance `threshold`. + fn k_nearest_within(&self, target: &U, k: usize, threshold: f64) -> Vec> { + self.search(HeapNeighborhood::new(target, k, Some(threshold))) + .into_vec() + } + + /// Search for nearest neighbors and add them to a neighborhood. + fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>; +} + +/// A [NearestNeighbors] implementation that does exhaustive search. +#[derive(Debug)] +pub struct ExhaustiveSearch(Vec); + +impl ExhaustiveSearch { + /// Create an empty ExhaustiveSearch index. + pub fn new() -> Self { + Self(Vec::new()) + } + + /// Add a new item to the index. + pub fn push(&mut self, item: T) { + self.0.push(item); + } +} + +impl FromIterator for ExhaustiveSearch { + fn from_iter>(items: I) -> Self { + Self(items.into_iter().collect()) + } +} + +impl IntoIterator for ExhaustiveSearch { + type Item = T; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Extend for ExhaustiveSearch { + fn extend>(&mut self, iter: I) { + for value in iter { + self.push(value); + } + } +} + +impl> NearestNeighbors for ExhaustiveSearch { + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + for e in &self.0 { + neighborhood.consider(e); + } + neighborhood + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + use rand::prelude::*; + + #[derive(Clone, Copy, Debug, PartialEq)] + pub struct Point(pub [f64; 3]); + + impl Metric for Point { + type Distance = SquaredDistance; + + fn distance(&self, other: &Self) -> Self::Distance { + self.0.distance(&other.0) + } + } + + /// Test a [NearestNeighbors] impl. + pub fn test_nearest_neighbors(from_iter: F) + where + T: NearestNeighbors, + F: Fn(Vec) -> T, + { + test_empty(&from_iter); + test_pythagorean(&from_iter); + test_random_points(&from_iter); + } + + fn test_empty(from_iter: &F) + where + T: NearestNeighbors, + F: Fn(Vec) -> T, + { + let points = Vec::new(); + let index = from_iter(points); + let target = Point([0.0, 0.0, 0.0]); + assert_eq!(index.nearest(&target), None); + assert_eq!(index.nearest_within(&target, 1.0), None); + assert!(index.k_nearest(&target, 0).is_empty()); + assert!(index.k_nearest(&target, 3).is_empty()); + assert!(index.k_nearest_within(&target, 0, 1.0).is_empty()); + assert!(index.k_nearest_within(&target, 3, 1.0).is_empty()); + } + + fn test_pythagorean(from_iter: &F) + where + T: NearestNeighbors, + F: Fn(Vec) -> T, + { + let points = vec![ + Point([3.0, 4.0, 0.0]), + Point([5.0, 0.0, 12.0]), + Point([0.0, 8.0, 15.0]), + Point([1.0, 2.0, 2.0]), + Point([2.0, 3.0, 6.0]), + Point([4.0, 4.0, 7.0]), + ]; + let index = from_iter(points); + let target = Point([0.0, 0.0, 0.0]); + + assert_eq!( + index.nearest(&target), + Some(Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0)) + ); + + assert_eq!(index.nearest_within(&target, 2.0), None); + assert_eq!( + index.nearest_within(&target, 4.0), + Some(Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0)) + ); + + assert!(index.k_nearest(&target, 0).is_empty()); + assert_eq!( + index.k_nearest(&target, 3), + vec![ + Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Point([3.0, 4.0, 0.0]), 5.0), + Neighbor::new(&Point([2.0, 3.0, 6.0]), 7.0), + ] + ); + + assert!(index.k_nearest(&target, 0).is_empty()); + assert_eq!( + index.k_nearest_within(&target, 3, 6.0), + vec![ + Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Point([3.0, 4.0, 0.0]), 5.0), + ] + ); + assert_eq!( + index.k_nearest_within(&target, 3, 8.0), + vec![ + Neighbor::new(&Point([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Point([3.0, 4.0, 0.0]), 5.0), + Neighbor::new(&Point([2.0, 3.0, 6.0]), 7.0), + ] + ); + } + + fn test_random_points(from_iter: &F) + where + T: NearestNeighbors, + F: Fn(Vec) -> T, + { + let mut points = Vec::new(); + for _ in 0..255 { + points.push(Point([random(), random(), random()])); + } + let target = Point([random(), random(), random()]); + + let eindex = ExhaustiveSearch::from_iter(points.clone()); + let index = from_iter(points); + + assert_eq!(index.k_nearest(&target, 3), eindex.k_nearest(&target, 3)); + } + + #[test] + fn test_exhaustive_index() { + test_nearest_neighbors(ExhaustiveSearch::from_iter); + } +} -- cgit v1.2.3 From 8d9de0e1028daed981246174182a39dd917b72bc Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sun, 19 Apr 2020 16:38:24 -0400 Subject: metric/vp: Implement vantage-point trees --- src/metric.rs | 2 + src/metric/vp.rs | 168 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+) create mode 100644 src/metric/vp.rs (limited to 'src/metric.rs') diff --git a/src/metric.rs b/src/metric.rs index e067771..549db67 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -1,5 +1,7 @@ //! [Metric spaces](https://en.wikipedia.org/wiki/Metric_space). +pub mod vp; + use ordered_float::OrderedFloat; use std::cmp::Ordering; diff --git a/src/metric/vp.rs b/src/metric/vp.rs new file mode 100644 index 0000000..8d5b091 --- /dev/null +++ b/src/metric/vp.rs @@ -0,0 +1,168 @@ +//! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +use std::iter::FromIterator; + +/// A node in a VP tree. +#[derive(Debug)] +struct VpNode { + /// The vantage point itself. + item: T, + /// The radius of this node. + radius: f64, + /// The subtree inside the radius, if any. + inside: Option>, + /// The subtree outside the radius, if any. + outside: Option>, +} + +impl VpNode { + /// Create a new VpNode. + fn new(mut items: Vec) -> Option> { + if items.is_empty() { + return None; + } + + let item = items.pop().unwrap(); + + items.sort_by_cached_key(|a| item.distance(a)); + + let mid = items.len() / 2; + let outside: Vec = items.drain(mid..).collect(); + + let radius = items.last().map(|l| item.distance(l).into()).unwrap_or(0.0); + + Some(Box::new(Self { + item, + radius, + inside: Self::new(items), + outside: Self::new(outside), + })) + } +} + +trait VpSearch<'a, T, U, N> { + /// Recursively search for nearest neighbors. + fn search(&'a self, neighborhood: &mut N); + + /// Search the inside subtree. + fn search_inside(&'a self, distance: f64, neighborhood: &mut N); + + /// Search the outside subtree. + fn search_outside(&'a self, distance: f64, neighborhood: &mut N); +} + +impl<'a, T, U, N> VpSearch<'a, T, U, N> for VpNode +where + T: 'a, + U: Metric<&'a T>, + N: Neighborhood<&'a T, U>, +{ + fn search(&'a self, neighborhood: &mut N) { + let distance = neighborhood.consider(&self.item).into(); + + if distance <= self.radius { + self.search_inside(distance, neighborhood); + self.search_outside(distance, neighborhood); + } else { + self.search_outside(distance, neighborhood); + self.search_inside(distance, neighborhood); + } + } + + fn search_inside(&'a self, distance: f64, neighborhood: &mut N) { + if let Some(inside) = &self.inside { + if neighborhood.contains(distance - self.radius) { + inside.search(neighborhood); + } + } + } + + fn search_outside(&'a self, distance: f64, neighborhood: &mut N) { + if let Some(outside) = &self.outside { + if neighborhood.contains(self.radius - distance) { + outside.search(neighborhood); + } + } + } +} + +/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree). +#[derive(Debug)] +pub struct VpTree { + root: Option>>, +} + +impl FromIterator for VpTree { + fn from_iter>(items: I) -> Self { + Self { + root: VpNode::new(items.into_iter().collect::>()), + } + } +} + +impl NearestNeighbors for VpTree +where + T: Metric, + U: Metric, +{ + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + if let Some(root) = &self.root { + root.search(&mut neighborhood); + } + neighborhood + } +} + +/// An iterator that moves values out of a VP tree. +#[derive(Debug)] +pub struct IntoIter { + stack: Vec>>, +} + +impl IntoIter { + fn new(node: Option>>) -> Self { + Self { + stack: node.into_iter().collect(), + } + } +} + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + self.stack.pop().map(|node| { + self.stack.extend(node.inside); + self.stack.extend(node.outside); + node.item + }) + } +} + +impl IntoIterator for VpTree { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.root) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::test_nearest_neighbors; + + #[test] + fn test_vp_tree() { + test_nearest_neighbors(VpTree::from_iter); + } +} -- cgit v1.2.3 From 1c560791902a4ef72efa671106d8f6d97fea50c1 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sun, 19 Apr 2020 16:40:29 -0400 Subject: metric/kd: Implement k-d trees --- src/metric.rs | 1 + src/metric/kd.rs | 217 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 218 insertions(+) create mode 100644 src/metric/kd.rs (limited to 'src/metric.rs') diff --git a/src/metric.rs b/src/metric.rs index 549db67..95191da 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -1,5 +1,6 @@ //! [Metric spaces](https://en.wikipedia.org/wiki/Metric_space). +pub mod kd; pub mod vp; use ordered_float::OrderedFloat; diff --git a/src/metric/kd.rs b/src/metric/kd.rs new file mode 100644 index 0000000..ab0cd2e --- /dev/null +++ b/src/metric/kd.rs @@ -0,0 +1,217 @@ +//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +use ordered_float::OrderedFloat; + +use std::iter::FromIterator; + +/// A point in Cartesian space. +pub trait Cartesian { + /// Returns the number of dimensions necessary to describe this point. + fn dimensions(&self) -> usize; + + /// Returns the value of the `i`th coordinate of this point (`i < self.dimensions()`). + fn coordinate(&self, i: usize) -> f64; +} + +/// Blanket [Cartesian] implementation for references. +impl<'a, T: Cartesian> Cartesian for &'a T { + fn dimensions(&self) -> usize { + (*self).dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + (*self).coordinate(i) + } +} + +/// Standard cartesian space. +impl Cartesian for [f64] { + fn dimensions(&self) -> usize { + self.len() + } + + fn coordinate(&self, i: usize) -> f64 { + self[i] + } +} + +/// A node in a k-d tree. +#[derive(Debug)] +struct KdNode { + /// The value stored in this node. + item: T, + /// The left subtree, if any. + left: Option>, + /// The right subtree, if any. + right: Option>, +} + +trait KdSearch<'a, T, U, N> { + /// Recursively search for nearest neighbors. + fn search(&'a self, i: usize, neighborhood: &mut N); + + /// Search the left subtree. + fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N); + + /// Search the right subtree. + fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N); +} + +impl<'a, T, U, N> KdSearch<'a, T, U, N> for KdNode +where + T: 'a + Cartesian, + U: Cartesian + Metric<&'a T>, + N: Neighborhood<&'a T, U>, +{ + fn search(&'a self, i: usize, neighborhood: &mut N) { + neighborhood.consider(&self.item); + + let distance = neighborhood.target().coordinate(i) - self.item.coordinate(i); + let j = (i + 1) % self.item.dimensions(); + if distance <= 0.0 { + self.search_left(j, distance, neighborhood); + self.search_right(j, -distance, neighborhood); + } else { + self.search_right(j, -distance, neighborhood); + self.search_left(j, distance, neighborhood); + } + } + + fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N) { + if let Some(left) = &self.left { + if neighborhood.contains(distance) { + left.search(i, neighborhood); + } + } + } + + fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N) { + if let Some(right) = &self.right { + if neighborhood.contains(distance) { + right.search(i, neighborhood); + } + } + } +} + +impl KdNode { + /// Create a new KdNode. + fn new(i: usize, mut items: Vec) -> Option> { + if items.is_empty() { + return None; + } + + items.sort_unstable_by_key(|x| OrderedFloat::from(x.coordinate(i))); + + let mid = items.len() / 2; + let right: Vec = items.drain((mid + 1)..).collect(); + let item = items.pop().unwrap(); + let j = (i + 1) % item.dimensions(); + Some(Box::new(Self { + item, + left: Self::new(j, items), + right: Self::new(j, right), + })) + } +} + +/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). +#[derive(Debug)] +pub struct KdTree { + root: Option>>, +} + +impl FromIterator for KdTree { + /// Create a new k-d tree from a set of points. + fn from_iter>(items: I) -> Self { + Self { + root: KdNode::new(0, items.into_iter().collect()), + } + } +} + +impl NearestNeighbors for KdTree +where + T: Cartesian, + U: Cartesian + Metric, +{ + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + if let Some(root) = &self.root { + root.search(0, &mut neighborhood); + } + neighborhood + } +} + +/// An iterator that the moves values out of a k-d tree. +#[derive(Debug)] +pub struct IntoIter { + stack: Vec>>, +} + +impl IntoIter { + fn new(node: Option>>) -> Self { + Self { + stack: node.into_iter().collect(), + } + } +} + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + self.stack.pop().map(|node| { + self.stack.extend(node.left); + self.stack.extend(node.right); + node.item + }) + } +} + +impl IntoIterator for KdTree { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.root) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::{test_nearest_neighbors, Point}; + use crate::metric::SquaredDistance; + + impl Metric<[f64]> for Point { + type Distance = SquaredDistance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + self.0.distance(other) + } + } + + impl Cartesian for Point { + fn dimensions(&self) -> usize { + self.0.dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + self.0.coordinate(i) + } + } + + #[test] + fn test_kd_tree() { + test_nearest_neighbors(KdTree::from_iter); + } +} -- cgit v1.2.3 From a4a75059f302de2a00971f1f485fcf4389710628 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sun, 19 Apr 2020 16:55:17 -0400 Subject: metric/forest: Implement dynamized forests --- src/metric.rs | 1 + src/metric/forest.rs | 152 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 src/metric/forest.rs (limited to 'src/metric.rs') diff --git a/src/metric.rs b/src/metric.rs index 95191da..7a5f5f7 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -1,5 +1,6 @@ //! [Metric spaces](https://en.wikipedia.org/wiki/Metric_space). +pub mod forest; pub mod kd; pub mod vp; diff --git a/src/metric/forest.rs b/src/metric/forest.rs new file mode 100644 index 0000000..f23c451 --- /dev/null +++ b/src/metric/forest.rs @@ -0,0 +1,152 @@ +//! [Dynamization](https://en.wikipedia.org/wiki/Dynamization) for nearest neighbor search. + +use super::kd::KdTree; +use super::vp::VpTree; +use super::{Metric, NearestNeighbors, Neighborhood}; + +use std::iter::{Extend, Flatten, FromIterator}; + +/// A dynamic wrapper for a static nearest neighbor search data structure. +/// +/// This type applies [dynamization](https://en.wikipedia.org/wiki/Dynamization) to an arbitrary +/// nearest neighbor search structure `T`, allowing new items to be added dynamically. +#[derive(Debug)] +pub struct Forest(Vec>); + +impl Forest +where + U: FromIterator + IntoIterator, +{ + /// Create a new empty forest. + pub fn new() -> Self { + Self(Vec::new()) + } + + /// Add a new item to the forest. + pub fn push(&mut self, item: T) { + let mut items = vec![item]; + + for slot in &mut self.0 { + match slot.take() { + // Collect the items from any trees we encounter... + Some(tree) => { + items.extend(tree); + } + // ... and put them all in the first empty slot + None => { + *slot = Some(items.into_iter().collect()); + return; + } + } + } + + self.0.push(Some(items.into_iter().collect())); + } + + /// Get the number of items in the forest. + pub fn len(&self) -> usize { + let mut len = 0; + for (i, slot) in self.0.iter().enumerate() { + if slot.is_some() { + len |= 1 << i; + } + } + len + } +} + +impl Extend for Forest +where + U: FromIterator + IntoIterator, +{ + fn extend>(&mut self, items: I) { + for item in items { + self.push(item); + } + } +} + +impl FromIterator for Forest +where + U: FromIterator + IntoIterator, +{ + fn from_iter>(items: I) -> Self { + let mut forest = Self::new(); + forest.extend(items); + forest + } +} + +type IntoIterImpl = Flatten>>>; + +/// An iterator that moves items out of a forest. +pub struct IntoIter(IntoIterImpl); + +impl Iterator for IntoIter { + type Item = T::Item; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl IntoIterator for Forest { + type Item = T::Item; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.0.into_iter().flatten().flatten()) + } +} + +impl NearestNeighbors for Forest +where + U: Metric, + V: NearestNeighbors, +{ + fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + self.0 + .iter() + .flatten() + .fold(neighborhood, |n, t| t.search(n)) + } +} + +/// A forest of k-d trees. +pub type KdForest = Forest>; + +/// A forest of vantage-point trees. +pub type VpForest = Forest>; + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::test_nearest_neighbors; + use crate::metric::ExhaustiveSearch; + + #[test] + fn test_exhaustive_forest() { + test_nearest_neighbors(Forest::>::from_iter); + } + + #[test] + fn test_forest_forest() { + test_nearest_neighbors(Forest::>>::from_iter); + } + + #[test] + fn test_kd_forest() { + test_nearest_neighbors(KdForest::from_iter); + } + + #[test] + fn test_vp_forest() { + test_nearest_neighbors(VpForest::from_iter); + } +} -- cgit v1.2.3 From e53ace3e69ed6bacedb7de345df10d3e575a291e Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sun, 19 Apr 2020 16:58:52 -0400 Subject: metric/soft: Implement soft deletes --- src/metric.rs | 1 + src/metric/soft.rs | 282 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 283 insertions(+) create mode 100644 src/metric/soft.rs (limited to 'src/metric.rs') diff --git a/src/metric.rs b/src/metric.rs index 7a5f5f7..b46c8da 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -2,6 +2,7 @@ pub mod forest; pub mod kd; +pub mod soft; pub mod vp; use ordered_float::OrderedFloat; diff --git a/src/metric/soft.rs b/src/metric/soft.rs new file mode 100644 index 0000000..0d7dcdb --- /dev/null +++ b/src/metric/soft.rs @@ -0,0 +1,282 @@ +//! [Soft deletion](https://en.wiktionary.org/wiki/soft_deletion) for nearest neighbor search. + +use super::forest::{KdForest, VpForest}; +use super::kd::KdTree; +use super::vp::VpTree; +use super::{Metric, NearestNeighbors, Neighborhood}; + +use std::iter; +use std::iter::FromIterator; +use std::mem; + +/// A trait for objects that can be soft-deleted. +pub trait SoftDelete { + /// Check whether this item is deleted. + fn is_deleted(&self) -> bool; +} + +/// Blanket [SoftDelete] implementation for references. +impl<'a, T: SoftDelete> SoftDelete for &'a T { + fn is_deleted(&self) -> bool { + (*self).is_deleted() + } +} + +/// [Neighborhood] wrapper that ignores soft-deleted items. +#[derive(Debug)] +struct SoftNeighborhood(N); + +impl Neighborhood for SoftNeighborhood +where + T: SoftDelete, + U: Metric, + N: Neighborhood, +{ + fn target(&self) -> U { + self.0.target() + } + + fn contains(&self, distance: f64) -> bool { + self.0.contains(distance) + } + + fn contains_distance(&self, distance: U::Distance) -> bool { + self.0.contains_distance(distance) + } + + fn consider(&mut self, item: T) -> U::Distance { + if item.is_deleted() { + self.target().distance(&item) + } else { + self.0.consider(item) + } + } +} + +/// A [NearestNeighbors] implementation that supports [soft deletes](https://en.wiktionary.org/wiki/soft_deletion). +#[derive(Debug)] +pub struct SoftSearch(T); + +impl SoftSearch +where + T: SoftDelete, + U: FromIterator + IntoIterator, +{ + /// Create a new empty soft index. + pub fn new() -> Self { + Self(iter::empty().collect()) + } + + /// Push a new item into this index. + pub fn push(&mut self, item: T) + where + U: Extend, + { + self.0.extend(iter::once(item)); + } + + /// Rebuild this index, discarding deleted items. + pub fn rebuild(&mut self) { + let items = mem::replace(&mut self.0, iter::empty().collect()); + self.0 = items.into_iter().filter(|e| !e.is_deleted()).collect(); + } +} + +impl> Extend for SoftSearch { + fn extend>(&mut self, iter: I) { + self.0.extend(iter); + } +} + +impl> FromIterator for SoftSearch { + fn from_iter>(iter: I) -> Self { + Self(U::from_iter(iter)) + } +} + +impl IntoIterator for SoftSearch { + type Item = T::Item; + type IntoIter = T::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl NearestNeighbors for SoftSearch +where + T: SoftDelete, + U: Metric, + V: NearestNeighbors, +{ + fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + self.0.search(SoftNeighborhood(neighborhood)).0 + } +} + +/// A k-d forest that supports soft deletes. +pub type SoftKdForest = SoftSearch>; + +/// A k-d tree that supports soft deletes. +pub type SoftKdTree = SoftSearch>; + +/// A VP forest that supports soft deletes. +pub type SoftVpForest = SoftSearch>; + +/// A VP tree that supports soft deletes. +pub type SoftVpTree = SoftSearch>; + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::kd::Cartesian; + use crate::metric::tests::Point; + use crate::metric::Neighbor; + + #[derive(Debug, PartialEq)] + struct SoftPoint { + point: Point, + deleted: bool, + } + + impl SoftPoint { + fn new(x: f64, y: f64, z: f64) -> Self { + Self { + point: Point([x, y, z]), + deleted: false, + } + } + + fn deleted(x: f64, y: f64, z: f64) -> Self { + Self { + point: Point([x, y, z]), + deleted: true, + } + } + } + + impl SoftDelete for SoftPoint { + fn is_deleted(&self) -> bool { + self.deleted + } + } + + impl Metric for SoftPoint { + type Distance = ::Distance; + + fn distance(&self, other: &Self) -> Self::Distance { + self.point.distance(&other.point) + } + } + + impl Metric<[f64]> for SoftPoint { + type Distance = ::Distance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + self.point.distance(other) + } + } + + impl Cartesian for SoftPoint { + fn dimensions(&self) -> usize { + self.point.dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + self.point.coordinate(i) + } + } + + impl Metric for Point { + type Distance = ::Distance; + + fn distance(&self, other: &SoftPoint) -> Self::Distance { + self.distance(&other.point) + } + } + + fn test_index(index: &T) + where + T: NearestNeighbors, + { + let target = Point([0.0, 0.0, 0.0]); + + assert_eq!( + index.nearest(&target), + Some(Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0)) + ); + + assert_eq!(index.nearest_within(&target, 2.0), None); + assert_eq!( + index.nearest_within(&target, 4.0), + Some(Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0)) + ); + + assert_eq!( + index.k_nearest(&target, 3), + vec![ + Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), + Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), + Neighbor::new(&SoftPoint::new(2.0, 3.0, 6.0), 7.0), + ] + ); + + assert_eq!( + index.k_nearest_within(&target, 3, 6.0), + vec![ + Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), + Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), + ] + ); + assert_eq!( + index.k_nearest_within(&target, 3, 8.0), + vec![ + Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), + Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), + Neighbor::new(&SoftPoint::new(2.0, 3.0, 6.0), 7.0), + ] + ); + } + + fn test_soft_index(index: &mut SoftSearch) + where + T: Extend, + T: FromIterator, + T: IntoIterator, + T: NearestNeighbors, + { + let points = vec![ + SoftPoint::deleted(0.0, 0.0, 0.0), + SoftPoint::new(3.0, 4.0, 0.0), + SoftPoint::new(5.0, 0.0, 12.0), + SoftPoint::new(0.0, 8.0, 15.0), + SoftPoint::new(1.0, 2.0, 2.0), + SoftPoint::new(2.0, 3.0, 6.0), + SoftPoint::new(4.0, 4.0, 7.0), + ]; + + for point in points { + index.push(point); + } + test_index(index); + + index.rebuild(); + test_index(index); + } + + #[test] + fn test_soft_kd_forest() { + test_soft_index(&mut SoftKdForest::new()); + } + + #[test] + fn test_soft_vp_forest() { + test_soft_index(&mut SoftVpForest::new()); + } +} -- cgit v1.2.3 From 462daaffd5ec720ed80a2e7b1f445a73cabf5833 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sun, 19 Apr 2020 16:59:24 -0400 Subject: metric/approx: Implement approximate nearest neighbor search --- src/metric.rs | 1 + src/metric/approx.rs | 131 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 src/metric/approx.rs (limited to 'src/metric.rs') diff --git a/src/metric.rs b/src/metric.rs index b46c8da..268aefd 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -1,5 +1,6 @@ //! [Metric spaces](https://en.wikipedia.org/wiki/Metric_space). +pub mod approx; pub mod forest; pub mod kd; pub mod soft; diff --git a/src/metric/approx.rs b/src/metric/approx.rs new file mode 100644 index 0000000..c23f9c7 --- /dev/null +++ b/src/metric/approx.rs @@ -0,0 +1,131 @@ +//! [Approximate nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +/// An approximate [Neighborhood], for approximate nearest neighbor searches. +#[derive(Debug)] +struct ApproximateNeighborhood { + inner: N, + ratio: f64, + limit: usize, +} + +impl ApproximateNeighborhood { + fn new(inner: N, ratio: f64, limit: usize) -> Self { + Self { + inner, + ratio, + limit, + } + } +} + +impl Neighborhood for ApproximateNeighborhood +where + U: Metric, + N: Neighborhood, +{ + fn target(&self) -> U { + self.inner.target() + } + + fn contains(&self, distance: f64) -> bool { + if self.limit > 0 { + self.inner.contains(self.ratio * distance) + } else { + false + } + } + + fn contains_distance(&self, distance: U::Distance) -> bool { + self.contains(self.ratio * distance.into()) + } + + fn consider(&mut self, item: T) -> U::Distance { + self.limit = self.limit.saturating_sub(1); + self.inner.consider(item) + } +} + +/// An [approximate nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor) +/// index. +/// +/// This wrapper converts an exact nearest neighbor search algorithm into an approximate one by +/// modifying the behavior of [Neighborhood::contains]. The approximation is controlled by two +/// parameters: +/// +/// * `ratio`: The [nearest neighbor distance ratio](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Nearest_neighbor_distance_ratio), +/// which controls how much closer new candidates must be to be considered. For example, a ratio +/// of 2.0 means that a neighbor must be less than half of the current threshold away to be +/// considered. A ratio of 1.0 means an exact search. +/// +/// * `limit`: A limit on the number of candidates to consider. Typical nearest neighbor algorithms +/// find a close match quickly, so setting a limit bounds the worst-case search time while keeping +/// good accuracy. +#[derive(Debug)] +pub struct ApproximateSearch { + inner: T, + ratio: f64, + limit: usize, +} + +impl ApproximateSearch { + /// Create a new ApproximateSearch index. + /// + /// * `inner`: The [NearestNeighbors] implementation to wrap. + /// * `ratio`: The nearest neighbor distance ratio. + /// * `limit`: The maximum number of results to consider. + pub fn new(inner: T, ratio: f64, limit: usize) -> Self { + Self { + inner, + ratio, + limit, + } + } +} + +impl NearestNeighbors for ApproximateSearch +where + U: Metric, + V: NearestNeighbors, +{ + fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + self.inner + .search(ApproximateNeighborhood::new( + neighborhood, + self.ratio, + self.limit, + )) + .inner + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::kd::KdTree; + use crate::metric::tests::test_nearest_neighbors; + use crate::metric::vp::VpTree; + + use std::iter::FromIterator; + + #[test] + fn test_approx_kd_tree() { + test_nearest_neighbors(|iter| { + ApproximateSearch::new(KdTree::from_iter(iter), 1.0, std::usize::MAX) + }); + } + + #[test] + fn test_approx_vp_tree() { + test_nearest_neighbors(|iter| { + ApproximateSearch::new(VpTree::from_iter(iter), 1.0, std::usize::MAX) + }); + } +} -- cgit v1.2.3