From 1c560791902a4ef72efa671106d8f6d97fea50c1 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Sun, 19 Apr 2020 16:40:29 -0400 Subject: metric/kd: Implement k-d trees --- src/metric/kd.rs | 217 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 src/metric/kd.rs (limited to 'src/metric/kd.rs') diff --git a/src/metric/kd.rs b/src/metric/kd.rs new file mode 100644 index 0000000..ab0cd2e --- /dev/null +++ b/src/metric/kd.rs @@ -0,0 +1,217 @@ +//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree). + +use super::{Metric, NearestNeighbors, Neighborhood}; + +use ordered_float::OrderedFloat; + +use std::iter::FromIterator; + +/// A point in Cartesian space. +pub trait Cartesian { + /// Returns the number of dimensions necessary to describe this point. + fn dimensions(&self) -> usize; + + /// Returns the value of the `i`th coordinate of this point (`i < self.dimensions()`). + fn coordinate(&self, i: usize) -> f64; +} + +/// Blanket [Cartesian] implementation for references. +impl<'a, T: Cartesian> Cartesian for &'a T { + fn dimensions(&self) -> usize { + (*self).dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + (*self).coordinate(i) + } +} + +/// Standard cartesian space. +impl Cartesian for [f64] { + fn dimensions(&self) -> usize { + self.len() + } + + fn coordinate(&self, i: usize) -> f64 { + self[i] + } +} + +/// A node in a k-d tree. +#[derive(Debug)] +struct KdNode { + /// The value stored in this node. + item: T, + /// The left subtree, if any. + left: Option>, + /// The right subtree, if any. + right: Option>, +} + +trait KdSearch<'a, T, U, N> { + /// Recursively search for nearest neighbors. + fn search(&'a self, i: usize, neighborhood: &mut N); + + /// Search the left subtree. + fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N); + + /// Search the right subtree. + fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N); +} + +impl<'a, T, U, N> KdSearch<'a, T, U, N> for KdNode +where + T: 'a + Cartesian, + U: Cartesian + Metric<&'a T>, + N: Neighborhood<&'a T, U>, +{ + fn search(&'a self, i: usize, neighborhood: &mut N) { + neighborhood.consider(&self.item); + + let distance = neighborhood.target().coordinate(i) - self.item.coordinate(i); + let j = (i + 1) % self.item.dimensions(); + if distance <= 0.0 { + self.search_left(j, distance, neighborhood); + self.search_right(j, -distance, neighborhood); + } else { + self.search_right(j, -distance, neighborhood); + self.search_left(j, distance, neighborhood); + } + } + + fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N) { + if let Some(left) = &self.left { + if neighborhood.contains(distance) { + left.search(i, neighborhood); + } + } + } + + fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N) { + if let Some(right) = &self.right { + if neighborhood.contains(distance) { + right.search(i, neighborhood); + } + } + } +} + +impl KdNode { + /// Create a new KdNode. + fn new(i: usize, mut items: Vec) -> Option> { + if items.is_empty() { + return None; + } + + items.sort_unstable_by_key(|x| OrderedFloat::from(x.coordinate(i))); + + let mid = items.len() / 2; + let right: Vec = items.drain((mid + 1)..).collect(); + let item = items.pop().unwrap(); + let j = (i + 1) % item.dimensions(); + Some(Box::new(Self { + item, + left: Self::new(j, items), + right: Self::new(j, right), + })) + } +} + +/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). +#[derive(Debug)] +pub struct KdTree { + root: Option>>, +} + +impl FromIterator for KdTree { + /// Create a new k-d tree from a set of points. + fn from_iter>(items: I) -> Self { + Self { + root: KdNode::new(0, items.into_iter().collect()), + } + } +} + +impl NearestNeighbors for KdTree +where + T: Cartesian, + U: Cartesian + Metric, +{ + fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N + where + T: 'a, + U: 'b, + N: Neighborhood<&'a T, &'b U>, + { + if let Some(root) = &self.root { + root.search(0, &mut neighborhood); + } + neighborhood + } +} + +/// An iterator that the moves values out of a k-d tree. +#[derive(Debug)] +pub struct IntoIter { + stack: Vec>>, +} + +impl IntoIter { + fn new(node: Option>>) -> Self { + Self { + stack: node.into_iter().collect(), + } + } +} + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + self.stack.pop().map(|node| { + self.stack.extend(node.left); + self.stack.extend(node.right); + node.item + }) + } +} + +impl IntoIterator for KdTree { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.root) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::metric::tests::{test_nearest_neighbors, Point}; + use crate::metric::SquaredDistance; + + impl Metric<[f64]> for Point { + type Distance = SquaredDistance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + self.0.distance(other) + } + } + + impl Cartesian for Point { + fn dimensions(&self) -> usize { + self.0.dimensions() + } + + fn coordinate(&self, i: usize) -> f64 { + self.0.coordinate(i) + } + } + + #[test] + fn test_kd_tree() { + test_nearest_neighbors(KdTree::from_iter); + } +} -- cgit v1.2.3 From 5de377b2b00a927a4f6463c1c5a5fd18606ad006 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Thu, 30 Apr 2020 22:51:06 -0400 Subject: metric/kd: Prune k-d tree searches more aggressively --- src/metric/kd.rs | 116 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 65 insertions(+), 51 deletions(-) (limited to 'src/metric/kd.rs') diff --git a/src/metric/kd.rs b/src/metric/kd.rs index ab0cd2e..db1b2bd 100644 --- a/src/metric/kd.rs +++ b/src/metric/kd.rs @@ -7,7 +7,7 @@ use ordered_float::OrderedFloat; use std::iter::FromIterator; /// A point in Cartesian space. -pub trait Cartesian { +pub trait Cartesian: Metric<[f64]> { /// Returns the number of dimensions necessary to describe this point. fn dimensions(&self) -> usize; @@ -26,6 +26,15 @@ impl<'a, T: Cartesian> Cartesian for &'a T { } } +/// Blanket [Metric<[f64]>](Metric) implementation for [Cartesian] references. +impl<'a, T: Cartesian> Metric<[f64]> for &'a T { + type Distance = T::Distance; + + fn distance(&self, other: &[f64]) -> Self::Distance { + (*self).distance(other) + } +} + /// Standard cartesian space. impl Cartesian for [f64] { fn dimensions(&self) -> usize { @@ -37,6 +46,21 @@ impl Cartesian for [f64] { } } +/// Marker trait for cartesian metric spaces. +pub trait CartesianMetric: + Cartesian + Metric>::Distance> +{ +} + +/// Blanket [CartesianMetric] implementation for cartesian spaces with compatible metric distance +/// types. +impl CartesianMetric for U +where + T: ?Sized, + U: ?Sized + Cartesian + Metric>::Distance>, +{ +} + /// A node in a k-d tree. #[derive(Debug)] struct KdNode { @@ -48,54 +72,6 @@ struct KdNode { right: Option>, } -trait KdSearch<'a, T, U, N> { - /// Recursively search for nearest neighbors. - fn search(&'a self, i: usize, neighborhood: &mut N); - - /// Search the left subtree. - fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N); - - /// Search the right subtree. - fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N); -} - -impl<'a, T, U, N> KdSearch<'a, T, U, N> for KdNode -where - T: 'a + Cartesian, - U: Cartesian + Metric<&'a T>, - N: Neighborhood<&'a T, U>, -{ - fn search(&'a self, i: usize, neighborhood: &mut N) { - neighborhood.consider(&self.item); - - let distance = neighborhood.target().coordinate(i) - self.item.coordinate(i); - let j = (i + 1) % self.item.dimensions(); - if distance <= 0.0 { - self.search_left(j, distance, neighborhood); - self.search_right(j, -distance, neighborhood); - } else { - self.search_right(j, -distance, neighborhood); - self.search_left(j, distance, neighborhood); - } - } - - fn search_left(&'a self, i: usize, distance: f64, neighborhood: &mut N) { - if let Some(left) = &self.left { - if neighborhood.contains(distance) { - left.search(i, neighborhood); - } - } - } - - fn search_right(&'a self, i: usize, distance: f64, neighborhood: &mut N) { - if let Some(right) = &self.right { - if neighborhood.contains(distance) { - right.search(i, neighborhood); - } - } - } -} - impl KdNode { /// Create a new KdNode. fn new(i: usize, mut items: Vec) -> Option> { @@ -115,6 +91,40 @@ impl KdNode { right: Self::new(j, right), })) } + + /// Recursively search for nearest neighbors. + fn search<'a, U, N>(&'a self, i: usize, closest: &mut [f64], neighborhood: &mut N) + where + T: 'a, + U: CartesianMetric<&'a T>, + N: Neighborhood<&'a T, U>, + { + neighborhood.consider(&self.item); + + let target = neighborhood.target(); + let ti = target.coordinate(i); + let si = self.item.coordinate(i); + let j = (i + 1) % self.item.dimensions(); + + let (near, far) = if ti <= si { + (&self.left, &self.right) + } else { + (&self.right, &self.left) + }; + + if let Some(near) = near { + near.search(j, closest, neighborhood); + } + + if let Some(far) = far { + let saved = closest[i]; + closest[i] = si; + if neighborhood.contains_distance(target.distance(closest)) { + far.search(j, closest, neighborhood); + } + closest[i] = saved; + } + } } /// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). @@ -135,7 +145,7 @@ impl FromIterator for KdTree { impl NearestNeighbors for KdTree where T: Cartesian, - U: Cartesian + Metric, + U: CartesianMetric, { fn search<'a, 'b, N>(&'a self, mut neighborhood: N) -> N where @@ -143,8 +153,12 @@ where U: 'b, N: Neighborhood<&'a T, &'b U>, { + let target = neighborhood.target(); + let dims = target.dimensions(); + let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); + if let Some(root) = &self.root { - root.search(0, &mut neighborhood); + root.search(0, &mut closest, &mut neighborhood); } neighborhood } -- cgit v1.2.3 From e9a81a6d0df149252164003975addf175d5c6f4b Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Thu, 23 Apr 2020 09:55:13 -0400 Subject: metric/kd: Flatten the tree representation --- src/metric/kd.rs | 113 ++++++++++++++++++++++++++----------------------------- 1 file changed, 54 insertions(+), 59 deletions(-) (limited to 'src/metric/kd.rs') diff --git a/src/metric/kd.rs b/src/metric/kd.rs index db1b2bd..2caf4a3 100644 --- a/src/metric/kd.rs +++ b/src/metric/kd.rs @@ -66,61 +66,71 @@ where struct KdNode { /// The value stored in this node. item: T, - /// The left subtree, if any. - left: Option>, - /// The right subtree, if any. - right: Option>, + /// The size of the left subtree. + left_len: usize, } impl KdNode { /// Create a new KdNode. - fn new(i: usize, mut items: Vec) -> Option> { - if items.is_empty() { - return None; + fn new(item: T) -> Self { + Self { item, left_len: 0 } + } + + /// Build a k-d tree recursively. + fn build(slice: &mut [KdNode], i: usize) { + if slice.is_empty() { + return; } - items.sort_unstable_by_key(|x| OrderedFloat::from(x.coordinate(i))); - - let mid = items.len() / 2; - let right: Vec = items.drain((mid + 1)..).collect(); - let item = items.pop().unwrap(); - let j = (i + 1) % item.dimensions(); - Some(Box::new(Self { - item, - left: Self::new(j, items), - right: Self::new(j, right), - })) + slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i))); + + let mid = slice.len() / 2; + slice.swap(0, mid); + + let (node, children) = slice.split_first_mut().unwrap(); + let (left, right) = children.split_at_mut(mid); + node.left_len = left.len(); + + let j = (i + 1) % node.item.dimensions(); + Self::build(left, j); + Self::build(right, j); } /// Recursively search for nearest neighbors. - fn search<'a, U, N>(&'a self, i: usize, closest: &mut [f64], neighborhood: &mut N) - where + fn recurse<'a, U, N>( + slice: &'a [KdNode], + i: usize, + closest: &mut [f64], + neighborhood: &mut N, + ) where T: 'a, U: CartesianMetric<&'a T>, N: Neighborhood<&'a T, U>, { - neighborhood.consider(&self.item); + let (node, children) = slice.split_first().unwrap(); + neighborhood.consider(&node.item); let target = neighborhood.target(); let ti = target.coordinate(i); - let si = self.item.coordinate(i); - let j = (i + 1) % self.item.dimensions(); + let ni = node.item.coordinate(i); + let j = (i + 1) % node.item.dimensions(); - let (near, far) = if ti <= si { - (&self.left, &self.right) + let (left, right) = children.split_at(node.left_len); + let (near, far) = if ti <= ni { + (left, right) } else { - (&self.right, &self.left) + (right, left) }; - if let Some(near) = near { - near.search(j, closest, neighborhood); + if !near.is_empty() { + Self::recurse(near, j, closest, neighborhood); } - if let Some(far) = far { + if !far.is_empty() { let saved = closest[i]; - closest[i] = si; + closest[i] = ni; if neighborhood.contains_distance(target.distance(closest)) { - far.search(j, closest, neighborhood); + Self::recurse(far, j, closest, neighborhood); } closest[i] = saved; } @@ -129,16 +139,14 @@ impl KdNode { /// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree). #[derive(Debug)] -pub struct KdTree { - root: Option>>, -} +pub struct KdTree(Vec>); impl FromIterator for KdTree { /// Create a new k-d tree from a set of points. fn from_iter>(items: I) -> Self { - Self { - root: KdNode::new(0, items.into_iter().collect()), - } + let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect(); + KdNode::build(nodes.as_mut_slice(), 0); + Self(nodes) } } @@ -153,40 +161,27 @@ where U: 'b, N: Neighborhood<&'a T, &'b U>, { - let target = neighborhood.target(); - let dims = target.dimensions(); - let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); + if !self.0.is_empty() { + let target = neighborhood.target(); + let dims = target.dimensions(); + let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect(); - if let Some(root) = &self.root { - root.search(0, &mut closest, &mut neighborhood); + KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood); } + neighborhood } } /// An iterator that the moves values out of a k-d tree. #[derive(Debug)] -pub struct IntoIter { - stack: Vec>>, -} - -impl IntoIter { - fn new(node: Option>>) -> Self { - Self { - stack: node.into_iter().collect(), - } - } -} +pub struct IntoIter(std::vec::IntoIter>); impl Iterator for IntoIter { type Item = T; fn next(&mut self) -> Option { - self.stack.pop().map(|node| { - self.stack.extend(node.left); - self.stack.extend(node.right); - node.item - }) + self.0.next().map(|n| n.item) } } @@ -195,7 +190,7 @@ impl IntoIterator for KdTree { type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - IntoIter::new(self.root) + IntoIter(self.0.into_iter()) } } -- cgit v1.2.3