summaryrefslogtreecommitdiffstats
path: root/src/metric
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-05-03 10:55:16 -0400
committerGitHub <noreply@github.com>2020-05-03 10:55:16 -0400
commitce2904b4840611f769b92b55bf6d9b5afe84d3d7 (patch)
treea133319a302f95edf7a7a261262a8f24473bd21c /src/metric
parentd95e93bf70f3351e6fd489284794ef7909fd94ce (diff)
parent2984e8f93fe88d0ee7eb3c0561dcd2da44807429 (diff)
downloadkd-forest-ce2904b4840611f769b92b55bf6d9b5afe84d3d7.tar.xz
Merge pull request #1 from tavianator/rust
Rewrite in rust
Diffstat (limited to 'src/metric')
-rw-r--r--src/metric/approx.rs131
-rw-r--r--src/metric/forest.rs159
-rw-r--r--src/metric/kd.rs226
-rw-r--r--src/metric/soft.rs282
-rw-r--r--src/metric/vp.rs137
5 files changed, 935 insertions, 0 deletions
diff --git a/src/metric/approx.rs b/src/metric/approx.rs
new file mode 100644
index 0000000..c23f9c7
--- /dev/null
+++ b/src/metric/approx.rs
@@ -0,0 +1,131 @@
+//! [Approximate nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor).
+
+use super::{Metric, NearestNeighbors, Neighborhood};
+
+/// An approximate [Neighborhood], for approximate nearest neighbor searches.
+#[derive(Debug)]
+struct ApproximateNeighborhood<N> {
+ inner: N,
+ ratio: f64,
+ limit: usize,
+}
+
+impl<N> ApproximateNeighborhood<N> {
+ fn new(inner: N, ratio: f64, limit: usize) -> Self {
+ Self {
+ inner,
+ ratio,
+ limit,
+ }
+ }
+}
+
+impl<T, U, N> Neighborhood<T, U> for ApproximateNeighborhood<N>
+where
+ U: Metric<T>,
+ N: Neighborhood<T, U>,
+{
+ fn target(&self) -> U {
+ self.inner.target()
+ }
+
+ fn contains(&self, distance: f64) -> bool {
+ if self.limit > 0 {
+ self.inner.contains(self.ratio * distance)
+ } else {
+ false
+ }
+ }
+
+ fn contains_distance(&self, distance: U::Distance) -> bool {
+ self.contains(self.ratio * distance.into())
+ }
+
+ fn consider(&mut self, item: T) -> U::Distance {
+ self.limit = self.limit.saturating_sub(1);
+ self.inner.consider(item)
+ }
+}
+
+/// An [approximate nearest neighbor search](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor)
+/// index.
+///
+/// This wrapper converts an exact nearest neighbor search algorithm into an approximate one by
+/// modifying the behavior of [Neighborhood::contains]. The approximation is controlled by two
+/// parameters:
+///
+/// * `ratio`: The [nearest neighbor distance ratio](https://en.wikipedia.org/wiki/Nearest_neighbor_search#Nearest_neighbor_distance_ratio),
+/// which controls how much closer new candidates must be to be considered. For example, a ratio
+/// of 2.0 means that a neighbor must be less than half of the current threshold away to be
+/// considered. A ratio of 1.0 means an exact search.
+///
+/// * `limit`: A limit on the number of candidates to consider. Typical nearest neighbor algorithms
+/// find a close match quickly, so setting a limit bounds the worst-case search time while keeping
+/// good accuracy.
+#[derive(Debug)]
+pub struct ApproximateSearch<T> {
+ inner: T,
+ ratio: f64,
+ limit: usize,
+}
+
+impl<T> ApproximateSearch<T> {
+ /// Create a new ApproximateSearch index.
+ ///
+ /// * `inner`: The [NearestNeighbors] implementation to wrap.
+ /// * `ratio`: The nearest neighbor distance ratio.
+ /// * `limit`: The maximum number of results to consider.
+ pub fn new(inner: T, ratio: f64, limit: usize) -> Self {
+ Self {
+ inner,
+ ratio,
+ limit,
+ }
+ }
+}
+
+impl<T, U, V> NearestNeighbors<T, U> for ApproximateSearch<V>
+where
+ U: Metric<T>,
+ V: NearestNeighbors<T, U>,
+{
+ fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N
+ where
+ T: 'a,
+ U: 'b,
+ N: Neighborhood<&'a T, &'b U>,
+ {
+ self.inner
+ .search(ApproximateNeighborhood::new(
+ neighborhood,
+ self.ratio,
+ self.limit,
+ ))
+ .inner
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::metric::kd::KdTree;
+ use crate::metric::tests::test_nearest_neighbors;
+ use crate::metric::vp::VpTree;
+
+ use std::iter::FromIterator;
+
+ #[test]
+ fn test_approx_kd_tree() {
+ test_nearest_neighbors(|iter| {
+ ApproximateSearch::new(KdTree::from_iter(iter), 1.0, std::usize::MAX)
+ });
+ }
+
+ #[test]
+ fn test_approx_vp_tree() {
+ test_nearest_neighbors(|iter| {
+ ApproximateSearch::new(VpTree::from_iter(iter), 1.0, std::usize::MAX)
+ });
+ }
+}
diff --git a/src/metric/forest.rs b/src/metric/forest.rs
new file mode 100644
index 0000000..29b6f55
--- /dev/null
+++ b/src/metric/forest.rs
@@ -0,0 +1,159 @@
+//! [Dynamization](https://en.wikipedia.org/wiki/Dynamization) for nearest neighbor search.
+
+use super::kd::KdTree;
+use super::vp::VpTree;
+use super::{Metric, NearestNeighbors, Neighborhood};
+
+use std::iter::{self, Extend, Flatten, FromIterator};
+
+/// A dynamic wrapper for a static nearest neighbor search data structure.
+///
+/// This type applies [dynamization](https://en.wikipedia.org/wiki/Dynamization) to an arbitrary
+/// nearest neighbor search structure `T`, allowing new items to be added dynamically.
+#[derive(Debug)]
+pub struct Forest<T>(Vec<Option<T>>);
+
+impl<T, U> Forest<U>
+where
+ U: FromIterator<T> + IntoIterator<Item = T>,
+{
+ /// Create a new empty forest.
+ pub fn new() -> Self {
+ Self(Vec::new())
+ }
+
+ /// Add a new item to the forest.
+ pub fn push(&mut self, item: T) {
+ self.extend(iter::once(item));
+ }
+
+ /// Get the number of items in the forest.
+ pub fn len(&self) -> usize {
+ let mut len = 0;
+ for (i, slot) in self.0.iter().enumerate() {
+ if slot.is_some() {
+ len |= 1 << i;
+ }
+ }
+ len
+ }
+}
+
+impl<T, U> Extend<T> for Forest<U>
+where
+ U: FromIterator<T> + IntoIterator<Item = T>,
+{
+ fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) {
+ let mut vec: Vec<_> = items.into_iter().collect();
+ let new_len = self.len() + vec.len();
+
+ for i in 0.. {
+ let bit = 1 << i;
+
+ if bit > new_len {
+ break;
+ }
+
+ if i >= self.0.len() {
+ self.0.push(None);
+ }
+
+ if new_len & bit == 0 {
+ if let Some(tree) = self.0[i].take() {
+ vec.extend(tree);
+ }
+ } else if self.0[i].is_none() {
+ let offset = vec.len() - bit;
+ self.0[i] = Some(vec.drain(offset..).collect());
+ }
+ }
+
+ debug_assert!(vec.is_empty());
+ debug_assert!(self.len() == new_len);
+ }
+}
+
+impl<T, U> FromIterator<T> for Forest<U>
+where
+ U: FromIterator<T> + IntoIterator<Item = T>,
+{
+ fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
+ let mut forest = Self::new();
+ forest.extend(items);
+ forest
+ }
+}
+
+type IntoIterImpl<T> = Flatten<Flatten<std::vec::IntoIter<Option<T>>>>;
+
+/// An iterator that moves items out of a forest.
+pub struct IntoIter<T: IntoIterator>(IntoIterImpl<T>);
+
+impl<T: IntoIterator> Iterator for IntoIter<T> {
+ type Item = T::Item;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ self.0.next()
+ }
+}
+
+impl<T: IntoIterator> IntoIterator for Forest<T> {
+ type Item = T::Item;
+ type IntoIter = IntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ IntoIter(self.0.into_iter().flatten().flatten())
+ }
+}
+
+impl<T, U, V> NearestNeighbors<T, U> for Forest<V>
+where
+ U: Metric<T>,
+ V: NearestNeighbors<T, U>,
+{
+ fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N
+ where
+ T: 'a,
+ U: 'b,
+ N: Neighborhood<&'a T, &'b U>,
+ {
+ self.0
+ .iter()
+ .flatten()
+ .fold(neighborhood, |n, t| t.search(n))
+ }
+}
+
+/// A forest of k-d trees.
+pub type KdForest<T> = Forest<KdTree<T>>;
+
+/// A forest of vantage-point trees.
+pub type VpForest<T> = Forest<VpTree<T>>;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::metric::tests::test_nearest_neighbors;
+ use crate::metric::ExhaustiveSearch;
+
+ #[test]
+ fn test_exhaustive_forest() {
+ test_nearest_neighbors(Forest::<ExhaustiveSearch<_>>::from_iter);
+ }
+
+ #[test]
+ fn test_forest_forest() {
+ test_nearest_neighbors(Forest::<Forest<ExhaustiveSearch<_>>>::from_iter);
+ }
+
+ #[test]
+ fn test_kd_forest() {
+ test_nearest_neighbors(KdForest::from_iter);
+ }
+
+ #[test]
+ fn test_vp_forest() {
+ test_nearest_neighbors(VpForest::from_iter);
+ }
+}
diff --git a/src/metric/kd.rs b/src/metric/kd.rs
new file mode 100644
index 0000000..2caf4a3
--- /dev/null
+++ b/src/metric/kd.rs
@@ -0,0 +1,226 @@
+//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree).
+
+use super::{Metric, NearestNeighbors, Neighborhood};
+
+use ordered_float::OrderedFloat;
+
+use std::iter::FromIterator;
+
+/// A point in Cartesian space.
+pub trait Cartesian: Metric<[f64]> {
+ /// Returns the number of dimensions necessary to describe this point.
+ fn dimensions(&self) -> usize;
+
+ /// Returns the value of the `i`th coordinate of this point (`i < self.dimensions()`).
+ fn coordinate(&self, i: usize) -> f64;
+}
+
+/// Blanket [Cartesian] implementation for references.
+impl<'a, T: Cartesian> Cartesian for &'a T {
+ fn dimensions(&self) -> usize {
+ (*self).dimensions()
+ }
+
+ fn coordinate(&self, i: usize) -> f64 {
+ (*self).coordinate(i)
+ }
+}
+
+/// Blanket [Metric<[f64]>](Metric) implementation for [Cartesian] references.
+impl<'a, T: Cartesian> Metric<[f64]> for &'a T {
+ type Distance = T::Distance;
+
+ fn distance(&self, other: &[f64]) -> Self::Distance {
+ (*self).distance(other)
+ }
+}
+
+/// Standard cartesian space.
+impl Cartesian for [f64] {
+ fn dimensions(&self) -> usize {
+ self.len()
+ }
+
+ fn coordinate(&self, i: usize) -> f64 {
+ self[i]
+ }
+}
+
+/// Marker trait for cartesian metric spaces.
+pub trait CartesianMetric<T: ?Sized = Self>:
+ Cartesian + Metric<T, Distance = <Self as Metric<[f64]>>::Distance>
+{
+}
+
+/// Blanket [CartesianMetric] implementation for cartesian spaces with compatible metric distance
+/// types.
+impl<T, U> CartesianMetric<T> for U
+where
+ T: ?Sized,
+ U: ?Sized + Cartesian + Metric<T, Distance = <U as Metric<[f64]>>::Distance>,
+{
+}
+
+/// A node in a k-d tree.
+#[derive(Debug)]
+struct KdNode<T> {
+ /// The value stored in this node.
+ item: T,
+ /// The size of the left subtree.
+ left_len: usize,
+}
+
+impl<T: Cartesian> KdNode<T> {
+ /// Create a new KdNode.
+ fn new(item: T) -> Self {
+ Self { item, left_len: 0 }
+ }
+
+ /// Build a k-d tree recursively.
+ fn build(slice: &mut [KdNode<T>], i: usize) {
+ if slice.is_empty() {
+ return;
+ }
+
+ slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i)));
+
+ let mid = slice.len() / 2;
+ slice.swap(0, mid);
+
+ let (node, children) = slice.split_first_mut().unwrap();
+ let (left, right) = children.split_at_mut(mid);
+ node.left_len = left.len();
+
+ let j = (i + 1) % node.item.dimensions();
+ Self::build(left, j);
+ Self::build(right, j);
+ }
+
+ /// Recursively search for nearest neighbors.
+ fn recurse<'a, U, N>(
+ slice: &'a [KdNode<T>],
+ i: usize,
+ closest: &mut [f64],
+ neighborhood: &mut N,
+ ) where
+ T: 'a,
+ U: CartesianMetric<&'a T>,
+ N: Neighborhood<&'a T, U>,
+ {
+ let (node, children) = slice.split_first().unwrap();
+ neighborhood.consider(&node.item);
+
+ let target = neighborhood.target();
+ let ti = target.coordinate(i);
+ let ni = node.item.coordinate(i);
+ let j = (i + 1) % node.item.dimensions();
+
+ let (left, right) = children.split_at(node.left_len);
+ let (near, far) = if ti <= ni {
+ (left, right)
+ } else {
+ (right, left)
+ };
+
+ if !near.is_empty() {
+ Self::recurse(near, j, closest, neighborhood);
+ }
+
+ if !far.is_empty() {
+ let saved = closest[i];
+ closest[i] = ni;
+ if neighborhood.contains_distance(target.distance(closest)) {
+ Self::recurse(far, j, closest, neighborhood);
+ }
+ closest[i] = saved;
+ }
+ }
+}
+
+/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree).
+#[derive(Debug)]
+pub struct KdTree<T>(Vec<KdNode<T>>);
+
+impl<T: Cartesian> FromIterator<T> for KdTree<T> {
+ /// Create a new k-d tree from a set of points.
+ fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
+ let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect();
+ KdNode::build(nodes.as_mut_slice(), 0);
+ Self(nodes)
+ }
+}
+
+impl<T, U> NearestNeighbors<T, U> for KdTree<T>
+where
+ T: Cartesian,
+ U: CartesianMetric<T>,
+{
+ fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N
+ where
+ T: 'a,
+ U: 'b,
+ N: Neighborhood<&'a T, &'b U>,
+ {
+ if !self.0.is_empty() {
+ let target = neighborhood.target();
+ let dims = target.dimensions();
+ let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect();
+
+ KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood);
+ }
+
+ neighborhood
+ }
+}
+
+/// An iterator that the moves values out of a k-d tree.
+#[derive(Debug)]
+pub struct IntoIter<T>(std::vec::IntoIter<KdNode<T>>);
+
+impl<T> Iterator for IntoIter<T> {
+ type Item = T;
+
+ fn next(&mut self) -> Option<T> {
+ self.0.next().map(|n| n.item)
+ }
+}
+
+impl<T> IntoIterator for KdTree<T> {
+ type Item = T;
+ type IntoIter = IntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ IntoIter(self.0.into_iter())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::metric::tests::{test_nearest_neighbors, Point};
+ use crate::metric::SquaredDistance;
+
+ impl Metric<[f64]> for Point {
+ type Distance = SquaredDistance;
+
+ fn distance(&self, other: &[f64]) -> Self::Distance {
+ self.0.distance(other)
+ }
+ }
+
+ impl Cartesian for Point {
+ fn dimensions(&self) -> usize {
+ self.0.dimensions()
+ }
+
+ fn coordinate(&self, i: usize) -> f64 {
+ self.0.coordinate(i)
+ }
+ }
+
+ #[test]
+ fn test_kd_tree() {
+ test_nearest_neighbors(KdTree::from_iter);
+ }
+}
diff --git a/src/metric/soft.rs b/src/metric/soft.rs
new file mode 100644
index 0000000..0d7dcdb
--- /dev/null
+++ b/src/metric/soft.rs
@@ -0,0 +1,282 @@
+//! [Soft deletion](https://en.wiktionary.org/wiki/soft_deletion) for nearest neighbor search.
+
+use super::forest::{KdForest, VpForest};
+use super::kd::KdTree;
+use super::vp::VpTree;
+use super::{Metric, NearestNeighbors, Neighborhood};
+
+use std::iter;
+use std::iter::FromIterator;
+use std::mem;
+
+/// A trait for objects that can be soft-deleted.
+pub trait SoftDelete {
+ /// Check whether this item is deleted.
+ fn is_deleted(&self) -> bool;
+}
+
+/// Blanket [SoftDelete] implementation for references.
+impl<'a, T: SoftDelete> SoftDelete for &'a T {
+ fn is_deleted(&self) -> bool {
+ (*self).is_deleted()
+ }
+}
+
+/// [Neighborhood] wrapper that ignores soft-deleted items.
+#[derive(Debug)]
+struct SoftNeighborhood<N>(N);
+
+impl<T, U, N> Neighborhood<T, U> for SoftNeighborhood<N>
+where
+ T: SoftDelete,
+ U: Metric<T>,
+ N: Neighborhood<T, U>,
+{
+ fn target(&self) -> U {
+ self.0.target()
+ }
+
+ fn contains(&self, distance: f64) -> bool {
+ self.0.contains(distance)
+ }
+
+ fn contains_distance(&self, distance: U::Distance) -> bool {
+ self.0.contains_distance(distance)
+ }
+
+ fn consider(&mut self, item: T) -> U::Distance {
+ if item.is_deleted() {
+ self.target().distance(&item)
+ } else {
+ self.0.consider(item)
+ }
+ }
+}
+
+/// A [NearestNeighbors] implementation that supports [soft deletes](https://en.wiktionary.org/wiki/soft_deletion).
+#[derive(Debug)]
+pub struct SoftSearch<T>(T);
+
+impl<T, U> SoftSearch<U>
+where
+ T: SoftDelete,
+ U: FromIterator<T> + IntoIterator<Item = T>,
+{
+ /// Create a new empty soft index.
+ pub fn new() -> Self {
+ Self(iter::empty().collect())
+ }
+
+ /// Push a new item into this index.
+ pub fn push(&mut self, item: T)
+ where
+ U: Extend<T>,
+ {
+ self.0.extend(iter::once(item));
+ }
+
+ /// Rebuild this index, discarding deleted items.
+ pub fn rebuild(&mut self) {
+ let items = mem::replace(&mut self.0, iter::empty().collect());
+ self.0 = items.into_iter().filter(|e| !e.is_deleted()).collect();
+ }
+}
+
+impl<T, U: Extend<T>> Extend<T> for SoftSearch<U> {
+ fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
+ self.0.extend(iter);
+ }
+}
+
+impl<T, U: FromIterator<T>> FromIterator<T> for SoftSearch<U> {
+ fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
+ Self(U::from_iter(iter))
+ }
+}
+
+impl<T: IntoIterator> IntoIterator for SoftSearch<T> {
+ type Item = T::Item;
+ type IntoIter = T::IntoIter;
+
+ fn into_iter(self) -> Self::IntoIter {
+ self.0.into_iter()
+ }
+}
+
+impl<T, U, V> NearestNeighbors<T, U> for SoftSearch<V>
+where
+ T: SoftDelete,
+ U: Metric<T>,
+ V: NearestNeighbors<T, U>,
+{
+ fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N
+ where
+ T: 'a,
+ U: 'b,
+ N: Neighborhood<&'a T, &'b U>,
+ {
+ self.0.search(SoftNeighborhood(neighborhood)).0
+ }
+}
+
+/// A k-d forest that supports soft deletes.
+pub type SoftKdForest<T> = SoftSearch<KdForest<T>>;
+
+/// A k-d tree that supports soft deletes.
+pub type SoftKdTree<T> = SoftSearch<KdTree<T>>;
+
+/// A VP forest that supports soft deletes.
+pub type SoftVpForest<T> = SoftSearch<VpForest<T>>;
+
+/// A VP tree that supports soft deletes.
+pub type SoftVpTree<T> = SoftSearch<VpTree<T>>;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::metric::kd::Cartesian;
+ use crate::metric::tests::Point;
+ use crate::metric::Neighbor;
+
+ #[derive(Debug, PartialEq)]
+ struct SoftPoint {
+ point: Point,
+ deleted: bool,
+ }
+
+ impl SoftPoint {
+ fn new(x: f64, y: f64, z: f64) -> Self {
+ Self {
+ point: Point([x, y, z]),
+ deleted: false,
+ }
+ }
+
+ fn deleted(x: f64, y: f64, z: f64) -> Self {
+ Self {
+ point: Point([x, y, z]),
+ deleted: true,
+ }
+ }
+ }
+
+ impl SoftDelete for SoftPoint {
+ fn is_deleted(&self) -> bool {
+ self.deleted
+ }
+ }
+
+ impl Metric for SoftPoint {
+ type Distance = <Point as Metric>::Distance;
+
+ fn distance(&self, other: &Self) -> Self::Distance {
+ self.point.distance(&other.point)
+ }
+ }
+
+ impl Metric<[f64]> for SoftPoint {
+ type Distance = <Point as Metric>::Distance;
+
+ fn distance(&self, other: &[f64]) -> Self::Distance {
+ self.point.distance(other)
+ }
+ }
+
+ impl Cartesian for SoftPoint {
+ fn dimensions(&self) -> usize {
+ self.point.dimensions()
+ }
+
+ fn coordinate(&self, i: usize) -> f64 {
+ self.point.coordinate(i)
+ }
+ }
+
+ impl Metric<SoftPoint> for Point {
+ type Distance = <Point as Metric>::Distance;
+
+ fn distance(&self, other: &SoftPoint) -> Self::Distance {
+ self.distance(&other.point)
+ }
+ }
+
+ fn test_index<T>(index: &T)
+ where
+ T: NearestNeighbors<SoftPoint, Point>,
+ {
+ let target = Point([0.0, 0.0, 0.0]);
+
+ assert_eq!(
+ index.nearest(&target),
+ Some(Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0))
+ );
+
+ assert_eq!(index.nearest_within(&target, 2.0), None);
+ assert_eq!(
+ index.nearest_within(&target, 4.0),
+ Some(Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0))
+ );
+
+ assert_eq!(
+ index.k_nearest(&target, 3),
+ vec![
+ Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0),
+ Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0),
+ Neighbor::new(&SoftPoint::new(2.0, 3.0, 6.0), 7.0),
+ ]
+ );
+
+ assert_eq!(
+ index.k_nearest_within(&target, 3, 6.0),
+ vec![
+ Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0),
+ Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0),
+ ]
+ );
+ assert_eq!(
+ index.k_nearest_within(&target, 3, 8.0),
+ vec![
+ Neighbor::new(&SoftPoint::new(1.0, 2.0, 2.0), 3.0),
+ Neighbor::new(&SoftPoint::new(3.0, 4.0, 0.0), 5.0),
+ Neighbor::new(&SoftPoint::new(2.0, 3.0, 6.0), 7.0),
+ ]
+ );
+ }
+
+ fn test_soft_index<T>(index: &mut SoftSearch<T>)
+ where
+ T: Extend<SoftPoint>,
+ T: FromIterator<SoftPoint>,
+ T: IntoIterator<Item = SoftPoint>,
+ T: NearestNeighbors<SoftPoint, Point>,
+ {
+ let points = vec![
+ SoftPoint::deleted(0.0, 0.0, 0.0),
+ SoftPoint::new(3.0, 4.0, 0.0),
+ SoftPoint::new(5.0, 0.0, 12.0),
+ SoftPoint::new(0.0, 8.0, 15.0),
+ SoftPoint::new(1.0, 2.0, 2.0),
+ SoftPoint::new(2.0, 3.0, 6.0),
+ SoftPoint::new(4.0, 4.0, 7.0),
+ ];
+
+ for point in points {
+ index.push(point);
+ }
+ test_index(index);
+
+ index.rebuild();
+ test_index(index);
+ }
+
+ #[test]
+ fn test_soft_kd_forest() {
+ test_soft_index(&mut SoftKdForest::new());
+ }
+
+ #[test]
+ fn test_soft_vp_forest() {
+ test_soft_index(&mut SoftVpForest::new());
+ }
+}
diff --git a/src/metric/vp.rs b/src/metric/vp.rs
new file mode 100644
index 0000000..fae62e5
--- /dev/null
+++ b/src/metric/vp.rs
@@ -0,0 +1,137 @@
+//! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree).
+
+use super::{Metric, NearestNeighbors, Neighborhood};
+
+use std::iter::FromIterator;
+
+/// A node in a VP tree.
+#[derive(Debug)]
+struct VpNode<T> {
+ /// The vantage point itself.
+ item: T,
+ /// The radius of this node.
+ radius: f64,
+ /// The size of the subtree inside the radius.
+ inside_len: usize,
+}
+
+impl<T: Metric> VpNode<T> {
+ /// Create a new VpNode.
+ fn new(item: T) -> Self {
+ Self {
+ item,
+ radius: 0.0,
+ inside_len: 0,
+ }
+ }
+
+ /// Build a VP tree recursively.
+ fn build(slice: &mut [VpNode<T>]) {
+ if let Some((node, children)) = slice.split_first_mut() {
+ let item = &node.item;
+ children.sort_by_cached_key(|n| item.distance(&n.item));
+
+ let (inside, outside) = children.split_at_mut(children.len() / 2);
+ if let Some(last) = inside.last() {
+ node.radius = item.distance(&last.item).into();
+ }
+ node.inside_len = inside.len();
+
+ Self::build(inside);
+ Self::build(outside);
+ }
+ }
+
+ /// Recursively search for nearest neighbors.
+ fn recurse<'a, U, N>(slice: &'a [VpNode<T>], neighborhood: &mut N)
+ where
+ T: 'a,
+ U: Metric<&'a T>,
+ N: Neighborhood<&'a T, U>,
+ {
+ let (node, children) = slice.split_first().unwrap();
+ let (inside, outside) = children.split_at(node.inside_len);
+
+ let distance = neighborhood.consider(&node.item).into();
+
+ if distance <= node.radius {
+ if !inside.is_empty() && neighborhood.contains(distance - node.radius) {
+ Self::recurse(inside, neighborhood);
+ }
+ if !outside.is_empty() && neighborhood.contains(node.radius - distance) {
+ Self::recurse(outside, neighborhood);
+ }
+ } else {
+ if !outside.is_empty() && neighborhood.contains(node.radius - distance) {
+ Self::recurse(outside, neighborhood);
+ }
+ if !inside.is_empty() && neighborhood.contains(distance - node.radius) {
+ Self::recurse(inside, neighborhood);
+ }
+ }
+ }
+}
+
+/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree).
+#[derive(Debug)]
+pub struct VpTree<T>(Vec<VpNode<T>>);
+
+impl<T: Metric> FromIterator<T> for VpTree<T> {
+ fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
+ let mut nodes: Vec<_> = items.into_iter().map(VpNode::new).collect();
+ VpNode::build(nodes.as_mut_slice());
+ Self(nodes)
+ }
+}
+
+impl<T, U> NearestNeighbors<T, U> for VpTree<T>
+where
+ T: Metric,
+ U: Metric<T>,
+{
+ fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N
+ where
+ T: 'a,
+ U: 'b,
+ N: Neighborhood<&'a T, &'b U>,
+ {
+ if !self.0.is_empty() {
+ VpNode::recurse(&self.0, &mut neighborhood);
+ }
+
+ neighborhood
+ }
+}
+
+/// An iterator that moves values out of a VP tree.
+#[derive(Debug)]
+pub struct IntoIter<T>(std::vec::IntoIter<VpNode<T>>);
+
+impl<T> Iterator for IntoIter<T> {
+ type Item = T;
+
+ fn next(&mut self) -> Option<T> {
+ self.0.next().map(|n| n.item)
+ }
+}
+
+impl<T> IntoIterator for VpTree<T> {
+ type Item = T;
+ type IntoIter = IntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ IntoIter(self.0.into_iter())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::metric::tests::test_nearest_neighbors;
+
+ #[test]
+ fn test_vp_tree() {
+ test_nearest_neighbors(VpTree::from_iter);
+ }
+}