diff options
author | Tavian Barnes <tavianator@tavianator.com> | 2020-06-24 15:20:02 -0400 |
---|---|---|
committer | Tavian Barnes <tavianator@tavianator.com> | 2020-06-24 15:44:14 -0400 |
commit | 39c0348c9f98b4dd29bd112a0a2a42faa67c92d4 (patch) | |
tree | 6c8ed80bd8cbbb0af79c9ac57bdb39634fa178fd /src/metric | |
parent | adaafdd7043507cbceae65e78c38954e47103b5c (diff) | |
download | kd-forest-master.tar.xz |
Diffstat (limited to 'src/metric')
-rw-r--r-- | src/metric/approx.rs | 131 | ||||
-rw-r--r-- | src/metric/forest.rs | 187 | ||||
-rw-r--r-- | src/metric/kd.rs | 224 | ||||
-rw-r--r-- | src/metric/soft.rs | 292 | ||||
-rw-r--r-- | src/metric/vp.rs | 137 |
5 files changed, 0 insertions, 971 deletions
diff --git a/src/metric/approx.rs b/src/metric/approx.rs deleted file mode 100644 index c23f9c7..0000000 --- a/src/metric/approx.rs +++ /dev/null @@ -1,131 +0,0 @@ -//! [Approximate nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor). - -use super::{Metric, NearestNeighbors, Neighborhood}; - -/// An approximate [Neighborhood], for approximate nearest neighbor searches. -#[derive(Debug)] -struct ApproximateNeighborhood<N> { - inner: N, - ratio: f64, - limit: usize, -} - -impl<N> ApproximateNeighborhood<N> { - fn new(inner: N, ratio: f64, limit: usize) -> Self { - Self { - inner, - ratio, - limit, - } - } -} - -impl<T, U, N> Neighborhood<T, U> for ApproximateNeighborhood<N> -where - U: Metric<T>, - N: Neighborhood<T, U>, -{ - fn target(&self) -> U { - self.inner.target() - } - - fn contains(&self, distance: f64) -> bool { - if self.limit > 0 { - self.inner.contains(self.ratio * distance) - } else { - false - } - } - - fn contains_distance(&self, distance: U::Distance) -> bool { - self.contains(self.ratio * distance.into()) - } - - fn consider(&mut self, item: T) -> U::Distance { - self.limit = self.limit.saturating_sub(1); - self.inner.consider(item) - } -} - -/// An [approximate nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor) -/// index. -/// -/// This wrapper converts an exact nearest neighbor search algorithm into an approximate one by -/// modifying the behavior of [Neighborhood::contains]. The approximation is controlled by two -/// parameters: -/// -/// * `ratio`: The [nearest neighbor distance ratio](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Nearest_neighbor_distance_ratio), -/// which controls how much closer new candidates must be to be considered. For example, a ratio -/// of 2.0 means that a neighbor must be less than half of the current threshold away to be -/// considered. A ratio of 1.0 means an exact search. -/// -/// * `limit`: A limit on the number of candidates to consider. Typical nearest neighbor algorithms -/// find a close match quickly, so setting a limit bounds the worst-case search time while keeping -/// good accuracy. -#[derive(Debug)] -pub struct ApproximateSearch<T> { - inner: T, - ratio: f64, - limit: usize, -} - -impl<T> ApproximateSearch<T> { - /// Create a new ApproximateSearch index. - /// - /// * `inner`: The [NearestNeighbors] implementation to wrap. - /// * `ratio`: The nearest neighbor distance ratio. - /// * `limit`: The maximum number of results to consider. - pub fn new(inner: T, ratio: f64, limit: usize) -> Self { - Self { - inner, - ratio, - limit, - } - } -} - -impl<T, U, V> NearestNeighbors<T, U> for ApproximateSearch<V> -where - U: Metric<T>, - V: NearestNeighbors<T, U>, -{ - fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N - where - T: 'a, - U: 'b, - N: Neighborhood<&'a T, &'b U>, - { - self.inner - .search(ApproximateNeighborhood::new( - neighborhood, - self.ratio, - self.limit, - )) - .inner - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::metric::kd::KdTree; - use crate::metric::tests::test_nearest_neighbors; - use crate::metric::vp::VpTree; - - use std::iter::FromIterator; - - #[test] - fn test_approx_kd_tree() { - test_nearest_neighbors(|iter| { - ApproximateSearch::new(KdTree::from_iter(iter), 1.0, std::usize::MAX) - }); - } - - #[test] - fn test_approx_vp_tree() { - test_nearest_neighbors(|iter| { - ApproximateSearch::new(VpTree::from_iter(iter), 1.0, std::usize::MAX) - }); - } -} diff --git a/src/metric/forest.rs b/src/metric/forest.rs deleted file mode 100644 index 887ff12..0000000 --- a/src/metric/forest.rs +++ /dev/null @@ -1,187 +0,0 @@ -//! [Dynamization](https://en.wikipedia.org/wiki/Dynamization) for nearest neighbor search. - -use super::kd::KdTree; -use super::vp::VpTree; -use super::{Metric, NearestNeighbors, Neighborhood}; - -use std::iter::{self, Extend, FromIterator}; - -/// The number of bits dedicated to the flat buffer. -const BUFFER_BITS: usize = 6; -/// The maximum size of the buffer. -const BUFFER_SIZE: usize = 1 << BUFFER_BITS; - -/// A dynamic wrapper for a static nearest neighbor search data structure. -/// -/// This type applies [dynamization](https://en.wikipedia.org/wiki/Dynamization) to an arbitrary -/// nearest neighbor search structure `T`, allowing new items to be added dynamically. -#[derive(Debug)] -pub struct Forest<T: IntoIterator> { - /// A flat buffer used for the first few items, to avoid repeatedly rebuilding small trees. - buffer: Vec<T::Item>, - /// The trees of the forest, with sizes in geometric progression. - trees: Vec<Option<T>>, -} - -impl<T, U> Forest<U> -where - U: FromIterator<T> + IntoIterator<Item = T>, -{ - /// Create a new empty forest. - pub fn new() -> Self { - Self { - buffer: Vec::new(), - trees: Vec::new(), - } - } - - /// Add a new item to the forest. - pub fn push(&mut self, item: T) { - self.extend(iter::once(item)); - } - - /// Get the number of items in the forest. - pub fn len(&self) -> usize { - let mut len = self.buffer.len(); - for (i, slot) in self.trees.iter().enumerate() { - if slot.is_some() { - len += 1 << (i + BUFFER_BITS); - } - } - len - } - - /// Check if this forest is empty. - pub fn is_empty(&self) -> bool { - if !self.buffer.is_empty() { - return false; - } - - self.trees.iter().flatten().next().is_none() - } -} - -impl<T, U> Default for Forest<U> -where - U: FromIterator<T> + IntoIterator<Item = T>, -{ - fn default() -> Self { - Self::new() - } -} - -impl<T, U> Extend<T> for Forest<U> -where - U: FromIterator<T> + IntoIterator<Item = T>, -{ - fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) { - self.buffer.extend(items); - if self.buffer.len() < BUFFER_SIZE { - return; - } - - let len = self.len(); - - for i in 0.. { - let bit = 1 << (i + BUFFER_BITS); - - if bit > len { - break; - } - - if i >= self.trees.len() { - self.trees.push(None); - } - - if len & bit == 0 { - if let Some(tree) = self.trees[i].take() { - self.buffer.extend(tree); - } - } else if self.trees[i].is_none() { - let offset = self.buffer.len() - bit; - self.trees[i] = Some(self.buffer.drain(offset..).collect()); - } - } - - debug_assert!(self.buffer.len() < BUFFER_SIZE); - debug_assert!(self.len() == len); - } -} - -impl<T, U> FromIterator<T> for Forest<U> -where - U: FromIterator<T> + IntoIterator<Item = T>, -{ - fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { - let mut forest = Self::new(); - forest.extend(items); - forest - } -} - -impl<T: IntoIterator> IntoIterator for Forest<T> { - type Item = T::Item; - type IntoIter = std::vec::IntoIter<T::Item>; - - fn into_iter(mut self) -> Self::IntoIter { - self.buffer.extend(self.trees.into_iter().flatten().flatten()); - self.buffer.into_iter() - } -} - -impl<T, U, V> NearestNeighbors<T, U> for Forest<V> -where - U: Metric<T>, - V: NearestNeighbors<T, U>, - V: IntoIterator<Item = T>, -{ - fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N - where - T: 'a, - U: 'b, - N: Neighborhood<&'a T, &'b U>, - { - for item in &self.buffer { - neighborhood.consider(item); - } - - self.trees - .iter() - .flatten() - .fold(neighborhood, |n, t| t.search(n)) - } -} - -/// A forest of k-d trees. -pub type KdForest<T> = Forest<KdTree<T>>; - -/// A forest of vantage-point trees. -pub type VpForest<T> = Forest<VpTree<T>>; - -#[cfg(test)] -mod tests { - use super::*; - - use crate::metric::tests::test_nearest_neighbors; - use crate::metric::ExhaustiveSearch; - - #[test] - fn test_exhaustive_forest() { - test_nearest_neighbors(Forest::<ExhaustiveSearch<_>>::from_iter); - } - - #[test] - fn test_forest_forest() { - test_nearest_neighbors(Forest::<Forest<ExhaustiveSearch<_>>>::from_iter); - } - - #[test] - fn test_kd_forest() { - test_nearest_neighbors(KdForest::from_iter); - } - - #[test] - fn test_vp_forest() { - test_nearest_neighbors(VpForest::from_iter); - } -} diff --git a/src/metric/kd.rs b/src/metric/kd.rs deleted file mode 100644 index 6ea3809..0000000 --- a/src/metric/kd.rs +++ /dev/null @@ -1,224 +0,0 @@ -//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree). - -use super::{Metric, NearestNeighbors, Neighborhood, Ordered}; - -use std::iter::FromIterator; - -/// A point in Cartesian space. -pub trait Cartesian: Metric<[f64]> { - /// Returns the number of dimensions necessary to describe this point. - fn dimensions(&self) -> usize; - - /// Returns the value of the `i`th coordinate of this point (`i < self.dimensions()`). - fn coordinate(&self, i: usize) -> f64; -} - -/// Blanket [Cartesian] implementation for references. -impl<'a, T: Cartesian> Cartesian for &'a T { - fn dimensions(&self) -> usize { - (*self).dimensions() - } - - fn coordinate(&self, i: usize) -> f64 { - (*self).coordinate(i) - } -} - -/// Blanket [Metric<[f64]>](Metric) implementation for [Cartesian] references. -impl<'a, T: Cartesian> Metric<[f64]> for &'a T { - type Distance = T::Distance; - - fn distance(&self, other: &[f64]) -> Self::Distance { - (*self).distance(other) - } -} - -/// Standard cartesian space. -impl Cartesian for [f64] { - fn dimensions(&self) -> usize { - self.len() - } - - fn coordinate(&self, i: usize) -> f64 { - self[i] - } -} - -/// Marker trait for cartesian metric spaces. -pub trait CartesianMetric<T: ?Sized = Self>: - Cartesian + Metric<T, Distance = <Self as Metric<[f64]>>::Distance> -{ -} - -/// Blanket [CartesianMetric] implementation for cartesian spaces with compatible metric distance -/// types. -impl<T, U> CartesianMetric<T> for U -where - T: ?Sized, - U: ?Sized + Cartesian + Metric<T, Distance = <U as Metric<[f64]>>::Distance>, -{ -} - -/// A node in a k-d tree. -#[derive(Debug)] -struct KdNode<T> { - /// The value stored in this node. - item: T, - /// The size of the left subtree. - left_len: usize, -} - -impl<T: Cartesian> KdNode<T> { - /// Create a new KdNode. - fn new(item: T) -> Self { - Self { item, left_len: 0 } - } - - /// Build a k-d tree recursively. - fn build(slice: &mut [KdNode<T>], i: usize) { - if slice.is_empty() { - return; - } - - slice.sort_unstable_by_key(|n| Ordered(n.item.coordinate(i))); - - let mid = slice.len() / 2; - slice.swap(0, mid); - - let (node, children) = slice.split_first_mut().unwrap(); - let (left, right) = children.split_at_mut(mid); - node.left_len = left.len(); - - let j = (i + 1) % node.item.dimensions(); - Self::build(left, j); - Self::build(right, j); - } - - /// Recursively search for nearest neighbors. - fn recurse<'a, U, N>( - slice: &'a [KdNode<T>], - i: usize, - closest: &mut [f64], - neighborhood: &mut N, - ) where - T: 'a, - U: CartesianMetric<&'a T>, - N: Neighborhood<&'a T, U>, - { - let (node, children) = slice.split_first().unwrap(); - neighborhood.consider(&node.item); - - let target = neighborhood.target(); - let ti = target.coordinate(i); - let ni = node.item.coordinate(i); - let j = (i + 1) % node.item.dimensions(); - - let (left, right) = children.split_at(node.left_len); - let (near, far) = if ti <= ni { - (left, right) - } else { - (right, left) - }; - - if !near.is_empty() { - Self::recurse(near, j, closest, neighborhood); - } - - if !far.is_empty() { - let saved = closest[i]; - closest[i] = ni; - if neighborhood.contains_distance(target.distance(closest)) { - Self::recurse(far, j, closest, neighborhood); - } - closest[i] = saved; - } - } -} - -/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). -#[derive(Debug)] -pub struct KdTree<T>(Vec<KdNode<T>>); - -impl<T: Cartesian> FromIterator<T> for KdTree<T> { - /// Create a new k-d tree from a set of points. - fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { - let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect(); - KdNode::build(nodes.as_mut_slice(), 0); - Self(nodes) - } -} - -impl<T, U> NearestNeighbors<T, U> for KdTree<T> -where - T: Cartesian, - U: CartesianMetric<T>, -{ - fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N - where - T: 'a, - U: 'b, - N: Neighborhood<&'a T, &'b U>, - { - if !self.0.is_empty() { - let target = neighborhood.target(); - let dims = target.dimensions(); - let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); - - KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood); - } - - neighborhood - } -} - -/// An iterator that the moves values out of a k-d tree. -#[derive(Debug)] -pub struct IntoIter<T>(std::vec::IntoIter<KdNode<T>>); - -impl<T> Iterator for IntoIter<T> { - type Item = T; - - fn next(&mut self) -> Option<T> { - self.0.next().map(|n| n.item) - } -} - -impl<T> IntoIterator for KdTree<T> { - type Item = T; - type IntoIter = IntoIter<T>; - - fn into_iter(self) -> Self::IntoIter { - IntoIter(self.0.into_iter()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::metric::tests::{test_nearest_neighbors, Point}; - use crate::metric::SquaredDistance; - - impl Metric<[f64]> for Point { - type Distance = SquaredDistance; - - fn distance(&self, other: &[f64]) -> Self::Distance { - self.0.distance(other) - } - } - - impl Cartesian for Point { - fn dimensions(&self) -> usize { - self.0.dimensions() - } - - fn coordinate(&self, i: usize) -> f64 { - self.0.coordinate(i) - } - } - - #[test] - fn test_kd_tree() { - test_nearest_neighbors(KdTree::from_iter); - } -} diff --git a/src/metric/soft.rs b/src/metric/soft.rs deleted file mode 100644 index d443bfd..0000000 --- a/src/metric/soft.rs +++ /dev/null @@ -1,292 +0,0 @@ -//! [Soft deletion](https://en.wiktionary.org/wiki/soft_deletion) for nearest neighbor search. - -use super::forest::{KdForest, VpForest}; -use super::kd::KdTree; -use super::vp::VpTree; -use super::{Metric, NearestNeighbors, Neighborhood}; - -use std::iter; -use std::iter::FromIterator; -use std::mem; - -/// A trait for objects that can be soft-deleted. -pub trait SoftDelete { - /// Check whether this item is deleted. - fn is_deleted(&self) -> bool; -} - -/// Blanket [SoftDelete] implementation for references. -impl<'a, T: SoftDelete> SoftDelete for &'a T { - fn is_deleted(&self) -> bool { - (*self).is_deleted() - } -} - -/// [Neighborhood] wrapper that ignores soft-deleted items. -#[derive(Debug)] -struct SoftNeighborhood<N>(N); - -impl<T, U, N> Neighborhood<T, U> for SoftNeighborhood<N> -where - T: SoftDelete, - U: Metric<T>, - N: Neighborhood<T, U>, -{ - fn target(&self) -> U { - self.0.target() - } - - fn contains(&self, distance: f64) -> bool { - self.0.contains(distance) - } - - fn contains_distance(&self, distance: U::Distance) -> bool { - self.0.contains_distance(distance) - } - - fn consider(&mut self, item: T) -> U::Distance { - if item.is_deleted() { - self.target().distance(&item) - } else { - self.0.consider(item) - } - } -} - -/// A [NearestNeighbors] implementation that supports [soft deletes](https://en.wiktionary.org/wiki/soft_deletion). -#[derive(Debug)] -pub struct SoftSearch<T>(T); - -impl<T, U> SoftSearch<U> -where - T: SoftDelete, - U: FromIterator<T> + IntoIterator<Item = T>, -{ - /// Create a new empty soft index. - pub fn new() -> Self { - Self(iter::empty().collect()) - } - - /// Push a new item into this index. - pub fn push(&mut self, item: T) - where - U: Extend<T>, - { - self.0.extend(iter::once(item)); - } - - /// Rebuild this index, discarding deleted items. - pub fn rebuild(&mut self) { - let items = mem::replace(&mut self.0, iter::empty().collect()); - self.0 = items.into_iter().filter(|e| !e.is_deleted()).collect(); - } -} - -impl<T, U> Default for SoftSearch<U> -where - T: SoftDelete, - U: FromIterator<T> + IntoIterator<Item = T>, -{ - fn default() -> Self { - Self::new() - } -} - -impl<T, U: Extend<T>> Extend<T> for SoftSearch<U> { - fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) { - self.0.extend(iter); - } -} - -impl<T, U: FromIterator<T>> FromIterator<T> for SoftSearch<U> { - fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self { - Self(U::from_iter(iter)) - } -} - -impl<T: IntoIterator> IntoIterator for SoftSearch<T> { - type Item = T::Item; - type IntoIter = T::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl<T, U, V> NearestNeighbors<T, U> for SoftSearch<V> -where - T: SoftDelete, - U: Metric<T>, - V: NearestNeighbors<T, U>, -{ - fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N - where - T: 'a, - U: 'b, - N: Neighborhood<&'a T, &'b U>, - { - self.0.search(SoftNeighborhood(neighborhood)).0 - } -} - -/// A k-d forest that supports soft deletes. -pub type SoftKdForest<T> = SoftSearch<KdForest<T>>; - -/// A k-d tree that supports soft deletes. -pub type SoftKdTree<T> = SoftSearch<KdTree<T>>; - -/// A VP forest that supports soft deletes. -pub type SoftVpForest<T> = SoftSearch<VpForest<T>>; - -/// A VP tree that supports soft deletes. -pub type SoftVpTree<T> = SoftSearch<VpTree<T>>; - -#[cfg(test)] -mod tests { - use super::*; - - use crate::metric::kd::Cartesian; - use crate::metric::tests::Point; - use crate::metric::Neighbor; - - #[derive(Debug, PartialEq)] - struct SoftPoint { - point: Point, - deleted: bool, - } - - impl SoftPoint { - fn new(x: f64, y: f64, z: f64) -> Self { - Self { - point: Point([x, y, z]), - deleted: false, - } - } - - fn deleted(x: f64, y: f64, z: f64) -> Self { - Self { - point: Point([x, y, z]), - deleted: true, - } - } - } - - impl SoftDelete for SoftPoint { - fn is_deleted(&self) -> bool { - self.deleted - } - } - - impl Metric for SoftPoint { - type Distance = <Point as Metric>::Distance; - - fn distance(&self, other: &Self) -> Self::Distance { - self.point.distance(&other.point) - } - } - - impl Metric<[f64]> for SoftPoint { - type Distance = <Point as Metric>::Distance; - - fn distance(&self, other: &[f64]) -> Self::Distance { - self.point.distance(other) - } - } - - impl Cartesian for SoftPoint { - fn dimensions(&self) -> usize { - self.point.dimensions() - } - - fn coordinate(&self, i: usize) -> f64 { - self.point.coordinate(i) - } - } - - impl Metric<SoftPoint> for Point { - type Distance = <Point as Metric>::Distance; - - fn distance(&self, other: &SoftPoint) -> Self::Distance { - self.distance(&other.point) - } - } - - fn test_index<T>(index: &T) - where - T: NearestNeighbors<SoftPoint, Point>, - { - let target = Point([0.0, 0.0, 0.0]); - - assert_eq!( - index.nearest(&target), - Some(Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0)) - ); - - assert_eq!(index.nearest_within(&target, 2.0), None); - assert_eq!( - index.nearest_within(&target, 4.0), - Some(Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0)) - ); - - assert_eq!( - index.k_nearest(&target, 3), - vec![ - Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), - Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), - Neighbor::new(&SoftPoint::new(2.0, 3.0, 6.0), 7.0), - ] - ); - - assert_eq!( - index.k_nearest_within(&target, 3, 6.0), - vec![ - Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), - Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), - ] - ); - assert_eq!( - index.k_nearest_within(&target, 3, 8.0), - vec![ - Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), - Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), - Neighbor::new(&SoftPoint::new(2.0, 3.0, 6.0), 7.0), - ] - ); - } - - fn test_soft_index<T>(index: &mut SoftSearch<T>) - where - T: Extend<SoftPoint>, - T: FromIterator<SoftPoint>, - T: IntoIterator<Item = SoftPoint>, - T: NearestNeighbors<SoftPoint, Point>, - { - let points = vec![ - SoftPoint::deleted(0.0, 0.0, 0.0), - SoftPoint::new(3.0, 4.0, 0.0), - SoftPoint::new(5.0, 0.0, 12.0), - SoftPoint::new(0.0, 8.0, 15.0), - SoftPoint::new(1.0, 2.0, 2.0), - SoftPoint::new(2.0, 3.0, 6.0), - SoftPoint::new(4.0, 4.0, 7.0), - ]; - - for point in points { - index.push(point); - } - test_index(index); - - index.rebuild(); - test_index(index); - } - - #[test] - fn test_soft_kd_forest() { - test_soft_index(&mut SoftKdForest::new()); - } - - #[test] - fn test_soft_vp_forest() { - test_soft_index(&mut SoftVpForest::new()); - } -} diff --git a/src/metric/vp.rs b/src/metric/vp.rs deleted file mode 100644 index d6e05df..0000000 --- a/src/metric/vp.rs +++ /dev/null @@ -1,137 +0,0 @@ -//! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree). - -use super::{Metric, NearestNeighbors, Neighborhood, Ordered}; - -use std::iter::FromIterator; - -/// A node in a VP tree. -#[derive(Debug)] -struct VpNode<T> { - /// The vantage point itself. - item: T, - /// The radius of this node. - radius: f64, - /// The size of the subtree inside the radius. - inside_len: usize, -} - -impl<T: Metric> VpNode<T> { - /// Create a new VpNode. - fn new(item: T) -> Self { - Self { - item, - radius: 0.0, - inside_len: 0, - } - } - - /// Build a VP tree recursively. - fn build(slice: &mut [VpNode<T>]) { - if let Some((node, children)) = slice.split_first_mut() { - let item = &node.item; - children.sort_by_cached_key(|n| Ordered(item.distance(&n.item))); - - let (inside, outside) = children.split_at_mut(children.len() / 2); - if let Some(last) = inside.last() { - node.radius = item.distance(&last.item).into(); - } - node.inside_len = inside.len(); - - Self::build(inside); - Self::build(outside); - } - } - - /// Recursively search for nearest neighbors. - fn recurse<'a, U, N>(slice: &'a [VpNode<T>], neighborhood: &mut N) - where - T: 'a, - U: Metric<&'a T>, - N: Neighborhood<&'a T, U>, - { - let (node, children) = slice.split_first().unwrap(); - let (inside, outside) = children.split_at(node.inside_len); - - let distance = neighborhood.consider(&node.item).into(); - - if distance <= node.radius { - if !inside.is_empty() && neighborhood.contains(distance - node.radius) { - Self::recurse(inside, neighborhood); - } - if !outside.is_empty() && neighborhood.contains(node.radius - distance) { - Self::recurse(outside, neighborhood); - } - } else { - if !outside.is_empty() && neighborhood.contains(node.radius - distance) { - Self::recurse(outside, neighborhood); - } - if !inside.is_empty() && neighborhood.contains(distance - node.radius) { - Self::recurse(inside, neighborhood); - } - } - } -} - -/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree). -#[derive(Debug)] -pub struct VpTree<T>(Vec<VpNode<T>>); - -impl<T: Metric> FromIterator<T> for VpTree<T> { - fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { - let mut nodes: Vec<_> = items.into_iter().map(VpNode::new).collect(); - VpNode::build(nodes.as_mut_slice()); - Self(nodes) - } -} - -impl<T, U> NearestNeighbors<T, U> for VpTree<T> -where - T: Metric, - U: Metric<T>, -{ - fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N - where - T: 'a, - U: 'b, - N: Neighborhood<&'a T, &'b U>, - { - if !self.0.is_empty() { - VpNode::recurse(&self.0, &mut neighborhood); - } - - neighborhood - } -} - -/// An iterator that moves values out of a VP tree. -#[derive(Debug)] -pub struct IntoIter<T>(std::vec::IntoIter<VpNode<T>>); - -impl<T> Iterator for IntoIter<T> { - type Item = T; - - fn next(&mut self) -> Option<T> { - self.0.next().map(|n| n.item) - } -} - -impl<T> IntoIterator for VpTree<T> { - type Item = T; - type IntoIter = IntoIter<T>; - - fn into_iter(self) -> Self::IntoIter { - IntoIter(self.0.into_iter()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::metric::tests::test_nearest_neighbors; - - #[test] - fn test_vp_tree() { - test_nearest_neighbors(VpTree::from_iter); - } -} |