summaryrefslogtreecommitdiffstats
path: root/src/metric
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-04-19 16:55:17 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-05-02 13:19:05 -0400
commita4a75059f302de2a00971f1f485fcf4389710628 (patch)
tree76ffd65c5cc2e77d563cc813b9e3b9136d29d7a4 /src/metric
parente9a81a6d0df149252164003975addf175d5c6f4b (diff)
downloadkd-forest-a4a75059f302de2a00971f1f485fcf4389710628.tar.xz
metric/forest: Implement dynamized forests
Diffstat (limited to 'src/metric')
-rw-r--r--src/metric/forest.rs152
1 files changed, 152 insertions, 0 deletions
diff --git a/src/metric/forest.rs b/src/metric/forest.rs
new file mode 100644
index 0000000..f23c451
--- /dev/null
+++ b/src/metric/forest.rs
@@ -0,0 +1,152 @@
+//! [Dynamization](https://en.wikipedia.org/wiki/Dynamization) for nearest neighbor search.
+
+use super::kd::KdTree;
+use super::vp::VpTree;
+use super::{Metric, NearestNeighbors, Neighborhood};
+
+use std::iter::{Extend, Flatten, FromIterator};
+
+/// A dynamic wrapper for a static nearest neighbor search data structure.
+///
+/// This type applies [dynamization](https://en.wikipedia.org/wiki/Dynamization) to an arbitrary
+/// nearest neighbor search structure `T`, allowing new items to be added dynamically.
+#[derive(Debug)]
+pub struct Forest<T>(Vec<Option<T>>);
+
+impl<T, U> Forest<U>
+where
+ U: FromIterator<T> + IntoIterator<Item = T>,
+{
+ /// Create a new empty forest.
+ pub fn new() -> Self {
+ Self(Vec::new())
+ }
+
+ /// Add a new item to the forest.
+ pub fn push(&mut self, item: T) {
+ let mut items = vec![item];
+
+ for slot in &mut self.0 {
+ match slot.take() {
+ // Collect the items from any trees we encounter...
+ Some(tree) => {
+ items.extend(tree);
+ }
+ // ... and put them all in the first empty slot
+ None => {
+ *slot = Some(items.into_iter().collect());
+ return;
+ }
+ }
+ }
+
+ self.0.push(Some(items.into_iter().collect()));
+ }
+
+ /// Get the number of items in the forest.
+ pub fn len(&self) -> usize {
+ let mut len = 0;
+ for (i, slot) in self.0.iter().enumerate() {
+ if slot.is_some() {
+ len |= 1 << i;
+ }
+ }
+ len
+ }
+}
+
+impl<T, U> Extend<T> for Forest<U>
+where
+ U: FromIterator<T> + IntoIterator<Item = T>,
+{
+ fn extend<I: IntoIterator<Item = T>>(&mut self, items: I) {
+ for item in items {
+ self.push(item);
+ }
+ }
+}
+
+impl<T, U> FromIterator<T> for Forest<U>
+where
+ U: FromIterator<T> + IntoIterator<Item = T>,
+{
+ fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
+ let mut forest = Self::new();
+ forest.extend(items);
+ forest
+ }
+}
+
+type IntoIterImpl<T> = Flatten<Flatten<std::vec::IntoIter<Option<T>>>>;
+
+/// An iterator that moves items out of a forest.
+pub struct IntoIter<T: IntoIterator>(IntoIterImpl<T>);
+
+impl<T: IntoIterator> Iterator for IntoIter<T> {
+ type Item = T::Item;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ self.0.next()
+ }
+}
+
+impl<T: IntoIterator> IntoIterator for Forest<T> {
+ type Item = T::Item;
+ type IntoIter = IntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ IntoIter(self.0.into_iter().flatten().flatten())
+ }
+}
+
+impl<T, U, V> NearestNeighbors<T, U> for Forest<V>
+where
+ U: Metric<T>,
+ V: NearestNeighbors<T, U>,
+{
+ fn search<'a, 'b, N>(&'a self, neighborhood: N) -> N
+ where
+ T: 'a,
+ U: 'b,
+ N: Neighborhood<&'a T, &'b U>,
+ {
+ self.0
+ .iter()
+ .flatten()
+ .fold(neighborhood, |n, t| t.search(n))
+ }
+}
+
+/// A forest of k-d trees.
+pub type KdForest<T> = Forest<KdTree<T>>;
+
+/// A forest of vantage-point trees.
+pub type VpForest<T> = Forest<VpTree<T>>;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::metric::tests::test_nearest_neighbors;
+ use crate::metric::ExhaustiveSearch;
+
+ #[test]
+ fn test_exhaustive_forest() {
+ test_nearest_neighbors(Forest::<ExhaustiveSearch<_>>::from_iter);
+ }
+
+ #[test]
+ fn test_forest_forest() {
+ test_nearest_neighbors(Forest::<Forest<ExhaustiveSearch<_>>>::from_iter);
+ }
+
+ #[test]
+ fn test_kd_forest() {
+ test_nearest_neighbors(KdForest::from_iter);
+ }
+
+ #[test]
+ fn test_vp_forest() {
+ test_nearest_neighbors(VpForest::from_iter);
+ }
+}