From 825515439247853af3714d3135051a83bd84d2e0 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sat, 9 May 2020 16:34:54 -0400 Subject: metric/forest: Use a flat staging buffer to reduce tree building overhead --- src/metric/forest.rs | 86 +++++++++++++++++++++++++++++----------------------- 1 file changed, 48 insertions(+), 38 deletions(-) (limited to 'src/metric') diff --git a/src/metric/forest.rs b/src/metric/forest.rs index 29b6f55..47eb413 100644 --- a/src/metric/forest.rs +++ b/src/metric/forest.rs @@ -4,14 +4,24 @@ use super::kd::KdTree; use super::vp::VpTree; use super::{Metric, NearestNeighbors, Neighborhood}; -use std::iter::{self, Extend, Flatten, FromIterator}; +use std::iter::{self, Extend, FromIterator}; + +/// The number of bits dedicated to the flat buffer. +const BUFFER_BITS: usize = 6; +/// The maximum size of the buffer. +const BUFFER_SIZE: usize = 1 << BUFFER_BITS; /// A dynamic wrapper for a static nearest neighbor search data structure. /// /// This type applies [dynamization](https://en.wikipedia.org/wiki/Dynamization) to an arbitrary /// nearest neighbor search structure `T`, allowing new items to be added dynamically. #[derive(Debug)] -pub struct Forest(Vec>); +pub struct Forest { + /// A flat buffer used for the first few items, to avoid repeatedly rebuilding small trees. + buffer: Vec, + /// The trees of the forest, with sizes in geometric progression. + trees: Vec>, +} impl Forest where @@ -19,7 +29,10 @@ where { /// Create a new empty forest. pub fn new() -> Self { - Self(Vec::new()) + Self { + buffer: Vec::new(), + trees: Vec::new(), + } } /// Add a new item to the forest. @@ -29,10 +42,10 @@ where /// Get the number of items in the forest. pub fn len(&self) -> usize { - let mut len = 0; - for (i, slot) in self.0.iter().enumerate() { + let mut len = self.buffer.len(); + for (i, slot) in self.trees.iter().enumerate() { if slot.is_some() { - len |= 1 << i; + len += 1 << (i + BUFFER_BITS); } } len @@ -44,32 +57,36 @@ where U: FromIterator + IntoIterator, { fn extend>(&mut self, items: I) { - let mut vec: Vec<_> = items.into_iter().collect(); - let new_len = self.len() + vec.len(); + self.buffer.extend(items); + if self.buffer.len() < BUFFER_SIZE { + return; + } + + let len = self.len(); for i in 0.. { - let bit = 1 << i; + let bit = 1 << (i + BUFFER_BITS); - if bit > new_len { + if bit > len { break; } - if i >= self.0.len() { - self.0.push(None); + if i >= self.trees.len() { + self.trees.push(None); } - if new_len & bit == 0 { - if let Some(tree) = self.0[i].take() { - vec.extend(tree); + if len & bit == 0 { + if let Some(tree) = self.trees[i].take() { + self.buffer.extend(tree); } - } else if self.0[i].is_none() { - let offset = vec.len() - bit; - self.0[i] = Some(vec.drain(offset..).collect()); + } else if self.trees[i].is_none() { + let offset = self.buffer.len() - bit; + self.trees[i] = Some(self.buffer.drain(offset..).collect()); } } - debug_assert!(vec.is_empty()); - debug_assert!(self.len() == new_len); + debug_assert!(self.buffer.len() < BUFFER_SIZE); + debug_assert!(self.len() == len); } } @@ -84,25 +101,13 @@ where } } -type IntoIterImpl = Flatten>>>; - -/// An iterator that moves items out of a forest. -pub struct IntoIter(IntoIterImpl); - -impl Iterator for IntoIter { - type Item = T::Item; - - fn next(&mut self) -> Option { - self.0.next() - } -} - impl IntoIterator for Forest { type Item = T::Item; - type IntoIter = IntoIter; + type IntoIter = std::vec::IntoIter; - fn into_iter(self) -> Self::IntoIter { - IntoIter(self.0.into_iter().flatten().flatten()) + fn into_iter(mut self) -> Self::IntoIter { + self.buffer.extend(self.trees.into_iter().flatten().flatten()); + self.buffer.into_iter() } } @@ -110,14 +115,19 @@ impl NearestNeighbors for Forest where U: Metric, V: NearestNeighbors, + V: IntoIterator, { - fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N where T: 'a, U: 'b, N: Neighborhood<&'a T, &'b U>, { - self.0 + for item in &self.buffer { + neighborhood.consider(item); + } + + self.trees .iter() .flatten() .fold(neighborhood, |n, t| t.search(n)) -- cgit v1.2.3