summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/vp.rs183
1 files changed, 182 insertions, 1 deletions
diff --git a/src/vp.rs b/src/vp.rs
index 120e13b..e0645de 100644
--- a/src/vp.rs
+++ b/src/vp.rs
@@ -347,6 +347,183 @@ where
V: Metric,
{}
+/// A node in a flat VP tree.
+#[derive(Debug)]
+struct FlatVpNode<T, R = DistanceValue<T>> {
+ /// The vantage point itself.
+ item: T,
+ /// The radius of this node.
+ radius: R,
+ /// The size of the inside subtree.
+ inside_len: usize,
+}
+
+impl<T: Proximity> FlatVpNode<T> {
+ /// Create a new FlatVpNode.
+ fn new(item: T) -> Self {
+ Self {
+ item,
+ radius: zero(),
+ inside_len: 0,
+ }
+ }
+
+ /// Create a balanced tree.
+ fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> {
+ let mut nodes: Vec<_> = items
+ .into_iter()
+ .map(Self::new)
+ .collect();
+
+ Self::balance_recursive(&mut nodes);
+
+ nodes
+ }
+
+ /// Create a balanced subtree.
+ fn balance_recursive(nodes: &mut [Self]) {
+ if let Some((node, children)) = nodes.split_first_mut() {
+ children.sort_by_cached_key(|x| Ordered::new(node.item.distance(&x.item)));
+
+ let (inside, outside) = children.split_at_mut(children.len() / 2);
+ if let Some(last) = inside.last() {
+ node.radius = node.item.distance(&last.item).into();
+ }
+
+ node.inside_len = inside.len();
+
+ Self::balance_recursive(inside);
+ Self::balance_recursive(outside);
+ }
+ }
+}
+
+impl<'a, K, V, N> VpSearch<K, &'a V, N> for &'a [FlatVpNode<V>]
+where
+ K: Proximity<&'a V, Distance = V::Distance>,
+ V: Proximity,
+ N: Neighborhood<K, &'a V>,
+{
+ fn item(self) -> &'a V {
+ &self[0].item
+ }
+
+ fn radius(self) -> DistanceValue<V> {
+ self[0].radius
+ }
+
+ fn inside(self) -> Option<Self> {
+ let end = self[0].inside_len + 1;
+ if end > 1 {
+ Some(&self[1..end])
+ } else {
+ None
+ }
+ }
+
+ fn outside(self) -> Option<Self> {
+ let start = self[0].inside_len + 1;
+ if start < self.len() {
+ Some(&self[start..])
+ } else {
+ None
+ }
+ }
+}
+
+/// A [vantage-point tree] stored as a flat array.
+///
+/// A FlatVpTree is always balanced and usually more efficient than a [VpTree], but doesn't support
+/// dynamic updates.
+///
+/// [vantage-point tree]: https://en.wikipedia.org/wiki/Vantage-point_tree
+pub struct FlatVpTree<T: Proximity> {
+ nodes: Vec<FlatVpNode<T>>,
+}
+
+impl<T: Proximity> FlatVpTree<T> {
+ /// Create a balanced tree out of a sequence of items.
+ pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
+ Self {
+ nodes: FlatVpNode::balanced(items),
+ }
+ }
+}
+
+impl<T> Debug for FlatVpTree<T>
+where
+ T: Proximity + Debug,
+ DistanceValue<T>: Debug,
+{
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ f.debug_struct("FlatVpTree")
+ .field("node", &self.nodes)
+ .finish()
+ }
+}
+
+impl<T: Proximity> FromIterator<T> for FlatVpTree<T> {
+ fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
+ Self::balanced(items)
+ }
+}
+
+/// An iterator that moves values out of a flat VP tree.
+pub struct FlatIntoIter<T: Proximity>(std::vec::IntoIter<FlatVpNode<T>>);
+
+impl<T> Debug for FlatIntoIter<T>
+where
+ T: Proximity + Debug,
+ DistanceValue<T>: Debug,
+{
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ f.debug_tuple("FlatIntoIter")
+ .field(&self.0)
+ .finish()
+ }
+}
+
+impl<T: Proximity> Iterator for FlatIntoIter<T> {
+ type Item = T;
+
+ fn next(&mut self) -> Option<T> {
+ self.0.next().map(|n| n.item)
+ }
+}
+
+impl<T: Proximity> IntoIterator for FlatVpTree<T> {
+ type Item = T;
+ type IntoIter = FlatIntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ FlatIntoIter(self.nodes.into_iter())
+ }
+}
+
+impl<K, V> NearestNeighbors<K, V> for FlatVpTree<V>
+where
+ K: Proximity<V, Distance = V::Distance>,
+ V: Proximity,
+{
+ fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
+ where
+ K: 'k,
+ V: 'v,
+ N: Neighborhood<&'k K, &'v V>,
+ {
+ if !self.nodes.is_empty() {
+ self.nodes.as_slice().search(&mut neighborhood);
+ }
+ neighborhood
+ }
+}
+
+impl<K, V> ExactNeighbors<K, V> for FlatVpTree<V>
+where
+ K: Metric<V, Distance = V::Distance>,
+ V: Metric,
+{}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -368,5 +545,9 @@ mod tests {
tree
});
}
-}
+ #[test]
+ fn test_flat_vp_tree() {
+ test_nearest_neighbors(FlatVpTree::from_iter);
+ }
+}