From e272ede323d74d1dd30d28fd862099376b49265b Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Tue, 26 May 2020 12:54:26 -0400 Subject: vp: Implement flat VP trees --- benches/benches.rs | 8 ++- src/vp.rs | 183 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 189 insertions(+), 2 deletions(-) diff --git a/benches/benches.rs b/benches/benches.rs index 8791845..a8676ba 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -2,7 +2,7 @@ use acap::euclid::Euclidean; use acap::exhaustive::ExhaustiveSearch; -use acap::vp::VpTree; +use acap::vp::{FlatVpTree, VpTree}; use acap::NearestNeighbors; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -36,6 +36,7 @@ fn bench_from_iter(c: &mut Criterion) { let mut group = c.benchmark_group("from_iter"); group.bench_function("ExhaustiveSearch", |b| b.iter(|| ExhaustiveSearch::from_iter(points.clone()))); group.bench_function("VpTree", |b| b.iter(|| VpTree::from_iter(points.clone()))); + group.bench_function("FlatVpTree", |b| b.iter(|| FlatVpTree::from_iter(points.clone()))); group.finish(); } @@ -45,25 +46,30 @@ fn bench_nearest_neighbors(c: &mut Criterion) { let exhaustive = ExhaustiveSearch::from_iter(points.clone()); let vp_tree = VpTree::from_iter(points.clone()); + let flat_vp_tree = FlatVpTree::from_iter(points.clone()); let mut nearest = c.benchmark_group("NearestNeighbors::nearest"); nearest.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.nearest(&target))); nearest.bench_function("VpTree", |b| b.iter(|| vp_tree.nearest(&target))); + nearest.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.nearest(&target))); nearest.finish(); let mut nearest_within = c.benchmark_group("NearestNeighbors::nearest_within"); nearest_within.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.nearest_within(&target, 0.1))); nearest_within.bench_function("VpTree", |b| b.iter(|| vp_tree.nearest_within(&target, 0.1))); + nearest_within.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.nearest_within(&target, 0.1))); nearest_within.finish(); let mut k_nearest = c.benchmark_group("NearestNeighbors::k_nearest"); k_nearest.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.k_nearest(&target, 3))); k_nearest.bench_function("VpTree", |b| b.iter(|| vp_tree.k_nearest(&target, 3))); + k_nearest.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.k_nearest(&target, 3))); k_nearest.finish(); let mut k_nearest_within = c.benchmark_group("NearestNeighbors::k_nearest_within"); k_nearest_within.bench_function("ExhaustiveSearch", |b| b.iter(|| exhaustive.k_nearest_within(&target, 3, 0.1))); k_nearest_within.bench_function("VpTree", |b| b.iter(|| vp_tree.k_nearest_within(&target, 3, 0.1))); + k_nearest_within.bench_function("FlatVpTree", |b| b.iter(|| flat_vp_tree.k_nearest_within(&target, 3, 0.1))); k_nearest_within.finish(); } diff --git a/src/vp.rs b/src/vp.rs index 120e13b..e0645de 100644 --- a/src/vp.rs +++ b/src/vp.rs @@ -347,6 +347,183 @@ where V: Metric, {} +/// A node in a flat VP tree. +#[derive(Debug)] +struct FlatVpNode> { + /// The vantage point itself. + item: T, + /// The radius of this node. + radius: R, + /// The size of the inside subtree. + inside_len: usize, +} + +impl FlatVpNode { + /// Create a new FlatVpNode. + fn new(item: T) -> Self { + Self { + item, + radius: zero(), + inside_len: 0, + } + } + + /// Create a balanced tree. + fn balanced>(items: I) -> Vec { + let mut nodes: Vec<_> = items + .into_iter() + .map(Self::new) + .collect(); + + Self::balance_recursive(&mut nodes); + + nodes + } + + /// Create a balanced subtree. + fn balance_recursive(nodes: &mut [Self]) { + if let Some((node, children)) = nodes.split_first_mut() { + children.sort_by_cached_key(|x| Ordered::new(node.item.distance(&x.item))); + + let (inside, outside) = children.split_at_mut(children.len() / 2); + if let Some(last) = inside.last() { + node.radius = node.item.distance(&last.item).into(); + } + + node.inside_len = inside.len(); + + Self::balance_recursive(inside); + Self::balance_recursive(outside); + } + } +} + +impl<'a, K, V, N> VpSearch for &'a [FlatVpNode] +where + K: Proximity<&'a V, Distance = V::Distance>, + V: Proximity, + N: Neighborhood, +{ + fn item(self) -> &'a V { + &self[0].item + } + + fn radius(self) -> DistanceValue { + self[0].radius + } + + fn inside(self) -> Option { + let end = self[0].inside_len + 1; + if end > 1 { + Some(&self[1..end]) + } else { + None + } + } + + fn outside(self) -> Option { + let start = self[0].inside_len + 1; + if start < self.len() { + Some(&self[start..]) + } else { + None + } + } +} + +/// A [vantage-point tree] stored as a flat array. +/// +/// A FlatVpTree is always balanced and usually more efficient than a [VpTree], but doesn't support +/// dynamic updates. +/// +/// [vantage-point tree]: https://en.wikipedia.org/wiki/Vantage-point_tree +pub struct FlatVpTree { + nodes: Vec>, +} + +impl FlatVpTree { + /// Create a balanced tree out of a sequence of items. + pub fn balanced>(items: I) -> Self { + Self { + nodes: FlatVpNode::balanced(items), + } + } +} + +impl Debug for FlatVpTree +where + T: Proximity + Debug, + DistanceValue: Debug, +{ + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("FlatVpTree") + .field("node", &self.nodes) + .finish() + } +} + +impl FromIterator for FlatVpTree { + fn from_iter>(items: I) -> Self { + Self::balanced(items) + } +} + +/// An iterator that moves values out of a flat VP tree. +pub struct FlatIntoIter(std::vec::IntoIter>); + +impl Debug for FlatIntoIter +where + T: Proximity + Debug, + DistanceValue: Debug, +{ + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_tuple("FlatIntoIter") + .field(&self.0) + .finish() + } +} + +impl Iterator for FlatIntoIter { + type Item = T; + + fn next(&mut self) -> Option { + self.0.next().map(|n| n.item) + } +} + +impl IntoIterator for FlatVpTree { + type Item = T; + type IntoIter = FlatIntoIter; + + fn into_iter(self) -> Self::IntoIter { + FlatIntoIter(self.nodes.into_iter()) + } +} + +impl NearestNeighbors for FlatVpTree +where + K: Proximity, + V: Proximity, +{ + fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N + where + K: 'k, + V: 'v, + N: Neighborhood<&'k K, &'v V>, + { + if !self.nodes.is_empty() { + self.nodes.as_slice().search(&mut neighborhood); + } + neighborhood + } +} + +impl ExactNeighbors for FlatVpTree +where + K: Metric, + V: Metric, +{} + #[cfg(test)] mod tests { use super::*; @@ -368,5 +545,9 @@ mod tests { tree }); } -} + #[test] + fn test_flat_vp_tree() { + test_nearest_neighbors(FlatVpTree::from_iter); + } +} -- cgit v1.2.3