summaryrefslogtreecommitdiffstats
path: root/src/metric/vp.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/metric/vp.rs')
-rw-r--r--src/metric/vp.rs137
1 files changed, 137 insertions, 0 deletions
diff --git a/src/metric/vp.rs b/src/metric/vp.rs
new file mode 100644
index 0000000..fae62e5
--- /dev/null
+++ b/src/metric/vp.rs
@@ -0,0 +1,137 @@
+//! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree).
+
+use super::{Metric, NearestNeighbors, Neighborhood};
+
+use std::iter::FromIterator;
+
+/// A node in a VP tree.
+#[derive(Debug)]
+struct VpNode<T> {
+ /// The vantage point itself.
+ item: T,
+ /// The radius of this node.
+ radius: f64,
+ /// The size of the subtree inside the radius.
+ inside_len: usize,
+}
+
+impl<T: Metric> VpNode<T> {
+ /// Create a new VpNode.
+ fn new(item: T) -> Self {
+ Self {
+ item,
+ radius: 0.0,
+ inside_len: 0,
+ }
+ }
+
+ /// Build a VP tree recursively.
+ fn build(slice: &mut [VpNode<T>]) {
+ if let Some((node, children)) = slice.split_first_mut() {
+ let item = &node.item;
+ children.sort_by_cached_key(|n| item.distance(&n.item));
+
+ let (inside, outside) = children.split_at_mut(children.len() / 2);
+ if let Some(last) = inside.last() {
+ node.radius = item.distance(&last.item).into();
+ }
+ node.inside_len = inside.len();
+
+ Self::build(inside);
+ Self::build(outside);
+ }
+ }
+
+ /// Recursively search for nearest neighbors.
+ fn recurse<'a, U, N>(slice: &'a [VpNode<T>], neighborhood: &mut N)
+ where
+ T: 'a,
+ U: Metric<&'a T>,
+ N: Neighborhood<&'a T, U>,
+ {
+ let (node, children) = slice.split_first().unwrap();
+ let (inside, outside) = children.split_at(node.inside_len);
+
+ let distance = neighborhood.consider(&node.item).into();
+
+ if distance <= node.radius {
+ if !inside.is_empty() && neighborhood.contains(distance - node.radius) {
+ Self::recurse(inside, neighborhood);
+ }
+ if !outside.is_empty() && neighborhood.contains(node.radius - distance) {
+ Self::recurse(outside, neighborhood);
+ }
+ } else {
+ if !outside.is_empty() && neighborhood.contains(node.radius - distance) {
+ Self::recurse(outside, neighborhood);
+ }
+ if !inside.is_empty() && neighborhood.contains(distance - node.radius) {
+ Self::recurse(inside, neighborhood);
+ }
+ }
+ }
+}
+
+/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree).
+#[derive(Debug)]
+pub struct VpTree<T>(Vec<VpNode<T>>);
+
+impl<T: Metric> FromIterator<T> for VpTree<T> {
+ fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
+ let mut nodes: Vec<_> = items.into_iter().map(VpNode::new).collect();
+ VpNode::build(nodes.as_mut_slice());
+ Self(nodes)
+ }
+}
+
+impl<T, U> NearestNeighbors<T, U> for VpTree<T>
+where
+ T: Metric,
+ U: Metric<T>,
+{
+ fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N
+ where
+ T: 'a,
+ U: 'b,
+ N: Neighborhood<&'a T, &'b U>,
+ {
+ if !self.0.is_empty() {
+ VpNode::recurse(&self.0, &mut neighborhood);
+ }
+
+ neighborhood
+ }
+}
+
+/// An iterator that moves values out of a VP tree.
+#[derive(Debug)]
+pub struct IntoIter<T>(std::vec::IntoIter<VpNode<T>>);
+
+impl<T> Iterator for IntoIter<T> {
+ type Item = T;
+
+ fn next(&mut self) -> Option<T> {
+ self.0.next().map(|n| n.item)
+ }
+}
+
+impl<T> IntoIterator for VpTree<T> {
+ type Item = T;
+ type IntoIter = IntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ IntoIter(self.0.into_iter())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::metric::tests::test_nearest_neighbors;
+
+ #[test]
+ fn test_vp_tree() {
+ test_nearest_neighbors(VpTree::from_iter);
+ }
+}