From 77b0033e83f11a194651123e42b7960c6e756657 Mon Sep 17 00:00:00 2001
From: Tavian Barnes <tavianator@tavianator.com>
Date: Fri, 3 Jul 2020 09:06:42 -0400
Subject: Add methods for merging new neighbors into a vector in-place

---
 benches/benches.rs |  33 ++++++--
 src/lib.rs         | 233 ++++++++++++++++++++++++++++++++++++-----------------
 2 files changed, 184 insertions(+), 82 deletions(-)

diff --git a/benches/benches.rs b/benches/benches.rs
index 39368d2..caed17c 100644
--- a/benches/benches.rs
+++ b/benches/benches.rs
@@ -65,17 +65,34 @@ fn bench_nearest_neighbors(c: &mut Criterion) {
             let index = $type::from_iter(points.clone());
 
             group.bench_function("nearest", |b| b.iter(
-                || index.nearest(&target))
-            );
+                || index.nearest(&target)
+            ));
             group.bench_function("nearest_within", |b| b.iter(
-                || index.nearest_within(&target, 0.1))
-            );
+                || index.nearest_within(&target, 0.1)
+            ));
             group.bench_function("k_nearest", |b| b.iter(
-                || index.k_nearest(&target, 3))
-            );
+                || index.k_nearest(&target, 3)
+            ));
             group.bench_function("k_nearest_within", |b| b.iter(
-                || index.k_nearest_within(&target, 3, 0.1))
-            );
+                || index.k_nearest_within(&target, 3, 0.1)
+            ));
+
+            group.bench_function("merge_k_nearest", |b| b.iter_batched(
+                || Vec::with_capacity(3),
+                |mut n| {
+                    index.merge_k_nearest(&target, 3, &mut n);
+                    n
+                },
+                BatchSize::SmallInput,
+            ));
+            group.bench_function("merge_k_nearest_within", |b| b.iter_batched(
+                || Vec::with_capacity(3),
+                |mut n| {
+                    index.merge_k_nearest_within(&target, 3, &mut n, 0.1);
+                    n
+                },
+                BatchSize::SmallInput,
+            ));
 
             group.finish();
         };
diff --git a/src/lib.rs b/src/lib.rs
index 4792dbb..986c1d3 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -110,8 +110,6 @@ pub use coords::Coordinates;
 pub use distance::{Distance, Metric, Proximity};
 pub use euclid::{euclidean_distance, Euclidean, EuclideanDistance};
 
-use std::cmp::Ordering;
-use std::collections::BinaryHeap;
 use std::convert::TryInto;
 
 /// A nearest neighbor.
@@ -167,41 +165,6 @@ pub trait Neighborhood<K: Proximity<V>, V> {
     fn consider(&mut self, item: V) -> K::Distance;
 }
 
