From 8d9de0e1028daed981246174182a39dd917b72bc Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sun, 19 Apr 2020 16:38:24 -0400 Subject: metric/vp: Implement vantage-point trees --- src/metric.rs | 2 + src/metric/vp.rs | 168 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+) create mode 100644 src/metric/vp.rs diff --git a/src/metric.rs b/src/metric.rs index e067771..549db67 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -1,5 +1,7 @@ //! [Metric spaces](https://en.wikipedia.org/wiki/Metric_space). +pub mod vp; + use ordered_float::OrderedFloat; use std::cmp::Ordering; diff --git a/src/metric/vp.rs b/src/metric/vp.rs new file mode 100644 index 0000000..8d5b091 --- /dev/null +++ b/src/metric/vp.rs @@ -0,0 +1,168 @@ +//! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +use std::iter::FromIterator; + +/// A node in a VP tree. +#[derive(Debug)] +struct VpNode { + /// The vantage point itself. + item: T, + /// The radius of this node. + radius: f64, + /// The subtree inside the radius, if any. + inside: Option>, + /// The subtree outside the radius, if any. + outside: Option>, +} + +impl VpNode { + /// Create a new VpNode. + fn new(mut items: Vec) -> Option> { + if items.is_empty() { + return None; + } + + let item = items.pop().unwrap(); + + items.sort_by_cached_key(|a| item.distance(a)); + + let mid = items.len() / 2; + let outside: Vec = items.drain(mid..).collect(); + + let radius = items.last().map(|l| item.distance(l).into()).unwrap_or(0.0); + + Some(Box::new(Self { + item, + radius, + inside: Self::new(items), + outside: Self::new(outside), + })) + } +} + +trait VpSearch<'a, T, U, N> { + /// Recursively search for nearest neighbors. + fn search(&'a self, neighborhood: &mut N); + + /// Search the inside subtree. + fn search_inside(&'a self, distance: f64, neighborhood: &mut N); + + /// Search the outside subtree. + fn search_outside(&'a self, distance: f64, neighborhood: &mut N); +} + +impl<'a, T, U, N> VpSearch<'a, T, U, N> for VpNode +where + T: 'a, + U: Metric<&'a T>, + N: Neighborhood<&'a T, U>, +{ + fn search(&'a self, neighborhood: &mut N) { + let distance = neighborhood.consider(&self.item).into(); + + if distance <= self.radius { + self.search_inside(distance, neighborhood); + self.search_outside(distance, neighborhood); + } else { + self.search_outside(distance, neighborhood); + self.search_inside(distance, neighborhood); + } + } + + fn search_inside(&'a self, distance: f64, neighborhood: &mut N) { + if let Some(inside) = &self.inside { + if neighborhood.contains(distance - self.radius) { + inside.search(neighborhood); + } + } + } + + fn search_outside(&'a self, distance: f64, neighborhood: &mut N) { + if let Some(outside) = &self.outside { + if neighborhood.contains(self.radius - distance) { + outside.search(neighborhood); + } + } + } +} + +/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree). +#[derive(Debug)] +pub struct VpTree { + root: Option>>, +} + +impl FromIterator for VpTree { + fn from_iter>(items: I) -> Self { + Self { + root: VpNode::new(items.into_iter().collect::>()), + } + } +} + +impl NearestNeighbors for VpTree +where + T: Metric, + U: Metric, +{ + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + if let Some(root) = &self.root { + root.search(&mut neighborhood); + } + neighborhood + } +} + +/// An iterator that moves values out of a VP tree. +#[derive(Debug)] +pub struct IntoIter { + stack: Vec>>, +} + +impl IntoIter { + fn new(node: Option>>) -> Self { + Self { + stack: node.into_iter().collect(), + } + } +} + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + self.stack.pop().map(|node| { + self.stack.extend(node.inside); + self.stack.extend(node.outside); + node.item + }) + } +} + +impl IntoIterator for VpTree { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.root) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::test_nearest_neighbors; + + #[test] + fn test_vp_tree() { + test_nearest_neighbors(VpTree::from_iter); + } +} -- cgit v1.2.3