diff options
author | Tavian Barnes <tavianator@tavianator.com> | 2020-05-03 10:55:16 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-03 10:55:16 -0400 |
commit | ce2904b4840611f769b92b55bf6d9b5afe84d3d7 (patch) | |
tree | a133319a302f95edf7a7a261262a8f24473bd21c /src/metric | |
parent | d95e93bf70f3351e6fd489284794ef7909fd94ce (diff) | |
parent | 2984e8f93fe88d0ee7eb3c0561dcd2da44807429 (diff) | |
download | kd-forest-ce2904b4840611f769b92b55bf6d9b5afe84d3d7.tar.xz |
Merge pull request #1 from tavianator/rust
Rewrite in rust
Diffstat (limited to 'src/metric')
-rw-r--r-- | src/metric/approx.rs | 131 | ||||
-rw-r--r-- | src/metric/forest.rs | 159 | ||||
-rw-r--r-- | src/metric/kd.rs | 226 | ||||
-rw-r--r-- | src/metric/soft.rs | 282 | ||||
-rw-r--r-- | src/metric/vp.rs | 137 |
5 files changed, 935 insertions, 0 deletions
diff --git a/src/metric/approx.rs b/src/metric/approx.rs new file mode 100644 index 0000000..c23f9c7 --- /dev/null +++ b/src/metric/approx.rs @@ -0,0 +1,131 @@ +//! [Approximate nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +/// An approximate [Neighborhood], for approximate nearest neighbor searches. +#[derive(Debug)] +struct ApproximateNeighborhood<N> { + inner: N, + ratio: f64, + limit: usize, +} + +impl<N> ApproximateNeighborhood<N> { + fn new(inner: N, ratio: f64, limit: usize) -> Self { + Self { + inner, + ratio, + limit, + } + } +} + +impl<T, U, N> Neighborhood<T, U> for ApproximateNeighborhood<N> +where + U: Metric<T>, + N: Neighborhood<T, U>, +{ + fn target(&self) -> U { + self.inner.target() + } + + fn contains(&self, distance: f64) -> bool { + if self.limit > 0 { + self.inner.contains(self.ratio * distance) + } else { + false + } + } + + fn contains_distance(&self, distance: U::Distance) -> bool { + self.contains(self.ratio * distance.into()) + } + + fn consider(&mut self, item: T) -> U::Distance { + self.limit = self.limit.saturating_sub(1); + self.inner.consider(item) + } +} + +/// An [approximate nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor) +/// index. +/// +/// This wrapper converts an exact nearest neighbor search algorithm into an approximate one by +/// modifying the behavior of [Neighborhood::contains]. The approximation is controlled by two +/// parameters: +/// +/// * `ratio`: The [nearest neighbor distance ratio](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Nearest_neighbor_distance_ratio), +/// which controls how much closer new candidates must be to be considered. For example, a ratio +/// of 2.0 means that a neighbor must be less than half of the current threshold away to be +/// considered. A ratio of 1.0 means an exact search. +/// +/// * `limit`: A limit on the number of candidates to consider. Typical nearest neighbor algorithms +/// find a close match quickly, so setting a limit bounds the worst-case search time while keeping +/// good accuracy. +#[derive(Debug)] +pub struct ApproximateSearch<T> { + inner: T, + ratio: f64, + limit: usize, +} + +impl<T> ApproximateSearch<T> { + /// Create a new ApproximateSearch index. + /// + /// * `inner`: The [NearestNeighbors] implementation to wrap. + /// * `ratio`: The nearest neighbor distance ratio. + /// * `limit`: The maximum number of results to consider. + pub fn new(inner: T, ratio: f64, limit: usize) -> Self { + Self { + inner, + ratio, + limit, + } + } +} + +impl<T, U, V> NearestNeighbors<T, U> for ApproximateSearch<V> +where + U: Metric<T>, + V: NearestNeighbors<T, U>, +{ + fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + self.inner + .search(ApproximateNeighborhood::new( + neighborhood, + self.ratio, + self.limit, + )) + .inner + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::kd::KdTree; + use crate::metric::tests::test_nearest_neighbors; + use crate::metric::vp::VpTree; + + use std::iter::FromIterator; + + #[test] + fn test_approx_kd_tree() { + test_nearest_neighbors(|iter| { + ApproximateSearch::new(KdTree::from_iter(iter), 1.0, std::usize::MAX) + }); + } + + #[test] + fn test_approx_vp_tree() { + test_nearest_neighbors(|iter| { + ApproximateSearch::new(VpTree::from_iter(iter), 1.0, std::usize::MAX) + }); + } +} diff --git a/src/metric/forest.rs b/src/metric/forest.rs new file mode 100644 index 0000000..29b6f55 --- /dev/null +++ b/src/metric/forest.rs @@ -0,0 +1,159 @@ +//! [Dynamization](https://en.wikipedia.org/wiki/Dynamization) for nearest neighbor search. + +use super::kd::KdTree; +use super::vp::VpTree; +use super::{Metric, NearestNeighbors, Neighborhood}; + +use std::iter::{self, Extend, Flatten, FromIterator}; + +/// A dynamic wrapper for a static nearest neighbor search data structure. +/// +/// This type applies [dynamization](https://en.wikipedia.org/wiki/Dynamization) to an arbitrary +/// nearest neighbor search structure `T`, allowing new items to be added dynamically. +#[derive(Debug)] +pub struct Forest<T>(Vec<Option<T>>); + +impl<T, U> Forest<U> +where + U: FromIterator<T> + IntoIterator<Item = T>, +{ + /// Create a new empty forest. + pub fn new() -> Self { + Self(Vec::new()) + } + + /// Add a new item to the forest. + pub fn push(&mut self, item: T) { + self.extend(iter::once(item)); + } + + /// Get the number of items in the forest. + pub fn len(&self) -> usize { + let mut len = 0; + for (i, slot) in self.0.iter().enumerate() { + if slot.is_some() { + len |= 1 << i; + } + } + len + } +} + +impl<T, U> Extend<T> for Forest<U> +where + U: FromIterator<T> + IntoIterator<Item = T>, +{ + fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) { + let mut vec: Vec<_> = items.into_iter().collect(); + let new_len = self.len() + vec.len(); + + for i in 0.. { + let bit = 1 << i; + + if bit > new_len { + break; + } + + if i >= self.0.len() { + self.0.push(None); + } + + if new_len & bit == 0 { + if let Some(tree) = self.0[i].take() { + vec.extend(tree); + } + } else if self.0[i].is_none() { + let offset = vec.len() - bit; + self.0[i] = Some(vec.drain(offset..).collect()); + } + } + + debug_assert!(vec.is_empty()); + debug_assert!(self.len() == new_len); + } +} + +impl<T, U> FromIterator<T> for Forest<U> +where + U: FromIterator<T> + IntoIterator<Item = T>, +{ + fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { + let mut forest = Self::new(); + forest.extend(items); + forest + } +} + +type IntoIterImpl<T> = Flatten<Flatten<std::vec::IntoIter<Option<T>>>>; + +/// An iterator that moves items out of a forest. +pub struct IntoIter<T: IntoIterator>(IntoIterImpl<T>); + +impl<T: IntoIterator> Iterator for IntoIter<T> { + type Item = T::Item; + + fn next(&mut self) -> Option<Self::Item> { + self.0.next() + } +} + +impl<T: IntoIterator> IntoIterator for Forest<T> { + type Item = T::Item; + type IntoIter = IntoIter<T>; + + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.0.into_iter().flatten().flatten()) + } +} + +impl<T, U, V> NearestNeighbors<T, U> for Forest<V> +where + U: Metric<T>, + V: NearestNeighbors<T, U>, +{ + fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + self.0 + .iter() + .flatten() + .fold(neighborhood, |n, t| t.search(n)) + } +} + +/// A forest of k-d trees. +pub type KdForest<T> = Forest<KdTree<T>>; + +/// A forest of vantage-point trees. +pub type VpForest<T> = Forest<VpTree<T>>; + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::test_nearest_neighbors; + use crate::metric::ExhaustiveSearch; + + #[test] + fn test_exhaustive_forest() { + test_nearest_neighbors(Forest::<ExhaustiveSearch<_>>::from_iter); + } + + #[test] + fn test_forest_forest() { + test_nearest_neighbors(Forest::<Forest<ExhaustiveSearch<_>>>::from_iter); + } + + #[test] + fn test_kd_forest() { + test_nearest_neighbors(KdForest::from_iter); + } + + #[test] + fn test_vp_forest() { + test_nearest_neighbors(VpForest::from_iter); + } +} diff --git a/src/metric/kd.rs b/src/metric/kd.rs new file mode 100644 index 0000000..2caf4a3 --- /dev/null +++ b/src/metric/kd.rs @@ -0,0 +1,226 @@ +//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +use ordered_float::OrderedFloat; + +use std::iter::FromIterator; + +/// A point in Cartesian space. +pub trait Cartesian: Metric<[f64]> { + /// Returns the number of dimensions necessary to describe this point. + fn dimensions(&self) -> usize; + + /// Returns the value of the `i`th coordinate of this point (`i < self.dimensions()`). + fn coordinate(&self, i: usize) -> f64; +} + +/// Blanket [Cartesian] implementation for references. +impl<'a, T: Cartesian> Cartesian for &'a T { + fn dimensions(&self) -> usize { + (*self).dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + (*self).coordinate(i) + } +} + +/// Blanket [Metric<[f64]>](Metric) implementation for [Cartesian] references. +impl<'a, T: Cartesian> Metric<[f64]> for &'a T { + type Distance = T::Distance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + (*self).distance(other) + } +} + +/// Standard cartesian space. +impl Cartesian for [f64] { + fn dimensions(&self) -> usize { + self.len() + } + + fn coordinate(&self, i: usize) -> f64 { + self[i] + } +} + +/// Marker trait for cartesian metric spaces. +pub trait CartesianMetric<T: ?Sized = Self>: + Cartesian + Metric<T, Distance = <Self as Metric<[f64]>>::Distance> +{ +} + +/// Blanket [CartesianMetric] implementation for cartesian spaces with compatible metric distance +/// types. +impl<T, U> CartesianMetric<T> for U +where + T: ?Sized, + U: ?Sized + Cartesian + Metric<T, Distance = <U as Metric<[f64]>>::Distance>, +{ +} + +/// A node in a k-d tree. +#[derive(Debug)] +struct KdNode<T> { + /// The value stored in this node. + item: T, + /// The size of the left subtree. + left_len: usize, +} + +impl<T: Cartesian> KdNode<T> { + /// Create a new KdNode. + fn new(item: T) -> Self { + Self { item, left_len: 0 } + } + + /// Build a k-d tree recursively. + fn build(slice: &mut [KdNode<T>], i: usize) { + if slice.is_empty() { + return; + } + + slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i))); + + let mid = slice.len() / 2; + slice.swap(0, mid); + + let (node, children) = slice.split_first_mut().unwrap(); + let (left, right) = children.split_at_mut(mid); + node.left_len = left.len(); + + let j = (i + 1) % node.item.dimensions(); + Self::build(left, j); + Self::build(right, j); + } + + /// Recursively search for nearest neighbors. + fn recurse<'a, U, N>( + slice: &'a [KdNode<T>], + i: usize, + closest: &mut [f64], + neighborhood: &mut N, + ) where + T: 'a, + U: CartesianMetric<&'a T>, + N: Neighborhood<&'a T, U>, + { + let (node, children) = slice.split_first().unwrap(); + neighborhood.consider(&node.item); + + let target = neighborhood.target(); + let ti = target.coordinate(i); + let ni = node.item.coordinate(i); + let j = (i + 1) % node.item.dimensions(); + + let (left, right) = children.split_at(node.left_len); + let (near, far) = if ti <= ni { + (left, right) + } else { + (right, left) + }; + + if !near.is_empty() { + Self::recurse(near, j, closest, neighborhood); + } + + if !far.is_empty() { + let saved = closest[i]; + closest[i] = ni; + if neighborhood.contains_distance(target.distance(closest)) { + Self::recurse(far, j, closest, neighborhood); + } + closest[i] = saved; + } + } +} + +/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). +#[derive(Debug)] +pub struct KdTree<T>(Vec<KdNode<T>>); + +impl<T: Cartesian> FromIterator<T> for KdTree<T> { + /// Create a new k-d tree from a set of points. + fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { + let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect(); + KdNode::build(nodes.as_mut_slice(), 0); + Self(nodes) + } +} + +impl<T, U> NearestNeighbors<T, U> for KdTree<T> +where + T: Cartesian, + U: CartesianMetric<T>, +{ + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + if !self.0.is_empty() { + let target = neighborhood.target(); + let dims = target.dimensions(); + let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); + + KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood); + } + + neighborhood + } +} + +/// An iterator that the moves values out of a k-d tree. +#[derive(Debug)] +pub struct IntoIter<T>(std::vec::IntoIter<KdNode<T>>); + +impl<T> Iterator for IntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + self.0.next().map(|n| n.item) + } +} + +impl<T> IntoIterator for KdTree<T> { + type Item = T; + type IntoIter = IntoIter<T>; + + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.0.into_iter()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::{test_nearest_neighbors, Point}; + use crate::metric::SquaredDistance; + + impl Metric<[f64]> for Point { + type Distance = SquaredDistance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + self.0.distance(other) + } + } + + impl Cartesian for Point { + fn dimensions(&self) -> usize { + self.0.dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + self.0.coordinate(i) + } + } + + #[test] + fn test_kd_tree() { + test_nearest_neighbors(KdTree::from_iter); + } +} diff --git a/src/metric/soft.rs b/src/metric/soft.rs new file mode 100644 index 0000000..0d7dcdb --- /dev/null +++ b/src/metric/soft.rs @@ -0,0 +1,282 @@ +//! [Soft deletion](https://en.wiktionary.org/wiki/soft_deletion) for nearest neighbor search. + +use super::forest::{KdForest, VpForest}; +use super::kd::KdTree; +use super::vp::VpTree; +use super::{Metric, NearestNeighbors, Neighborhood}; + +use std::iter; +use std::iter::FromIterator; +use std::mem; + +/// A trait for objects that can be soft-deleted. +pub trait SoftDelete { + /// Check whether this item is deleted. + fn is_deleted(&self) -> bool; +} + +/// Blanket [SoftDelete] implementation for references. +impl<'a, T: SoftDelete> SoftDelete for &'a T { + fn is_deleted(&self) -> bool { + (*self).is_deleted() + } +} + +/// [Neighborhood] wrapper that ignores soft-deleted items. +#[derive(Debug)] +struct SoftNeighborhood<N>(N); + +impl<T, U, N> Neighborhood<T, U> for SoftNeighborhood<N> +where + T: SoftDelete, + U: Metric<T>, + N: Neighborhood<T, U>, +{ + fn target(&self) -> U { + self.0.target() + } + + fn contains(&self, distance: f64) -> bool { + self.0.contains(distance) + } + + fn contains_distance(&self, distance: U::Distance) -> bool { + self.0.contains_distance(distance) + } + + fn consider(&mut self, item: T) -> U::Distance { + if item.is_deleted() { + self.target().distance(&item) + } else { + self.0.consider(item) + } + } +} + +/// A [NearestNeighbors] implementation that supports [soft deletes](https://en.wiktionary.org/wiki/soft_deletion). +#[derive(Debug)] +pub struct SoftSearch<T>(T); + +impl<T, U> SoftSearch<U> +where + T: SoftDelete, + U: FromIterator<T> + IntoIterator<Item = T>, +{ + /// Create a new empty soft index. + pub fn new() -> Self { + Self(iter::empty().collect()) + } + + /// Push a new item into this index. + pub fn push(&mut self, item: T) + where + U: Extend<T>, + { + self.0.extend(iter::once(item)); + } + + /// Rebuild this index, discarding deleted items. + pub fn rebuild(&mut self) { + let items = mem::replace(&mut self.0, iter::empty().collect()); + self.0 = items.into_iter().filter(|e| !e.is_deleted()).collect(); + } +} + +impl<T, U: Extend<T>> Extend<T> for SoftSearch<U> { + fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) { + self.0.extend(iter); + } +} + +impl<T, U: FromIterator<T>> FromIterator<T> for SoftSearch<U> { + fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self { + Self(U::from_iter(iter)) + } +} + +impl<T: IntoIterator> IntoIterator for SoftSearch<T> { + type Item = T::Item; + type IntoIter = T::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<T, U, V> NearestNeighbors<T, U> for SoftSearch<V> +where + T: SoftDelete, + U: Metric<T>, + V: NearestNeighbors<T, U>, +{ + fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + self.0.search(SoftNeighborhood(neighborhood)).0 + } +} + +/// A k-d forest that supports soft deletes. +pub type SoftKdForest<T> = SoftSearch<KdForest<T>>; + +/// A k-d tree that supports soft deletes. +pub type SoftKdTree<T> = SoftSearch<KdTree<T>>; + +/// A VP forest that supports soft deletes. +pub type SoftVpForest<T> = SoftSearch<VpForest<T>>; + +/// A VP tree that supports soft deletes. +pub type SoftVpTree<T> = SoftSearch<VpTree<T>>; + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::kd::Cartesian; + use crate::metric::tests::Point; + use crate::metric::Neighbor; + + #[derive(Debug, PartialEq)] + struct SoftPoint { + point: Point, + deleted: bool, + } + + impl SoftPoint { + fn new(x: f64, y: f64, z: f64) -> Self { + Self { + point: Point([x, y, z]), + deleted: false, + } + } + + fn deleted(x: f64, y: f64, z: f64) -> Self { + Self { + point: Point([x, y, z]), + deleted: true, + } + } + } + + impl SoftDelete for SoftPoint { + fn is_deleted(&self) -> bool { + self.deleted + } + } + + impl Metric for SoftPoint { + type Distance = <Point as Metric>::Distance; + + fn distance(&self, other: &Self) -> Self::Distance { + self.point.distance(&other.point) + } + } + + impl Metric<[f64]> for SoftPoint { + type Distance = <Point as Metric>::Distance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + self.point.distance(other) + } + } + + impl Cartesian for SoftPoint { + fn dimensions(&self) -> usize { + self.point.dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + self.point.coordinate(i) + } + } + + impl Metric<SoftPoint> for Point { + type Distance = <Point as Metric>::Distance; + + fn distance(&self, other: &SoftPoint) -> Self::Distance { + self.distance(&other.point) + } + } + + fn test_index<T>(index: &T) + where + T: NearestNeighbors<SoftPoint, Point>, + { + let target = Point([0.0, 0.0, 0.0]); + + assert_eq!( + index.nearest(&target), + Some(Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0)) + ); + + assert_eq!(index.nearest_within(&target, 2.0), None); + assert_eq!( + index.nearest_within(&target, 4.0), + Some(Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0)) + ); + + assert_eq!( + index.k_nearest(&target, 3), + vec![ + Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), + Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), + Neighbor::new(&SoftPoint::new(2.0, 3.0, 6.0), 7.0), + ] + ); + + assert_eq!( + index.k_nearest_within(&target, 3, 6.0), + vec![ + Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), + Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), + ] + ); + assert_eq!( + index.k_nearest_within(&target, 3, 8.0), + vec![ + Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0), + Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0), + Neighbor::new(&SoftPoint::new(2.0, 3.0, 6.0), 7.0), + ] + ); + } + + fn test_soft_index<T>(index: &mut SoftSearch<T>) + where + T: Extend<SoftPoint>, + T: FromIterator<SoftPoint>, + T: IntoIterator<Item = SoftPoint>, + T: NearestNeighbors<SoftPoint, Point>, + { + let points = vec![ + SoftPoint::deleted(0.0, 0.0, 0.0), + SoftPoint::new(3.0, 4.0, 0.0), + SoftPoint::new(5.0, 0.0, 12.0), + SoftPoint::new(0.0, 8.0, 15.0), + SoftPoint::new(1.0, 2.0, 2.0), + SoftPoint::new(2.0, 3.0, 6.0), + SoftPoint::new(4.0, 4.0, 7.0), + ]; + + for point in points { + index.push(point); + } + test_index(index); + + index.rebuild(); + test_index(index); + } + + #[test] + fn test_soft_kd_forest() { + test_soft_index(&mut SoftKdForest::new()); + } + + #[test] + fn test_soft_vp_forest() { + test_soft_index(&mut SoftVpForest::new()); + } +} diff --git a/src/metric/vp.rs b/src/metric/vp.rs new file mode 100644 index 0000000..fae62e5 --- /dev/null +++ b/src/metric/vp.rs @@ -0,0 +1,137 @@ +//! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +use std::iter::FromIterator; + +/// A node in a VP tree. +#[derive(Debug)] +struct VpNode<T> { + /// The vantage point itself. + item: T, + /// The radius of this node. + radius: f64, + /// The size of the subtree inside the radius. + inside_len: usize, +} + +impl<T: Metric> VpNode<T> { + /// Create a new VpNode. + fn new(item: T) -> Self { + Self { + item, + radius: 0.0, + inside_len: 0, + } + } + + /// Build a VP tree recursively. + fn build(slice: &mut [VpNode<T>]) { + if let Some((node, children)) = slice.split_first_mut() { + let item = &node.item; + children.sort_by_cached_key(|n| item.distance(&n.item)); + + let (inside, outside) = children.split_at_mut(children.len() / 2); + if let Some(last) = inside.last() { + node.radius = item.distance(&last.item).into(); + } + node.inside_len = inside.len(); + + Self::build(inside); + Self::build(outside); + } + } + + /// Recursively search for nearest neighbors. + fn recurse<'a, U, N>(slice: &'a [VpNode<T>], neighborhood: &mut N) + where + T: 'a, + U: Metric<&'a T>, + N: Neighborhood<&'a T, U>, + { + let (node, children) = slice.split_first().unwrap(); + let (inside, outside) = children.split_at(node.inside_len); + + let distance = neighborhood.consider(&node.item).into(); + + if distance <= node.radius { + if !inside.is_empty() && neighborhood.contains(distance - node.radius) { + Self::recurse(inside, neighborhood); + } + if !outside.is_empty() && neighborhood.contains(node.radius - distance) { + Self::recurse(outside, neighborhood); + } + } else { + if !outside.is_empty() && neighborhood.contains(node.radius - distance) { + Self::recurse(outside, neighborhood); + } + if !inside.is_empty() && neighborhood.contains(distance - node.radius) { + Self::recurse(inside, neighborhood); + } + } + } +} + +/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree). +#[derive(Debug)] +pub struct VpTree<T>(Vec<VpNode<T>>); + +impl<T: Metric> FromIterator<T> for VpTree<T> { + fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { + let mut nodes: Vec<_> = items.into_iter().map(VpNode::new).collect(); + VpNode::build(nodes.as_mut_slice()); + Self(nodes) + } +} + +impl<T, U> NearestNeighbors<T, U> for VpTree<T> +where + T: Metric, + U: Metric<T>, +{ + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + if !self.0.is_empty() { + VpNode::recurse(&self.0, &mut neighborhood); + } + + neighborhood + } +} + +/// An iterator that moves values out of a VP tree. +#[derive(Debug)] +pub struct IntoIter<T>(std::vec::IntoIter<VpNode<T>>); + +impl<T> Iterator for IntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + self.0.next().map(|n| n.item) + } +} + +impl<T> IntoIterator for VpTree<T> { + type Item = T; + type IntoIter = IntoIter<T>; + + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.0.into_iter()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::test_nearest_neighbors; + + #[test] + fn test_vp_tree() { + test_nearest_neighbors(VpTree::from_iter); + } +} |