summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-07-03 09:06:42 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-07-03 09:06:42 -0400
commit77b0033e83f11a194651123e42b7960c6e756657 (patch)
tree8899bd51e18d4a79872ac77ebc6eeadce2cd9d54
parentd314f33de39fd3b32d093c2fc2d59c0aeb92c655 (diff)
downloadacap-77b0033e83f11a194651123e42b7960c6e756657.tar.xz
Add methods for merging new neighbors into a vector in-place
-rw-r--r--benches/benches.rs33
-rw-r--r--src/lib.rs233
2 files changed, 184 insertions, 82 deletions
diff --git a/benches/benches.rs b/benches/benches.rs
index 39368d2..caed17c 100644
--- a/benches/benches.rs
+++ b/benches/benches.rs
@@ -65,17 +65,34 @@ fn bench_nearest_neighbors(c: &mut Criterion) {
let index = $type::from_iter(points.clone());
group.bench_function("nearest", |b| b.iter(
- || index.nearest(&target))
- );
+ || index.nearest(&target)
+ ));
group.bench_function("nearest_within", |b| b.iter(
- || index.nearest_within(&target, 0.1))
- );
+ || index.nearest_within(&target, 0.1)
+ ));
group.bench_function("k_nearest", |b| b.iter(
- || index.k_nearest(&target, 3))
- );
+ || index.k_nearest(&target, 3)
+ ));
group.bench_function("k_nearest_within", |b| b.iter(
- || index.k_nearest_within(&target, 3, 0.1))
- );
+ || index.k_nearest_within(&target, 3, 0.1)
+ ));
+
+ group.bench_function("merge_k_nearest", |b| b.iter_batched(
+ || Vec::with_capacity(3),
+ |mut n| {
+ index.merge_k_nearest(&target, 3, &mut n);
+ n
+ },
+ BatchSize::SmallInput,
+ ));
+ group.bench_function("merge_k_nearest_within", |b| b.iter_batched(
+ || Vec::with_capacity(3),
+ |mut n| {
+ index.merge_k_nearest_within(&target, 3, &mut n, 0.1);
+ n
+ },
+ BatchSize::SmallInput,
+ ));
group.finish();
};
diff --git a/src/lib.rs b/src/lib.rs
index 4792dbb..986c1d3 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -110,8 +110,6 @@ pub use coords::Coordinates;
pub use distance::{Distance, Metric, Proximity};
pub use euclid::{euclidean_distance, Euclidean, EuclideanDistance};
-use std::cmp::Ordering;
-use std::collections::BinaryHeap;
use std::convert::TryInto;
/// A nearest neighbor.
@@ -167,41 +165,6 @@ pub trait Neighborhood<K: Proximity<V>, V> {
fn consider(&mut self, item: V) -> K::Distance;
}
-/// A candidate nearest neighbor found during a search.
-#[derive(Debug)]
-struct Candidate<V, D>(Neighbor<V, D>);
-
-impl<V, D: Distance> Candidate<V, D> {
- fn new<K>(target: K, item: V) -> Self
- where
- K: Proximity<V, Distance = D>,
- {
- let distance = target.distance(&item);
- Self(Neighbor::new(item, distance))
- }
-}
-
-impl<V, D: PartialOrd> PartialOrd for Candidate<V, D> {
- fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
- self.0.distance.partial_cmp(&other.0.distance)
- }
-}
-
-impl<V, D: PartialOrd> Ord for Candidate<V, D> {
- fn cmp(&self, other: &Self) -> Ordering {
- self.partial_cmp(other)
- .expect("Unordered distances found during nearest neighbor search")
- }
-}
-
-impl<V, D: PartialEq> PartialEq for Candidate<V, D> {
- fn eq(&self, other: &Self) -> bool {
- self.0.distance == other.0.distance
- }
-}
-
-impl<V, D: PartialEq> Eq for Candidate<V, D> {}
-
/// A [Neighborhood] with at most one result.
#[derive(Debug)]
struct SingletonNeighborhood<K, V, D> {
@@ -210,7 +173,7 @@ struct SingletonNeighborhood<K, V, D> {
/// The current threshold distance.
threshold: Option<D>,
/// The current nearest neighbor, if any.
- candidate: Option<Candidate<V, D>>,
+ neighbor: Option<Neighbor<V, D>>,
}
impl<K, V, D> SingletonNeighborhood<K, V, D> {
@@ -222,13 +185,13 @@ impl<K, V, D> SingletonNeighborhood<K, V, D> {
Self {
target,
threshold,
- candidate: None,
+ neighbor: None,
}
}
/// Convert this result into an optional neighbor.
fn into_option(self) -> Option<Neighbor<V, D>> {
- self.candidate.map(|c| c.0)
+ self.neighbor
}
}
@@ -248,12 +211,11 @@ where
}
fn consider(&mut self, item: V) -> K::Distance {
- let candidate = Candidate::new(self.target, item);
- let distance = candidate.0.distance;
+ let distance = self.target.distance(&item);
if self.contains(distance) {
self.threshold = Some(distance);
- self.candidate = Some(candidate);
+ self.neighbor = Some(Neighbor::new(item, distance));
}
distance
@@ -262,7 +224,7 @@ where
/// A [Neighborhood] of up to `k` results, using a binary heap.
#[derive(Debug)]
-struct HeapNeighborhood<K, V, D> {
+struct HeapNeighborhood<'a, K, V, D> {
/// The target of the nearest neighbor search.
target: K,
/// The number of nearest neighbors to find.
@@ -270,35 +232,79 @@ struct HeapNeighborhood<K, V, D> {
/// The current threshold distance to the farthest result.
threshold: Option<D>,
/// A max-heap of the best candidates found so far.
- heap: BinaryHeap<Candidate<V, D>>,
+ heap: &'a mut Vec<Neighbor<V, D>>,
}
-impl<K, V, D: PartialOrd> HeapNeighborhood<K, V, D> {
+impl<'a, K, V, D: Distance> HeapNeighborhood<'a, K, V, D> {
/// Create a new HeapNeighborhood.
///
/// * `target`: The search target.
- /// * `k`: The number of nearest neighbors to find.
+ /// * `k`: The maximum number of nearest neighbors to find.
/// * `threshold`: The maximum allowable distance.
- fn new(target: K, k: usize, threshold: Option<D>) -> Self {
+ /// * `heap`: The vector of neighbors to use as the heap.
+ fn new(
+ target: K,
+ k: usize,
+ mut threshold: Option<D>,
+ heap: &'a mut Vec<Neighbor<V, D>>,
+ ) -> Self {
+ if k > 0 && heap.len() == k {
+ let distance = heap[0].distance;
+ if threshold.map_or(true, |t| distance <= t) {
+ threshold = Some(distance);
+ }
+ }
+
Self {
target,
k,
threshold,
- heap: BinaryHeap::with_capacity(k),
+ heap,
+ }
+ }
+
+ /// Restore the heap property by raising an entry.
+ fn bubble_up(&mut self, mut i: usize) {
+ while i > 0 {
+ let parent = (i - 1) / 2;
+ if self.heap[i].distance <= self.heap[parent].distance {
+ break;
+ }
+ self.heap.swap(i, parent);
+ i = parent;
+ }
+ }
+
+ /// Restore the heap property by lowering an entry.
+ fn bubble_down(&mut self, mut i: usize, len: usize) {
+ let dist = self.heap[i].distance;
+
+ loop {
+ let mut child = 2 * i + 1;
+ let right = child + 1;
+ if right < len && self.heap[child].distance < self.heap[right].distance {
+ child = right;
+ }
+
+ if child < len && dist < self.heap[child].distance {
+ self.heap.swap(i, child);
+ i = child;
+ } else {
+ break;
+ }
}
}
- /// Extract the results from this neighborhood.
- fn into_vec(self) -> Vec<Neighbor<V, D>> {
- self.heap
- .into_sorted_vec()
- .into_iter()
- .map(|c| c.0)
- .collect()
+ /// Sort the heap from smallest to largest distance.
+ fn sort(&mut self) {
+ for i in (0..self.heap.len()).rev() {
+ self.heap.swap(0, i);
+ self.bubble_down(0, i);
+ }
}
}
-impl<K, V> Neighborhood<K, V> for HeapNeighborhood<K, V, K::Distance>
+impl<'a, K, V> Neighborhood<K, V> for HeapNeighborhood<'a, K, V, K::Distance>
where
K: Copy + Proximity<V>,
{
@@ -314,20 +320,21 @@ where
}
fn consider(&mut self, item: V) -> K::Distance {
- let candidate = Candidate::new(self.target, item);
- let distance = candidate.0.distance;
+ let distance = self.target.distance(&item);
if self.contains(distance) {
- let heap = &mut self.heap;
-
- if heap.len() == self.k {
- heap.pop();
+ let neighbor = Neighbor::new(item, distance);
+
+ if self.heap.len() < self.k {
+ self.heap.push(neighbor);
+ self.bubble_up(self.heap.len() - 1);
+ } else {
+ self.heap[0] = neighbor;
+ self.bubble_down(0, self.heap.len());
}
- heap.push(candidate);
-
- if heap.len() == self.k {
- self.threshold = heap.peek().map(|c| c.0.distance)
+ if self.heap.len() == self.k {
+ self.threshold = Some(self.heap[0].distance);
}
}
@@ -370,21 +377,74 @@ pub trait NearestNeighbors<K: Proximity<V>, V = K> {
}
/// Returns the up to `k` nearest neighbors to `target`.
+ ///
+ /// The result will be sorted from nearest to farthest.
fn k_nearest(&self, target: &K, k: usize) -> Vec<Neighbor<&V, K::Distance>> {
- self.search(HeapNeighborhood::new(target, k, None))
- .into_vec()
+ let mut neighbors = Vec::with_capacity(k);
+ self.search(HeapNeighborhood::new(target, k, None, &mut neighbors))
+ .sort();
+ neighbors
}
/// Returns the up to `k` nearest neighbors to `target` within the distance `threshold`.
- fn k_nearest_within<D>(&self, target: &K, k: usize, threshold: D) -> Vec<Neighbor<&V, K::Distance>>
+ ///
+ /// The result will be sorted from nearest to farthest.
+ fn k_nearest_within<D>(
+ &self,
+ target: &K,
+ k: usize,
+ threshold: D,
+ ) -> Vec<Neighbor<&V, K::Distance>>
where
D: TryInto<K::Distance>,
{
+ let mut neighbors = Vec::with_capacity(k);
+
if let Ok(distance) = threshold.try_into() {
- self.search(HeapNeighborhood::new(target, k, Some(distance)))
- .into_vec()
- } else {
- Vec::new()
+ self.search(HeapNeighborhood::new(
+ target,
+ k,
+ Some(distance),
+ &mut neighbors,
+ ))
+ .sort();
+ }
+
+ neighbors
+ }
+
+ /// Merges up to `k` nearest neighbors into an existing vector.
+ ///
+ /// The `neigbors` vector should either be empty, or populated by a previous call to
+ /// `merge_k_nearest()`. This method assumes a particular ordering that makes merging new
+ /// results efficient. If you want the results ordered from nearest to farthest, you must sort
+ /// it yourself.
+ fn merge_k_nearest<'v>(
+ &'v self,
+ target: &K,
+ k: usize,
+ neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
+ ) {
+ self.search(HeapNeighborhood::new(target, k, None, neighbors));
+ }
+
+ /// Merges up to `k` nearest neighbors within the `threshold` into an existing vector.
+ ///
+ /// The `neigbors` vector should either be empty, or populated by a previous call to
+ /// `merge_k_nearest()`. This method assumes a particular ordering that makes merging new
+ /// results efficient. If you want the results ordered from nearest to farthest, you must sort
+ /// it yourself.
+ fn merge_k_nearest_within<'v, D>(
+ &'v self,
+ target: &K,
+ k: usize,
+ neighbors: &mut Vec<Neighbor<&'v V, K::Distance>>,
+ threshold: D,
+ ) where
+ D: TryInto<K::Distance>,
+ {
+ if let Ok(distance) = threshold.try_into() {
+ self.search(HeapNeighborhood::new(target, k, Some(distance), neighbors));
}
}
@@ -404,6 +464,7 @@ pub mod tests {
use super::*;
use crate::exhaustive::ExhaustiveSearch;
+ use crate::util::Ordered;
use rand::prelude::*;
@@ -491,6 +552,30 @@ pub mod tests {
Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
]
);
+
+ let mut neighbors = Vec::new();
+ index.merge_k_nearest(&target, 3, &mut neighbors);
+ neighbors.sort_by_key(|n| Ordered::new(n.distance));
+ assert_eq!(
+ neighbors,
+ vec![
+ Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
+ Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
+ Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
+ ]
+ );
+
+ neighbors.drain(0..2);
+ index.merge_k_nearest_within(&target, 3, &mut neighbors, 6.0);
+ neighbors.sort_by_key(|n| Ordered::new(n.distance));
+ assert_eq!(
+ neighbors,
+ vec![
+ Neighbor::new(&Euclidean([1.0, 2.0, 2.0]), 3.0),
+ Neighbor::new(&Euclidean([3.0, 4.0, 0.0]), 5.0),
+ Neighbor::new(&Euclidean([2.0, 3.0, 6.0]), 7.0),
+ ]
+ );
}
fn test_random_points<T, F>(from_iter: &F)