From 9699f4657ecaaf4361448f249e4f2e210a854af4 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Thu, 23 Apr 2020 09:55:26 -0400 Subject: metric/vp: Flatten the tree representation --- src/metric/vp.rs | 133 +++++++++++++++++++++---------------------------------- 1 file changed, 51 insertions(+), 82 deletions(-) (limited to 'src') diff --git a/src/metric/vp.rs b/src/metric/vp.rs index 8d5b091..fae62e5 100644 --- a/src/metric/vp.rs +++ b/src/metric/vp.rs @@ -11,78 +11,62 @@ struct VpNode { item: T, /// The radius of this node. radius: f64, - /// The subtree inside the radius, if any. - inside: Option>, - /// The subtree outside the radius, if any. - outside: Option>, + /// The size of the subtree inside the radius. + inside_len: usize, } impl VpNode { /// Create a new VpNode. - fn new(mut items: Vec) -> Option> { - if items.is_empty() { - return None; + fn new(item: T) -> Self { + Self { + item, + radius: 0.0, + inside_len: 0, } + } - let item = items.pop().unwrap(); - - items.sort_by_cached_key(|a| item.distance(a)); - - let mid = items.len() / 2; - let outside: Vec = items.drain(mid..).collect(); + /// Build a VP tree recursively. + fn build(slice: &mut [VpNode]) { + if let Some((node, children)) = slice.split_first_mut() { + let item = &node.item; + children.sort_by_cached_key(|n| item.distance(&n.item)); - let radius = items.last().map(|l| item.distance(l).into()).unwrap_or(0.0); + let (inside, outside) = children.split_at_mut(children.len() / 2); + if let Some(last) = inside.last() { + node.radius = item.distance(&last.item).into(); + } + node.inside_len = inside.len(); - Some(Box::new(Self { - item, - radius, - inside: Self::new(items), - outside: Self::new(outside), - })) + Self::build(inside); + Self::build(outside); + } } -} -trait VpSearch<'a, T, U, N> { /// Recursively search for nearest neighbors. - fn search(&'a self, neighborhood: &mut N); - - /// Search the inside subtree. - fn search_inside(&'a self, distance: f64, neighborhood: &mut N); - - /// Search the outside subtree. - fn search_outside(&'a self, distance: f64, neighborhood: &mut N); -} + fn recurse<'a, U, N>(slice: &'a [VpNode], neighborhood: &mut N) + where + T: 'a, + U: Metric<&'a T>, + N: Neighborhood<&'a T, U>, + { + let (node, children) = slice.split_first().unwrap(); + let (inside, outside) = children.split_at(node.inside_len); -impl<'a, T, U, N> VpSearch<'a, T, U, N> for VpNode -where - T: 'a, - U: Metric<&'a T>, - N: Neighborhood<&'a T, U>, -{ - fn search(&'a self, neighborhood: &mut N) { - let distance = neighborhood.consider(&self.item).into(); + let distance = neighborhood.consider(&node.item).into(); - if distance <= self.radius { - self.search_inside(distance, neighborhood); - self.search_outside(distance, neighborhood); + if distance <= node.radius { + if !inside.is_empty() && neighborhood.contains(distance - node.radius) { + Self::recurse(inside, neighborhood); + } + if !outside.is_empty() && neighborhood.contains(node.radius - distance) { + Self::recurse(outside, neighborhood); + } } else { - self.search_outside(distance, neighborhood); - self.search_inside(distance, neighborhood); - } - } - - fn search_inside(&'a self, distance: f64, neighborhood: &mut N) { - if let Some(inside) = &self.inside { - if neighborhood.contains(distance - self.radius) { - inside.search(neighborhood); + if !outside.is_empty() && neighborhood.contains(node.radius - distance) { + Self::recurse(outside, neighborhood); } - } - } - - fn search_outside(&'a self, distance: f64, neighborhood: &mut N) { - if let Some(outside) = &self.outside { - if neighborhood.contains(self.radius - distance) { - outside.search(neighborhood); + if !inside.is_empty() && neighborhood.contains(distance - node.radius) { + Self::recurse(inside, neighborhood); } } } @@ -90,15 +74,13 @@ where /// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree). #[derive(Debug)] -pub struct VpTree { - root: Option>>, -} +pub struct VpTree(Vec>); impl FromIterator for VpTree { fn from_iter>(items: I) -> Self { - Self { - root: VpNode::new(items.into_iter().collect::>()), - } + let mut nodes: Vec<_> = items.into_iter().map(VpNode::new).collect(); + VpNode::build(nodes.as_mut_slice()); + Self(nodes) } } @@ -113,36 +95,23 @@ where U: 'b, N: Neighborhood<&'a T, &'b U>, { - if let Some(root) = &self.root { - root.search(&mut neighborhood); + if !self.0.is_empty() { + VpNode::recurse(&self.0, &mut neighborhood); } + neighborhood } } /// An iterator that moves values out of a VP tree. #[derive(Debug)] -pub struct IntoIter { - stack: Vec>>, -} - -impl IntoIter { - fn new(node: Option>>) -> Self { - Self { - stack: node.into_iter().collect(), - } - } -} +pub struct IntoIter(std::vec::IntoIter>); impl Iterator for IntoIter { type Item = T; fn next(&mut self) -> Option { - self.stack.pop().map(|node| { - self.stack.extend(node.inside); - self.stack.extend(node.outside); - node.item - }) + self.0.next().map(|n| n.item) } } @@ -151,7 +120,7 @@ impl IntoIterator for VpTree { type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - IntoIter::new(self.root) + IntoIter(self.0.into_iter()) } } -- cgit v1.2.3