//! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree). use crate::distance::{Distance, DistanceValue, Metric, Proximity}; use crate::knn::{ExactNeighbors, NearestNeighbors, Neighborhood}; use crate::util::Ordered; use num_traits::zero; use std::fmt::{self, Debug, Formatter}; /// A node in a VP tree. #[derive(Debug)] struct VpNode> { /// The vantage point itself. item: T, /// The radius of this node. radius: R, /// The subtree inside the radius, if any. inside: Option>, /// The subtree outside the radius, if any. outside: Option>, } impl VpNode { /// Create a new VpNode. fn new(item: T) -> Self { Self { item, radius: zero(), inside: None, outside: None, } } /// Create a balanced tree. fn balanced>(items: I) -> Option { let mut nodes: Vec<_> = items .into_iter() .map(Self::new) .map(Box::new) .map(Some) .collect(); Self::balanced_recursive(&mut nodes) .map(|node| *node) } /// Create a balanced subtree. fn balanced_recursive(nodes: &mut [Option>]) -> Option> { if let Some((node, children)) = nodes.split_first_mut() { let mut node = node.take().unwrap(); children.sort_by_cached_key(|x| Ordered::new(node.distance_to_box(x))); let (inside, outside) = children.split_at_mut(children.len() / 2); if let Some(last) = inside.last() { node.radius = node.distance_to_box(last).value(); } node.inside = Self::balanced_recursive(inside); node.outside = Self::balanced_recursive(outside); Some(node) } else { None } } /// Get the distance between to boxed nodes. fn distance_to_box(&self, child: &Option>) -> T::Distance { self.item.distance(&child.as_ref().unwrap().item) } /// Push a new item into this subtree. fn push(&mut self, item: T) { match (&mut self.inside, &mut self.outside) { (None, None) => { self.outside = Some(Box::new(Self::new(item))); } (Some(inside), Some(outside)) => { if self.item.distance(&item) <= self.radius { inside.push(item); } else { outside.push(item); } } _ => { let node = Box::new(Self::new(item)); let other = self.inside.take().xor(self.outside.take()).unwrap(); let r1 = self.item.distance(&node.item); let r2 = self.item.distance(&other.item); if r1 <= r2 { self.radius = r2.into(); self.inside = Some(node); self.outside = Some(other); } else { self.radius = r1.into(); self.inside = Some(other); self.outside = Some(node); } } } } } trait VpSearch: Copy where K: Proximity, V: Proximity, N: Neighborhood, { /// Get the vantage point of this node. fn item(self) -> V; /// Get the radius of this node. fn radius(self) -> DistanceValue; /// Get the inside subtree. fn inside(self) -> Option; /// Get the outside subtree. fn outside(self) -> Option; /// Recursively search for nearest neighbors. fn search(self, neighborhood: &mut N) { let distance = neighborhood.consider(self.item()).into(); if distance <= self.radius() { self.search_inside(distance, neighborhood); self.search_outside(distance, neighborhood); } else { self.search_outside(distance, neighborhood); self.search_inside(distance, neighborhood); } } /// Search the inside subtree. fn search_inside(self, distance: DistanceValue, neighborhood: &mut N) { if let Some(inside) = self.inside() { if neighborhood.contains(distance - self.radius()) { inside.search(neighborhood); } } } /// Search the outside subtree. fn search_outside(self, distance: DistanceValue, neighborhood: &mut N) { if let Some(outside) = self.outside() { if neighborhood.contains(self.radius() - distance) { outside.search(neighborhood); } } } } impl<'a, K, V, N> VpSearch for &'a VpNode where K: Proximity<&'a V, Distance = V::Distance>, V: Proximity, N: Neighborhood, { fn item(self) -> &'a V { &self.item } fn radius(self) -> DistanceValue { self.radius } fn inside(self) -> Option { self.inside.as_deref() } fn outside(self) -> Option { self.outside.as_deref() } } /// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree). pub struct VpTree { root: Option>, } impl VpTree { /// Create an empty tree. pub fn new() -> Self { Self { root: None } } /// Create a balanced tree out of a sequence of items. pub fn balanced>(items: I) -> Self { Self { root: VpNode::balanced(items), } } /// Iterate over the items stored in this tree. pub fn iter(&self) -> Iter<'_, T> { self.into_iter() } /// Rebalance this VP tree. pub fn balance(&mut self) { let mut nodes = Vec::new(); if let Some(root) = self.root.take() { nodes.push(Some(Box::new(root))); } let mut i = 0; while i < nodes.len() { let node = nodes[i].as_mut().unwrap(); let inside = node.inside.take(); let outside = node.outside.take(); if inside.is_some() { nodes.push(inside); } if outside.is_some() { nodes.push(outside); } i += 1; } self.root = VpNode::balanced_recursive(&mut nodes) .map(|node| *node); } /// Push a new item into the tree. /// /// Inserting elements individually tends to unbalance the tree. Use [VpTree::balanced] if /// possible to create a balanced tree from a batch of items. pub fn push(&mut self, item: T) { if let Some(root) = &mut self.root { root.push(item); } else { self.root = Some(VpNode::new(item)); } } } // Can't derive(Debug) due to https://github.com/rust-lang/rust/issues/26925 impl Debug for VpTree where T: Proximity + Debug, DistanceValue: Debug, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("VpTree") .field("root", &self.root) .finish() } } impl Default for VpTree { fn default() -> Self { Self::new() } } impl Extend for VpTree { fn extend>(&mut self, items: I) { if self.root.is_some() { for item in items { self.push(item); } } else { self.root = VpNode::balanced(items); } } } impl FromIterator for VpTree { fn from_iter>(items: I) -> Self { Self::balanced(items) } } /// An iterator that moves values out of a VP tree. pub struct IntoIter { stack: Vec>, } impl IntoIter { fn new(node: Option>) -> Self { Self { stack: node.into_iter().collect(), } } } impl Debug for IntoIter where T: Proximity + Debug, DistanceValue: Debug, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("IntoIter") .field("stack", &self.stack) .finish() } } impl Iterator for IntoIter { type Item = T; fn next(&mut self) -> Option { self.stack.pop().map(|node| { if let Some(inside) = node.inside { self.stack.push(*inside); } if let Some(outside) = node.outside { self.stack.push(*outside); } node.item }) } } impl IntoIterator for VpTree { type Item = T; type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { IntoIter::new(self.root) } } /// An iterator over the values in a VP tree. pub struct Iter<'a, T: Proximity> { stack: Vec<&'a VpNode>, } impl<'a, T: Proximity> Iter<'a, T> { fn new(node: &'a Option>) -> Self { Self { stack: node.as_ref().into_iter().collect(), } } } impl<'a, T> Debug for Iter<'a, T> where T: Proximity + Debug, DistanceValue: Debug, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Iter") .field("stack", &self.stack) .finish() } } impl<'a, T: Proximity> Iterator for Iter<'a, T> { type Item = &'a T; fn next(&mut self) -> Option<&'a T> { self.stack.pop().map(|node| { if let Some(inside) = &node.inside { self.stack.push(inside); } if let Some(outside) = &node.outside { self.stack.push(outside); } &node.item }) } } impl<'a, T: Proximity> IntoIterator for &'a VpTree { type Item = &'a T; type IntoIter = Iter<'a, T>; fn into_iter(self) -> Self::IntoIter { Iter::new(&self.root) } } impl NearestNeighbors for VpTree where K: Proximity, V: Proximity, { fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N where K: 'k, V: 'v, N: Neighborhood<&'k K, &'v V>, { if let Some(root) = &self.root { root.search(&mut neighborhood); } neighborhood } } impl ExactNeighbors for VpTree where K: Metric, V: Metric, {} /// A node in a flat VP tree. #[derive(Debug)] struct FlatVpNode> { /// The vantage point itself. item: T, /// The radius of this node. radius: R, /// The size of the inside subtree. inside_len: usize, } impl FlatVpNode { /// Create a new FlatVpNode. fn new(item: T) -> Self { Self { item, radius: zero(), inside_len: 0, } } /// Create a balanced tree. fn balanced>(items: I) -> Vec { let mut nodes: Vec<_> = items .into_iter() .map(Self::new) .collect(); Self::balance_recursive(&mut nodes); nodes } /// Create a balanced subtree. fn balance_recursive(nodes: &mut [Self]) { if let Some((node, children)) = nodes.split_first_mut() { children.sort_by_cached_key(|x| Ordered::new(node.item.distance(&x.item))); let (inside, outside) = children.split_at_mut(children.len() / 2); if let Some(last) = inside.last() { node.radius = node.item.distance(&last.item).into(); } node.inside_len = inside.len(); Self::balance_recursive(inside); Self::balance_recursive(outside); } } } impl<'a, K, V, N> VpSearch for &'a [FlatVpNode] where K: Proximity<&'a V, Distance = V::Distance>, V: Proximity, N: Neighborhood, { fn item(self) -> &'a V { &self[0].item } fn radius(self) -> DistanceValue { self[0].radius } fn inside(self) -> Option { let end = self[0].inside_len + 1; if end > 1 { Some(&self[1..end]) } else { None } } fn outside(self) -> Option { let start = self[0].inside_len + 1; if start < self.len() { Some(&self[start..]) } else { None } } } /// A [vantage-point tree] stored as a flat array. /// /// A FlatVpTree is always balanced and usually more efficient than a [VpTree], but doesn't support /// dynamic updates. /// /// [vantage-point tree]: https://en.wikipedia.org/wiki/Vantage-point_tree pub struct FlatVpTree { nodes: Vec>, } impl FlatVpTree { /// Create a balanced tree out of a sequence of items. pub fn balanced>(items: I) -> Self { Self { nodes: FlatVpNode::balanced(items), } } /// Iterate over the items stored in this tree. pub fn iter(&self) -> FlatIter<'_, T> { self.into_iter() } } impl Debug for FlatVpTree where T: Proximity + Debug, DistanceValue: Debug, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("FlatVpTree") .field("nodes", &self.nodes) .finish() } } impl FromIterator for FlatVpTree { fn from_iter>(items: I) -> Self { Self::balanced(items) } } /// An iterator that moves values out of a flat VP tree. pub struct FlatIntoIter(std::vec::IntoIter>); impl Debug for FlatIntoIter where T: Proximity + Debug, DistanceValue: Debug, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_tuple("FlatIntoIter") .field(&self.0) .finish() } } impl Iterator for FlatIntoIter { type Item = T; fn next(&mut self) -> Option { self.0.next().map(|n| n.item) } } impl IntoIterator for FlatVpTree { type Item = T; type IntoIter = FlatIntoIter; fn into_iter(self) -> Self::IntoIter { FlatIntoIter(self.nodes.into_iter()) } } /// An iterator over the values in a flat VP tree. pub struct FlatIter<'a, T: Proximity>(std::slice::Iter<'a, FlatVpNode>); impl<'a, T> Debug for FlatIter<'a, T> where T: Proximity + Debug, DistanceValue: Debug, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_tuple("FlatIter") .field(&self.0) .finish() } } impl<'a, T: Proximity> Iterator for FlatIter<'a, T> { type Item = &'a T; fn next(&mut self) -> Option<&'a T> { self.0.next().map(|n| &n.item) } } impl<'a, T: Proximity> IntoIterator for &'a FlatVpTree { type Item = &'a T; type IntoIter = FlatIter<'a, T>; fn into_iter(self) -> Self::IntoIter { FlatIter(self.nodes.iter()) } } impl NearestNeighbors for FlatVpTree where K: Proximity, V: Proximity, { fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N where K: 'k, V: 'v, N: Neighborhood<&'k K, &'v V>, { if !self.nodes.is_empty() { self.nodes.as_slice().search(&mut neighborhood); } neighborhood } } impl ExactNeighbors for FlatVpTree where K: Metric, V: Metric, {} #[cfg(test)] mod tests { use super::*; use crate::knn::tests::test_exact_neighbors; #[test] fn test_vp_tree() { test_exact_neighbors(VpTree::from_iter); } #[test] fn test_unbalanced_vp_tree() { test_exact_neighbors(|points| { let mut tree = VpTree::new(); for point in points { tree.push(point); } tree }); } #[test] fn test_flat_vp_tree() { test_exact_neighbors(FlatVpTree::from_iter); } }