From e9a81a6d0df149252164003975addf175d5c6f4b Mon Sep 17 00:00:00 2001
From: Tavian Barnes <tavianator@tavianator.com>
Date: Thu, 23 Apr 2020 09:55:13 -0400
Subject: metric/kd: Flatten the tree representation

---
 src/metric/kd.rs | 113 ++++++++++++++++++++++++++-----------------------------
 1 file changed, 54 insertions(+), 59 deletions(-)

(limited to 'src')

diff --git a/src/metric/kd.rs b/src/metric/kd.rs
index db1b2bd..2caf4a3 100644
--- a/src/metric/kd.rs
+++ b/src/metric/kd.rs
@@ -66,61 +66,71 @@ where
 struct KdNode<T> {
     /// The value stored in this node.
     item: T,
-    /// The left subtree, if any.
-    left: Option<Box<Self>>,
-    /// The right subtree, if any.
-    right: Option<Box<Self>>,
+    /// The size of the left subtree.
+    left_len: usize,
 }
 
 impl<T: Cartesian> KdNode<T> {
     /// Create a new KdNode.
-    fn new(i: usize, mut items: Vec<T>) -> Option<Box<Self>> {
-        if items.is_empty() {
-            return None;
+    fn new(item: T) -> Self {
+        Self { item, left_len: 0 }
+    }
+
+    /// Build a k-d tree recursively.
+    fn build(slice: &mut [KdNode<T>], i: usize) {
+        if slice.is_empty() {
+            return;
         }
 
-        items.sort_unstable_by_key(|x| OrderedFloat::from(x.coordinate(i)));
-
-        let mid = items.len() / 2;
-        let right: Vec<T> = items.drain((mid + 1)..).collect();
-        let item = items.pop().unwrap();
-        let j = (i + 1) % item.dimensions();
-        Some(Box::new(Self {
-            item,
-            left: Self::new(j, items),
-            right: Self::new(j, right),
-        }))
+        slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i)));
+
+        let mid = slice.len() / 2;
+        slice.swap(0, mid);
+
+        let (node, children) = slice.split_first_mut().unwrap();
+        let (left, right) = children.split_at_mut(mid);
+        node.left_len = left.len();
+
+        let j = (i + 1) % node.item.dimensions();
+        Self::build(left, j);
+        Self::build(right, j);
     }
 
     /// Recursively search for nearest neighbors.
-    fn search<'a, U, N>(&'a self, i: usize, closest: &mut [f64], neighborhood: &mut N)
-    where
+    fn recurse<'a, U, N>(
+        slice: &'a [KdNode<T>],
+        i: usize,
+        closest: &mut [f64],
+        neighborhood: &mut N,
+    ) where
         T: 'a,
         U: CartesianMetric<&'a T>,
         N: Neighborhood<&'a T, U>,
     {
-        neighborhood.consider(&self.item);
+        let (node, children) = slice.split_first().unwrap();
+        neighborhood.consider(&node.item);
 
         let target = neighborhood.target();
         let ti = target.coordinate(i);
-        let si = self.item.coordinate(i);
-        let j = (i + 1) % self.item.dimensions();
+        let ni = node.item.coordinate(i);
+        let j = (i + 1) % node.item.dimensions();
 
-        let (near, far) = if ti <= si {
-            (&self.left, &self.right)
+        let (left, right) = children.split_at(node.left_len);
+        let (near, far) = if ti <= ni {
+            (left, right)
         } else {
-            (&self.right, &self.left)
+            (right, left)
         };
 
-        if let Some(near) = near {
-            near.search(j, closest, neighborhood);
+        if !near.is_empty() {
+            Self::recurse(near, j, closest, neighborhood);
         }
 
-        if let Some(far) = far {
+        if !far.is_empty() {
             let saved = closest[i];
-            closest[i] = si;
+            closest[i] = ni;
             if neighborhood.contains_distance(target.distance(closest)) {
-                far.search(j, closest, neighborhood);
+                Self::recurse(far, j, closest, neighborhood);
             }
             closest[i] = saved;
         }
@@ -129,16 +139,14 @@ impl<T: Cartesian> KdNode<T> {
 
 /// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree).
 #[derive(Debug)]
-pub struct KdTree<T> {
-    root: Option<Box<KdNode<T>>>,
-}
+pub struct KdTree<T>(Vec<KdNode<T>>);
 
 impl<T: Cartesian> FromIterator<T> for KdTree<T> {
     /// Create a new k-d tree from a set of points.
     fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
-        Self {
-            root: KdNode::new(0, items.into_iter().collect()),
-        }
+        let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect();
+        KdNode::build(nodes.as_mut_slice(), 0);
+        Self(nodes)
     }
 }
 
@@ -153,40 +161,27 @@ where
         U: 'b,
         N: Neighborhood<&'a T, &'b U>,
     {
-        let target = neighborhood.target();
-        let dims = target.dimensions();
-        let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect();
+        if !self.0.is_empty() {
+            let target = neighborhood.target();
+            let dims = target.dimensions();
+            let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect();
 
-        if let Some(root) = &self.root {
-            root.search(0, &mut closest, &mut neighborhood);
+            KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood);
         }
+
         neighborhood
     }
 }
 
 /// An iterator that the moves values out of a k-d tree.
 #[derive(Debug)]
-pub struct IntoIter<T> {
-    stack: Vec<Box<KdNode<T>>>,
-}
-
-impl<T> IntoIter<T> {
-    fn new(node: Option<Box<KdNode<T>>>) -> Self {
-        Self {
-            stack: node.into_iter().collect(),
-        }
-    }
-}
+pub struct IntoIter<T>(std::vec::IntoIter<KdNode<T>>);
 
 impl<T> Iterator for IntoIter<T> {
     type Item = T;
 
     fn next(&mut self) -> Option<T> {
-        self.stack.pop().map(|node| {
-            self.stack.extend(node.left);
-            self.stack.extend(node.right);
-            node.item
-        })
+        self.0.next().map(|n| n.item)
     }
 }
 
@@ -195,7 +190,7 @@ impl<T> IntoIterator for KdTree<T> {
     type IntoIter = IntoIter<T>;
 
     fn into_iter(self) -> Self::IntoIter {
-        IntoIter::new(self.root)
+        IntoIter(self.0.into_iter())
     }
 }
 
-- 
cgit v1.2.3