diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/vp.rs | 183 |
1 files changed, 182 insertions, 1 deletions
@@ -347,6 +347,183 @@ where V: Metric, {} +/// A node in a flat VP tree. +#[derive(Debug)] +struct FlatVpNode<T, R = DistanceValue<T>> { + /// The vantage point itself. + item: T, + /// The radius of this node. + radius: R, + /// The size of the inside subtree. + inside_len: usize, +} + +impl<T: Proximity> FlatVpNode<T> { + /// Create a new FlatVpNode. + fn new(item: T) -> Self { + Self { + item, + radius: zero(), + inside_len: 0, + } + } + + /// Create a balanced tree. + fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> { + let mut nodes: Vec<_> = items + .into_iter() + .map(Self::new) + .collect(); + + Self::balance_recursive(&mut nodes); + + nodes + } + + /// Create a balanced subtree. + fn balance_recursive(nodes: &mut [Self]) { + if let Some((node, children)) = nodes.split_first_mut() { + children.sort_by_cached_key(|x| Ordered::new(node.item.distance(&x.item))); + + let (inside, outside) = children.split_at_mut(children.len() / 2); + if let Some(last) = inside.last() { + node.radius = node.item.distance(&last.item).into(); + } + + node.inside_len = inside.len(); + + Self::balance_recursive(inside); + Self::balance_recursive(outside); + } + } +} + +impl<'a, K, V, N> VpSearch<K, &'a V, N> for &'a [FlatVpNode<V>] +where + K: Proximity<&'a V, Distance = V::Distance>, + V: Proximity, + N: Neighborhood<K, &'a V>, +{ + fn item(self) -> &'a V { + &self[0].item + } + + fn radius(self) -> DistanceValue<V> { + self[0].radius + } + + fn inside(self) -> Option<Self> { + let end = self[0].inside_len + 1; + if end > 1 { + Some(&self[1..end]) + } else { + None + } + } + + fn outside(self) -> Option<Self> { + let start = self[0].inside_len + 1; + if start < self.len() { + Some(&self[start..]) + } else { + None + } + } +} + +/// A [vantage-point tree] stored as a flat array. +/// +/// A FlatVpTree is always balanced and usually more efficient than a [VpTree], but doesn't support +/// dynamic updates. +/// +/// [vantage-point tree]: https://en.wikipedia.org/wiki/Vantage-point_tree +pub struct FlatVpTree<T: Proximity> { + nodes: Vec<FlatVpNode<T>>, +} + +impl<T: Proximity> FlatVpTree<T> { + /// Create a balanced tree out of a sequence of items. + pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self { + Self { + nodes: FlatVpNode::balanced(items), + } + } +} + +impl<T> Debug for FlatVpTree<T> +where + T: Proximity + Debug, + DistanceValue<T>: Debug, +{ + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("FlatVpTree") + .field("node", &self.nodes) + .finish() + } +} + +impl<T: Proximity> FromIterator<T> for FlatVpTree<T> { + fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self { + Self::balanced(items) + } +} + +/// An iterator that moves values out of a flat VP tree. +pub struct FlatIntoIter<T: Proximity>(std::vec::IntoIter<FlatVpNode<T>>); + +impl<T> Debug for FlatIntoIter<T> +where + T: Proximity + Debug, + DistanceValue<T>: Debug, +{ + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_tuple("FlatIntoIter") + .field(&self.0) + .finish() + } +} + +impl<T: Proximity> Iterator for FlatIntoIter<T> { + type Item = T; + + fn next(&mut self) -> Option<T> { + self.0.next().map(|n| n.item) + } +} + +impl<T: Proximity> IntoIterator for FlatVpTree<T> { + type Item = T; + type IntoIter = FlatIntoIter<T>; + + fn into_iter(self) -> Self::IntoIter { + FlatIntoIter(self.nodes.into_iter()) + } +} + +impl<K, V> NearestNeighbors<K, V> for FlatVpTree<V> +where + K: Proximity<V, Distance = V::Distance>, + V: Proximity, +{ + fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N + where + K: 'k, + V: 'v, + N: Neighborhood<&'k K, &'v V>, + { + if !self.nodes.is_empty() { + self.nodes.as_slice().search(&mut neighborhood); + } + neighborhood + } +} + +impl<K, V> ExactNeighbors<K, V> for FlatVpTree<V> +where + K: Metric<V, Distance = V::Distance>, + V: Metric, +{} + #[cfg(test)] mod tests { use super::*; @@ -368,5 +545,9 @@ mod tests { tree }); } -} + #[test] + fn test_flat_vp_tree() { + test_nearest_neighbors(FlatVpTree::from_iter); + } +} |