summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-04-23 09:55:26 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-05-02 13:19:05 -0400
commit9699f4657ecaaf4361448f249e4f2e210a854af4 (patch)
tree35fe3e9172d72a749e17cb96e99228e268cd8cd7
parent5de377b2b00a927a4f6463c1c5a5fd18606ad006 (diff)
downloadkd-forest-9699f4657ecaaf4361448f249e4f2e210a854af4.tar.xz
metric/vp: Flatten the tree representation
-rw-r--r--src/metric/vp.rs133
1 files changed, 51 insertions, 82 deletions
diff --git a/src/metric/vp.rs b/src/metric/vp.rs
index 8d5b091..fae62e5 100644
--- a/src/metric/vp.rs
+++ b/src/metric/vp.rs
@@ -11,78 +11,62 @@ struct VpNode<T> {
item: T,
/// The radius of this node.
radius: f64,
- /// The subtree inside the radius, if any.
- inside: Option<Box<Self>>,
- /// The subtree outside the radius, if any.
- outside: Option<Box<Self>>,
+ /// The size of the subtree inside the radius.
+ inside_len: usize,
}
impl<T: Metric> VpNode<T> {
/// Create a new VpNode.
- fn new(mut items: Vec<T>) -> Option<Box<Self>> {
- if items.is_empty() {
- return None;
+ fn new(item: T) -> Self {
+ Self {
+ item,
+ radius: 0.0,
+ inside_len: 0,
}
+ }
- let item = items.pop().unwrap();
-
- items.sort_by_cached_key(|a| item.distance(a));
-
- let mid = items.len() / 2;
- let outside: Vec<T> = items.drain(mid..).collect();
+ /// Build a VP tree recursively.
+ fn build(slice: &mut [VpNode<T>]) {
+ if let Some((node, children)) = slice.split_first_mut() {
+ let item = &node.item;
+ children.sort_by_cached_key(|n| item.distance(&n.item));
- let radius = items.last().map(|l| item.distance(l).into()).unwrap_or(0.0);
+ let (inside, outside) = children.split_at_mut(children.len() / 2);
+ if let Some(last) = inside.last() {
+ node.radius = item.distance(&last.item).into();
+ }
+ node.inside_len = inside.len();
- Some(Box::new(Self {
- item,
- radius,
- inside: Self::new(items),
- outside: Self::new(outside),
- }))
+ Self::build(inside);
+ Self::build(outside);
+ }
}
-}
-trait VpSearch<'a, T, U, N> {
/// Recursively search for nearest neighbors.
- fn search(&'a self, neighborhood: &mut N);
-
- /// Search the inside subtree.
- fn search_inside(&'a self, distance: f64, neighborhood: &mut N);
-
- /// Search the outside subtree.
- fn search_outside(&'a self, distance: f64, neighborhood: &mut N);
-}
+ fn recurse<'a, U, N>(slice: &'a [VpNode<T>], neighborhood: &mut N)
+ where
+ T: 'a,
+ U: Metric<&'a T>,
+ N: Neighborhood<&'a T, U>,
+ {
+ let (node, children) = slice.split_first().unwrap();
+ let (inside, outside) = children.split_at(node.inside_len);
-impl<'a, T, U, N> VpSearch<'a, T, U, N> for VpNode<T>
-where
- T: 'a,
- U: Metric<&'a T>,
- N: Neighborhood<&'a T, U>,
-{
- fn search(&'a self, neighborhood: &mut N) {
- let distance = neighborhood.consider(&self.item).into();
+ let distance = neighborhood.consider(&node.item).into();
- if distance <= self.radius {
- self.search_inside(distance, neighborhood);
- self.search_outside(distance, neighborhood);
+ if distance <= node.radius {
+ if !inside.is_empty() && neighborhood.contains(distance - node.radius) {
+ Self::recurse(inside, neighborhood);
+ }
+ if !outside.is_empty() && neighborhood.contains(node.radius - distance) {
+ Self::recurse(outside, neighborhood);
+ }
} else {
- self.search_outside(distance, neighborhood);
- self.search_inside(distance, neighborhood);
- }
- }
-
- fn search_inside(&'a self, distance: f64, neighborhood: &mut N) {
- if let Some(inside) = &self.inside {
- if neighborhood.contains(distance - self.radius) {
- inside.search(neighborhood);
+ if !outside.is_empty() && neighborhood.contains(node.radius - distance) {
+ Self::recurse(outside, neighborhood);
}
- }
- }
-
- fn search_outside(&'a self, distance: f64, neighborhood: &mut N) {
- if let Some(outside) = &self.outside {
- if neighborhood.contains(self.radius - distance) {
- outside.search(neighborhood);
+ if !inside.is_empty() && neighborhood.contains(distance - node.radius) {
+ Self::recurse(inside, neighborhood);
}
}
}
@@ -90,15 +74,13 @@ where
/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree).
#[derive(Debug)]
-pub struct VpTree<T> {
- root: Option<Box<VpNode<T>>>,
-}
+pub struct VpTree<T>(Vec<VpNode<T>>);
impl<T: Metric> FromIterator<T> for VpTree<T> {
fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
- Self {
- root: VpNode::new(items.into_iter().collect::<Vec<_>>()),
- }
+ let mut nodes: Vec<_> = items.into_iter().map(VpNode::new).collect();
+ VpNode::build(nodes.as_mut_slice());
+ Self(nodes)
}
}
@@ -113,36 +95,23 @@ where
U: 'b,
N: Neighborhood<&'a T, &'b U>,
{
- if let Some(root) = &self.root {
- root.search(&mut neighborhood);
+ if !self.0.is_empty() {
+ VpNode::recurse(&self.0, &mut neighborhood);
}
+
neighborhood
}
}
/// An iterator that moves values out of a VP tree.
#[derive(Debug)]
-pub struct IntoIter<T> {
- stack: Vec<Box<VpNode<T>>>,
-}
-
-impl<T> IntoIter<T> {
- fn new(node: Option<Box<VpNode<T>>>) -> Self {
- Self {
- stack: node.into_iter().collect(),
- }
- }
-}
+pub struct IntoIter<T>(std::vec::IntoIter<VpNode<T>>);
impl<T> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
- self.stack.pop().map(|node| {
- self.stack.extend(node.inside);
- self.stack.extend(node.outside);
- node.item
- })
+ self.0.next().map(|n| n.item)
}
}
@@ -151,7 +120,7 @@ impl<T> IntoIterator for VpTree<T> {
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
- IntoIter::new(self.root)
+ IntoIter(self.0.into_iter())
}
}