summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/kd.rs155
1 files changed, 155 insertions, 0 deletions
diff --git a/src/kd.rs b/src/kd.rs
index 97616e7..4f591f9 100644
--- a/src/kd.rs
+++ b/src/kd.rs
@@ -333,6 +333,156 @@ where
V: Coordinates,
{}
+/// A node in a flat k-d tree.
+#[derive(Debug)]
+struct FlatKdNode<T> {
+ /// The vantage point itself.
+ item: T,
+ /// The size of the left subtree.
+ left_len: usize,
+}
+
+impl<T: Coordinates> FlatKdNode<T> {
+ /// Create a new FlatKdNode.
+ fn new(item: T) -> Self {
+ Self {
+ item,
+ left_len: 0,
+ }
+ }
+
+ /// Create a balanced tree.
+ fn balanced<I: IntoIterator<Item = T>>(items: I) -> Vec<Self> {
+ let mut nodes: Vec<_> = items
+ .into_iter()
+ .map(Self::new)
+ .collect();
+
+ Self::balance_recursive(&mut nodes, 0);
+
+ nodes
+ }
+
+ /// Create a balanced subtree.
+ fn balance_recursive(nodes: &mut [Self], level: usize) {
+ if !nodes.is_empty() {
+ nodes.sort_by_cached_key(|x| Ordered::new(x.item.coord(level)));
+
+ let mid = nodes.len() / 2;
+ nodes.swap(0, mid);
+
+ let (node, children) = nodes.split_first_mut().unwrap();
+ let (left, right) = children.split_at_mut(mid);
+ node.left_len = left.len();
+
+ let next = (level + 1) % node.item.dims();
+ Self::balance_recursive(left, next);
+ Self::balance_recursive(right, next);
+ }
+ }
+}
+
+impl<'a, K, V, N> KdSearch<K, &'a V, N> for &'a [FlatKdNode<V>]
+where
+ K: KdProximity<&'a V>,
+ V: Coordinates,
+ N: Neighborhood<K, &'a V>,
+{
+ fn item(self) -> &'a V {
+ &self[0].item
+ }
+
+ fn left(self) -> Option<Self> {
+ let end = self[0].left_len + 1;
+ if end > 1 {
+ Some(&self[1..end])
+ } else {
+ None
+ }
+ }
+
+ fn right(self) -> Option<Self> {
+ let start = self[0].left_len + 1;
+ if start < self.len() {
+ Some(&self[start..])
+ } else {
+ None
+ }
+ }
+}
+
+/// A [k-d tree] stored as a flat array.
+///
+/// A FlatKdTree is always balanced and usually more efficient than a [KdTree], but doesn't support
+/// dynamic updates.
+///
+/// [k-d tree]: https://en.wikipedia.org/wiki/K-d_tree
+#[derive(Debug)]
+pub struct FlatKdTree<T> {
+ nodes: Vec<FlatKdNode<T>>,
+}
+
+impl<T: Coordinates> FlatKdTree<T> {
+ /// Create a balanced tree out of a sequence of items.
+ pub fn balanced<I: IntoIterator<Item = T>>(items: I) -> Self {
+ Self {
+ nodes: FlatKdNode::balanced(items),
+ }
+ }
+}
+
+impl<T: Coordinates> FromIterator<T> for FlatKdTree<T> {
+ fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
+ Self::balanced(items)
+ }
+}
+
+/// An iterator that moves values out of a flat k-d tree.
+#[derive(Debug)]
+pub struct FlatIntoIter<T>(std::vec::IntoIter<FlatKdNode<T>>);
+
+impl<T> Iterator for FlatIntoIter<T> {
+ type Item = T;
+
+ fn next(&mut self) -> Option<T> {
+ self.0.next().map(|n| n.item)
+ }
+}
+
+impl<T> IntoIterator for FlatKdTree<T> {
+ type Item = T;
+ type IntoIter = FlatIntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ FlatIntoIter(self.nodes.into_iter())
+ }
+}
+
+impl<K, V> NearestNeighbors<K, V> for FlatKdTree<V>
+where
+ K: KdProximity<V>,
+ V: Coordinates,
+{
+ fn search<'k, 'v, N>(&'v self, mut neighborhood: N) -> N
+ where
+ K: 'k,
+ V: 'v,
+ N: Neighborhood<&'k K, &'v V>,
+ {
+ if !self.nodes.is_empty() {
+ let mut closest = neighborhood.target().as_vec();
+ self.nodes.as_slice().search(0, &mut closest, &mut neighborhood);
+ }
+ neighborhood
+ }
+}
+
+impl<K, V> ExactNeighbors<K, V> for FlatKdTree<V>
+where
+ K: KdMetric<V>,
+ V: Coordinates,
+{}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -354,4 +504,9 @@ mod tests {
tree
});
}
+
+ #[test]
+ fn test_flat_kd_tree() {
+ test_nearest_neighbors(FlatKdTree::from_iter);
+ }
}