summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-04-23 09:55:13 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-05-02 13:19:05 -0400
commite9a81a6d0df149252164003975addf175d5c6f4b (patch)
tree226a86a89662f0bef422e5200e5b7d7b5d4aa221
parent9699f4657ecaaf4361448f249e4f2e210a854af4 (diff)
downloadkd-forest-e9a81a6d0df149252164003975addf175d5c6f4b.tar.xz
metric/kd: Flatten the tree representation
-rw-r--r--src/metric/kd.rs113
1 files changed, 54 insertions, 59 deletions
diff --git a/src/metric/kd.rs b/src/metric/kd.rs
index db1b2bd..2caf4a3 100644
--- a/src/metric/kd.rs
+++ b/src/metric/kd.rs
@@ -66,61 +66,71 @@ where
struct KdNode<T> {
/// The value stored in this node.
item: T,
- /// The left subtree, if any.
- left: Option<Box<Self>>,
- /// The right subtree, if any.
- right: Option<Box<Self>>,
+ /// The size of the left subtree.
+ left_len: usize,
}
impl<T: Cartesian> KdNode<T> {
/// Create a new KdNode.
- fn new(i: usize, mut items: Vec<T>) -> Option<Box<Self>> {
- if items.is_empty() {
- return None;
+ fn new(item: T) -> Self {
+ Self { item, left_len: 0 }
+ }
+
+ /// Build a k-d tree recursively.
+ fn build(slice: &mut [KdNode<T>], i: usize) {
+ if slice.is_empty() {
+ return;
}
- items.sort_unstable_by_key(|x| OrderedFloat::from(x.coordinate(i)));
-
- let mid = items.len() / 2;
- let right: Vec<T> = items.drain((mid + 1)..).collect();
- let item = items.pop().unwrap();
- let j = (i + 1) % item.dimensions();
- Some(Box::new(Self {
- item,
- left: Self::new(j, items),
- right: Self::new(j, right),
- }))
+ slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i)));
+
+ let mid = slice.len() / 2;
+ slice.swap(0, mid);
+
+ let (node, children) = slice.split_first_mut().unwrap();
+ let (left, right) = children.split_at_mut(mid);
+ node.left_len = left.len();
+
+ let j = (i + 1) % node.item.dimensions();
+ Self::build(left, j);
+ Self::build(right, j);
}
/// Recursively search for nearest neighbors.
- fn search<'a, U, N>(&'a self, i: usize, closest: &mut [f64], neighborhood: &mut N)
- where
+ fn recurse<'a, U, N>(
+ slice: &'a [KdNode<T>],
+ i: usize,
+ closest: &mut [f64],
+ neighborhood: &mut N,
+ ) where
T: 'a,
U: CartesianMetric<&'a T>,
N: Neighborhood<&'a T, U>,
{
- neighborhood.consider(&self.item);
+ let (node, children) = slice.split_first().unwrap();
+ neighborhood.consider(&node.item);
let target = neighborhood.target();
let ti = target.coordinate(i);
- let si = self.item.coordinate(i);
- let j = (i + 1) % self.item.dimensions();
+ let ni = node.item.coordinate(i);
+ let j = (i + 1) % node.item.dimensions();
- let (near, far) = if ti <= si {
- (&self.left, &self.right)
+ let (left, right) = children.split_at(node.left_len);
+ let (near, far) = if ti <= ni {
+ (left, right)
} else {
- (&self.right, &self.left)
+ (right, left)
};
- if let Some(near) = near {
- near.search(j, closest, neighborhood);
+ if !near.is_empty() {
+ Self::recurse(near, j, closest, neighborhood);
}
- if let Some(far) = far {
+ if !far.is_empty() {
let saved = closest[i];
- closest[i] = si;
+ closest[i] = ni;
if neighborhood.contains_distance(target.distance(closest)) {
- far.search(j, closest, neighborhood);
+ Self::recurse(far, j, closest, neighborhood);
}
closest[i] = saved;
}
@@ -129,16 +139,14 @@ impl<T: Cartesian> KdNode<T> {
/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree).
#[derive(Debug)]
-pub struct KdTree<T> {
- root: Option<Box<KdNode<T>>>,
-}
+pub struct KdTree<T>(Vec<KdNode<T>>);
impl<T: Cartesian> FromIterator<T> for KdTree<T> {
/// Create a new k-d tree from a set of points.
fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
- Self {
- root: KdNode::new(0, items.into_iter().collect()),
- }
+ let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect();
+ KdNode::build(nodes.as_mut_slice(), 0);
+ Self(nodes)
}
}
@@ -153,40 +161,27 @@ where
U: 'b,
N: Neighborhood<&'a T, &'b U>,
{
- let target = neighborhood.target();
- let dims = target.dimensions();
- let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect();
+ if !self.0.is_empty() {
+ let target = neighborhood.target();
+ let dims = target.dimensions();
+ let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect();
- if let Some(root) = &self.root {
- root.search(0, &mut closest, &mut neighborhood);
+ KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood);
}
+
neighborhood
}
}
/// An iterator that the moves values out of a k-d tree.
#[derive(Debug)]
-pub struct IntoIter<T> {
- stack: Vec<Box<KdNode<T>>>,
-}
-
-impl<T> IntoIter<T> {
- fn new(node: Option<Box<KdNode<T>>>) -> Self {
- Self {
- stack: node.into_iter().collect(),
- }
- }
-}
+pub struct IntoIter<T>(std::vec::IntoIter<KdNode<T>>);
impl<T> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
- self.stack.pop().map(|node| {
- self.stack.extend(node.left);
- self.stack.extend(node.right);
- node.item
- })
+ self.0.next().map(|n| n.item)
}
}
@@ -195,7 +190,7 @@ impl<T> IntoIterator for KdTree<T> {
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
- IntoIter::new(self.root)
+ IntoIter(self.0.into_iter())
}
}