From 77b0033e83f11a194651123e42b7960c6e756657 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Fri, 3 Jul 2020 09:06:42 -0400 Subject: Add methods for merging new neighbors into a vector in-place --- src/lib.rs | 233 +++++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 159 insertions(+), 74 deletions(-) (limited to 'src') diff --git a/src/lib.rs b/src/lib.rs index 4792dbb..986c1d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,8 +110,6 @@ pub use coords::Coordinates; pub use distance::{Distance, Metric, Proximity}; pub use euclid::{euclidean_distance, Euclidean, EuclideanDistance}; -use std::cmp::Ordering; -use std::collections::BinaryHeap; use std::convert::TryInto; /// A nearest neighbor. @@ -167,41 +165,6 @@ pub trait Neighborhood, V> { fn consider(&mut self, item: V) -> K::Distance; } -/// A candidate nearest neighbor found during a search. -#[derive(Debug)] -struct Candidate(Neighbor); - -impl Candidate { - fn new(target: K, item: V) -> Self - where - K: Proximity, - { - let distance = target.distance(&item); - Self(Neighbor::new(item, distance)) - } -} - -impl PartialOrd for Candidate { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.distance.partial_cmp(&other.0.distance) - } -} - -impl Ord for Candidate { - fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other) - .expect("Unordered distances found during nearest neighbor search") - } -} - -impl PartialEq for Candidate { - fn eq(&self, other: &Self) -> bool { - self.0.distance == other.0.distance - } -} - -impl Eq for Candidate {} - /// A [Neighborhood] with at most one result. #[derive(Debug)] struct SingletonNeighborhood { @@ -210,7 +173,7 @@ struct SingletonNeighborhood { /// The current threshold distance. threshold: Option, /// The current nearest neighbor, if any. - candidate: Option>, + neighbor: Option>, } impl SingletonNeighborhood { @@ -222,13 +185,13 @@ impl SingletonNeighborhood { Self { target, threshold, - candidate: None, + neighbor: None, } } /// Convert this result into an optional neighbor. fn into_option(self) -> Option> { - self.candidate.map(|c| c.0) + self.neighbor } } @@ -248,12 +211,11 @@ where } fn consider(&mut self, item: V) -> K::Distance { - let candidate = Candidate::new(self.target, item); - let distance = candidate.0.distance; + let distance = self.target.distance(&item); if self.contains(distance) { self.threshold = Some(distance); - self.candidate = Some(candidate); + self.neighbor = Some(Neighbor::new(item, distance)); } distance @@ -262,7 +224,7 @@ where /// A [Neighborhood] of up to `k` results, using a binary heap. #[derive(Debug)] -struct HeapNeighborhood { +struct HeapNeighborhood<'a, K, V, D> { /// The target of the nearest neighbor search. target: K, /// The number of nearest neighbors to find. @@ -270,35 +232,79 @@ struct HeapNeighborhood { /// The current threshold distance to the farthest result. threshold: Option, /// A max-heap of the best candidates found so far. - heap: BinaryHeap>, + heap: &'a mut Vec>, } -impl HeapNeighborhood { +impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> { /// Create a new HeapNeighborhood. /// /// * `target`: The search target. - /// * `k`: The number of nearest neighbors to find. + /// * `k`: The maximum number of nearest neighbors to find. /// * `threshold`: The maximum allowable distance. - fn new(target: K, k: usize, threshold: Option) -> Self { + /// * `heap`: The vector of neighbors to use as the heap. + fn new( + target: K, + k: usize, + mut threshold: Option, + heap: &'a mut Vec>, + ) -> Self { + if k > 0 && heap.len() == k { + let distance = heap[0].distance; + if threshold.map_or(true, |t| distance <= t) { + threshold = Some(distance); + } + } + Self { target, k, threshold, - heap: BinaryHeap::with_capacity(k), + heap, + } + } + + /// Restore the heap property by raising an entry. + fn bubble_up(&mut self, mut i: usize) { + while i > 0 { + let parent = (i - 1) / 2; + if self.heap[i].distance <= self.heap[parent].distance { + break; + } + self.heap.swap(i, parent); + i = parent; + } + } + + /// Restore the heap property by lowering an entry. + fn bubble_down(&mut self, mut i: usize, len: usize) { + let dist = self.heap[i].distance; + + loop { + let mut child = 2 * i + 1; + let right = child + 1; + if right < len && self.heap[child].distance < self.heap[right].distance { + child = right; + } + + if child < len && dist < self.heap[child].distance { + self.heap.swap(i, child); + i = child; + } else { + break; + } } } - /// Extract the results from this neighborhood. - fn into_vec(self) -> Vec> { - self.heap - .into_sorted_vec() - .into_iter() - .map(|c| c.0) - .collect() + /// Sort the heap from smallest to largest distance. + fn sort(&mut self) { + for i in (0..self.heap.len()).rev() { + self.heap.swap(0, i); + self.bubble_down(0, i); + } } } -impl Neighborhood for HeapNeighborhood +impl<'a, K, V> Neighborhood for HeapNeighborhood<'a, K, V, K::Distance> where K: Copy + Proximity, { @@ -314,20 +320,21 @@ where } fn consider(&mut self, item: V) -> K::Distance { - let candidate = Candidate::new(self.target, item); - let distance = candidate.0.distance; + let distance = self.target.distance(&item); if self.contains(distance) { - let heap = &mut self.heap; - - if heap.len() == self.k { - heap.pop(); + let neighbor = Neighbor::new(item, distance); + + if self.heap.len() < self.k { + self.heap.push(neighbor); + self.bubble_up(self.heap.len() - 1); + } else { + self.heap[0] = neighbor; + self.bubble_down(0, self.heap.len()); } - heap.push(candidate); - - if heap.len() == self.k { - self.threshold = heap.peek().map(|c| c.0.distance) + if self.heap.len() == self.k { + self.threshold = Some(self.heap[0].distance); } } @@ -370,21 +377,74 @@ pub trait NearestNeighbors, V = K> { } /// Returns the up to `k` nearest neighbors to `target`. + /// + /// The result will be sorted from nearest to farthest. fn k_nearest(&self, target: &K, k: usize) -> Vec> { - self.search(HeapNeighborhood::new(target, k, None)) - .into_vec() + let mut neighbors = Vec::with_capacity(k); + self.search(HeapNeighborhood::new(target, k, None, &mut neighbors)) + .sort(); + neighbors } /// Returns the up to `k` nearest neighbors to `target` within the distance `threshold`. - fn k_nearest_within(&self, target: &K, k: usize, threshold: D) -> Vec> + /// + /// The result will be sorted from nearest to farthest. + fn k_nearest_within( + &self, + target: &K, + k: usize, + threshold: D, + ) -> Vec> where D: TryInto, { + let mut neighbors = Vec::with_capacity(k); + if let Ok(distance) = threshold.try_into() { - self.search(HeapNeighborhood::new(target, k, Some(distance))) - .into_vec() - } else { - Vec::new() + self.search(HeapNeighborhood::new( + target, + k, + Some(distance), + &mut neighbors, + )) + .sort(); + } + + neighbors + } + + /// Merges up to `k` nearest neighbors into an existing vector. + /// + /// The `neigbors` vector should either be empty, or populated by a previous call to + /// `merge_k_nearest()`. This method assumes a particular ordering that makes merging new + /// results efficient. If you want the results ordered from nearest to farthest, you must sort + /// it yourself. + fn merge_k_nearest<'v>( + &'v self, + target: &K, + k: usize, + neighbors: &mut Vec>, + ) { + self.search(HeapNeighborhood::new(target, k, None, neighbors)); + } + + /// Merges up to `k` nearest neighbors within the `threshold` into an existing vector. + /// + /// The `neigbors` vector should either be empty, or populated by a previous call to + /// `merge_k_nearest()`. This method assumes a particular ordering that makes merging new + /// results efficient. If you want the results ordered from nearest to farthest, you must sort + /// it yourself. + fn merge_k_nearest_within<'v, D>( + &'v self, + target: &K, + k: usize, + neighbors: &mut Vec>, + threshold: D, + ) where + D: TryInto, + { + if let Ok(distance) = threshold.try_into() { + self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors)); } } @@ -404,6 +464,7 @@ pub mod tests { use super::*; use crate::exhaustive::ExhaustiveSearch; + use crate::util::Ordered; use rand::prelude::*; @@ -491,6 +552,30 @@ pub mod tests { Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), ] ); + + let mut neighbors = Vec::new(); + index.merge_k_nearest(&target, 3, &mut neighbors); + neighbors.sort_by_key(|n| Ordered::new(n.distance)); + assert_eq!( + neighbors, + vec![ + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), + Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), + ] + ); + + neighbors.drain(0..2); + index.merge_k_nearest_within(&target, 3, &mut neighbors, 6.0); + neighbors.sort_by_key(|n| Ordered::new(n.distance)); + assert_eq!( + neighbors, + vec![ + Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0), + Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0), + Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0), + ] + ); } fn test_random_points(from_iter: &F) -- cgit v1.2.3