-/// A candidate nearest neighbor found during a search.
-#[derive(Debug)]
-struct Candidate<V, D>(Neighbor<V, D>);
-
-impl<V, D: Distance> Candidate<V, D> {
-    fn new<K>(target: K, item: V) -> Self
-    where
-        K: Proximity<V, Distance = D>,
-    {
-        let distance = target.distance(&item);
-        Self(Neighbor::new(item, distance))
-    }
-}
-
-impl<V, D: PartialOrd> PartialOrd for Candidate<V, D> {
-    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
-        self.0.distance.partial_cmp(&other.0.distance)
-    }
-}
-
-impl<V, D: PartialOrd> Ord for Candidate<V, D> {
-    fn cmp(&self, other: &Self) -> Ordering {
-        self.partial_cmp(other)
-            .expect("Unordered distances found during nearest neighbor search")
-    }
-}
-
-impl<V, D: PartialEq> PartialEq for Candidate<V, D> {
-    fn eq(&self, other: &Self) -> bool {
-        self.0.distance == other.0.distance
-    }
-}
-
-impl<V, D: PartialEq> Eq for Candidate<V, D> {}
-
 /// A [Neighborhood] with at most one result.
 #[derive(Debug)]
 struct SingletonNeighborhood<K, V, D> {
@@ -210,7 +173,7 @@ struct SingletonNeighborhood<K, V, D> {
     /// The current threshold distance.
     threshold: Option<D>,
     /// The current nearest neighbor, if any.
-    candidate: Option<Candidate<V, D>>,
+    neighbor: Option<Neighbor<V, D>>,
 }
 
 impl<K, V, D> SingletonNeighborhood<K, V, D> {
@@ -222,13 +185,13 @@ impl<K, V, D> SingletonNeighborhood<K, V, D> {
         Self {
             target,
             threshold,
-            candidate: None,
+            neighbor: None,
         }
     }
 
     /// Convert this result into an optional neighbor.
     fn into_option(self) -> Option<Neighbor<V, D>> {
-        self.candidate.map(|c| c.0)
+        self.neighbor
     }
 }
 
@@ -248,12 +211,11 @@ where
     }
 
     fn consider(&mut self, item: V) -> K::Distance {
-        let candidate = Candidate::new(self.target, item);
-        let distance = candidate.0.distance;
+        let distance = self.target.distance(&item);
 
         if self.contains(distance) {
             self.threshold = Some(distance);
-            self.candidate = Some(candidate);
+            self.neighbor = Some(Neighbor::new(item, distance));
         }
 
         distance
@@ -262,7 +224,7 @@ where
 
 /// A [Neighborhood] of up to `k` results, using a binary heap.
 #[derive(Debug)]
-struct HeapNeighborhood<K, V, D> {
+struct HeapNeighborhood<'a, K, V, D> {
     /// The target of the nearest neighbor search.
     target: K,
     /// The number of nearest neighbors to find.
@@ -270,35 +232,79 @@ struct HeapNeighborhood<K, V, D> {
     /// The current threshold distance to the farthest result.
     threshold: Option<D>,
     /// A max-heap of the best candidates found so far.
-    heap: BinaryHeap<Candidate<V, D>>,
+    heap: &'a mut Vec<Neighbor<V, D>>,
 }
 
-impl<K, V, D: PartialOrd> HeapNeighborhood<K, V, D> {
+impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> {
     /// Create a new HeapNeighborhood.
     ///
     /// * `target`: The search target.
-    /// * `k`: The number of nearest neighbors to find.
+    /// * `k`: The maximum number of nearest neighbors to find.
     /// * `threshold`: The maximum allowable distance.
-    fn new(target: K, k: usize, threshold: Option<D>) -> Self {
+    /// * `heap`: The vector of neighbors to use as the heap.
+    fn new(
+        target: K,
+        k: usize,
+        mut threshold: Option<D>,
+        heap: &'a mut Vec<Neighbor<V, D>>,
+    ) -> Self {
+        if k > 0 && heap.len() == k {
+            let distance = heap[0].distance;
+            if threshold.map_or(true, |t| distance <= t) {
+                threshold = Some(distance);
+            }
+        }
+
         Self {
             target,
             k,
             threshold,
-            heap: BinaryHeap::with_capacity(k),
+            heap,
+        }
+    }
+
+    /// Restore the heap property by raising an entry.
+    fn bubble_up(&mut self, mut i: usize) {
+        while i > 0 {
+            let parent = (i - 1) / 2;
+            if self.heap[i].distance <= self.heap[parent].distance {
+                break;
+            }
+            self.heap.swap(i, parent);
+            i = parent;
+        }
+    }
+
+    /// Restore the heap property by lowering an entry.
+    fn bubble_down(&mut self, mut i: usize, len: usize) {
+        let dist = self.heap[i].distance;
+
+        loop {
+            let mut child = 2 * i + 1;
+            let right = child + 1;
+            if right < len && self.heap[child].distance < self.heap[right].distance {
+                child = right;
+            }
+
+            if child < len && dist < self.heap[child].distance {
+                self.heap.swap(i, child);
+                i = child;
+            } else {
+                break;
+            }
         }
     }
 
-    /// Extract the results from this neighborhood.
-    fn into_vec(self) -> Vec<Neighbor<V, D>> {
-        self.heap
-            .into_sorted_vec()
-            .into_iter()
-            .map(|c| c.0)
-            .collect()
+    /// Sort the heap from smallest to largest distance.
+    fn sort(&mut self) {
+        for i in (0..self.heap.len()).rev() {
+            self.heap.swap(0, i);
+            self.bubble_down(0, i);
+        }
     }
 }
 
-impl<K, V> Neighborhood<K, V> for HeapNeighborhood<K, V, K::Distance>
+impl<'a, K, V> Neighborhood<K, V> for HeapNeighborhood<'a, K, V, K::Distance>
 where
     K: Copy + Proximity<V>,
 {
@@ -314,20 +320,21 @@ where
     }
 
     fn consider(&mut self, item: V) -> K::Distance {
-        let candidate = Candidate::new(self.target, item);
-        let distance = candidate.0.distance;
+        let distance = self.target.distance(&item);
 
         if self.contains(distance) {
-            let heap = &mut self.heap;
-
-            if heap.len() == self.k {
-                heap.pop();
+            let neighbor = Neighbor::new(item, distance);
+
+            if self.heap.len() < self.k {
+                self.heap.push(neighbor);
+                self.bubble_up(self.heap.len() - 1);
+            } else {
+                self.heap[0] = neighbor;
+                self.bubble_down(0, self.heap.len());
             }
 
-            heap.push(candidate);
-
-            if heap.len() == self.k {
-                self.threshold = heap.peek().map(|c| c.0.distance)
+            if self.heap.len() == self.k {
+                self.threshold = Some(self.heap[0].distance);
             }
         }
 
@@ -370,21 +377,74 @@ pub trait NearestNeighbors<K: Proximity<V>, V = K> {
     }
 
     /// Returns the up to `k` nearest neighbors to `target`.
+    ///
+    /// The result will be sorted from nearest to farthest.
     fn k_nearest(&self, target: &K, k: usize) -> Vec<Neighbor<&V, K::Distance>> {
-        self.search(HeapNeighborhood::new(target, k, None))
-            .into_vec()
+        let mut neighbors = Vec::with_capacity(k);
+        self.search(HeapNeighborhood::new(target, k, None, &mut neighbors))
+            .sort();
+        neighbors
     }
 
     /// Returns the up to `k` nearest neighbors to `target` within the distance `threshold`.
-    fn k_nearest_within<D>(&self, target: &K, k: usize, threshold: D) -> Vec<Neighbor<&V, K::Distance>>
+    ///
+    /// The result will be sorted from nearest to farthest.
+    fn k_nearest_within<D>(
+        &self,
+        target: &K,
+        k: usize,
+        threshold: D,
+    ) -> Vec<Neighbor<&V, K::Distance>>
     where
         D: TryInto<K::Distance>,
     {
+        let mut neighbors = Vec::with_capacity(k);
+
         if let Ok(distance) = threshold.try_into() {
-            self.search(HeapNeighborhood::new(target, k, Some(distance)))
-                .into_vec()
-        } else {
-            Vec::new()
+            self.search(HeapNeighborhood::new(
+                target,
+                k,
+                Some(distance),
+                &mut neighbors,
+            ))
+            .sort();
+        }
+
+        neighbors
+    }
+
+    /// Merges up to `k` nearest neighbors into an existing vector.
+    ///
+    /// The `neigbors` vector should either be empty, or populated by a previous call to
+    /// `merge_k_nearest()`.  This method assumes a particular ordering that makes merging new
+    /// results efficient.  If you want the results ordered from nearest to farthest, you must sort
+    /// it yourself.
+    fn merge_k_nearest<'v>(
+        &'v self,
+        target: &K,
+        k: usize,
+        neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
+    ) {
+        self.search(HeapNeighborhood::new(target, k, None, neighbors));
+    }
+
+    /// Merges up to `k` nearest neighbors within the `threshold` into an existing vector.
+    ///
+    /// The `neigbors` vector should either be empty, or populated by a previous call to
+    /// `merge_k_nearest()`.  This method assumes a particular ordering that makes merging new
+    /// results efficient.  If you want the results ordered from nearest to farthest, you must sort
+    /// it yourself.
+    fn merge_k_nearest_within<'v, D>(
+        &'v self,
+        target: &K,
+        k: usize,
+        neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
+        threshold: D,
+    ) where
+        D: TryInto<K::Distance>,
+    {
+        if let Ok(distance) = threshold.try_into() {
+            self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors));
         }
     }
 
@@ -404,6 +464,7 @@ pub mod tests {
     use super::*;
 
     use crate::exhaustive::ExhaustiveSearch;
+    use crate::util::Ordered;
 
     use rand::prelude::*;
 
@@ -491,6 +552,30 @@ pub mod tests {
                 Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
             ]
         );
+
+        let mut neighbors = Vec::new();
+        index.merge_k_nearest(&target, 3, &mut neighbors);
+        neighbors.sort_by_key(|n| Ordered::new(n.distance));
+        assert_eq!(
+            neighbors,
+            vec![
+                Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
+                Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
+                Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
+            ]
+        );
+
+        neighbors.drain(0..2);
+        index.merge_k_nearest_within(&target, 3, &mut neighbors, 6.0);
+        neighbors.sort_by_key(|n| Ordered::new(n.distance));
+        assert_eq!(
+            neighbors,
+            vec![
+                Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
+                Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
+                Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
+            ]
+        );
     }
 
     fn test_random_points<T, F>(from_iter: &F)
-- 
cgit v1.2.3