summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorTavian Barnes <tavianator@tavianator.com>2020-05-05 16:30:34 -0400
committerTavian Barnes <tavianator@tavianator.com>2020-05-05 16:30:34 -0400
commit48a8abb94e1318f67bbd2809186c62009456d7c6 (patch)
tree7977a218ef2e935a4f4f43b65c27cb4377dd0f2f /src
parentbae2b127e377842a8131901cdb83ed4598bb3f21 (diff)
downloadkd-forest-48a8abb94e1318f67bbd2809186c62009456d7c6.tar.xz
metric: Relax Distances to have only a partial order
Diffstat (limited to 'src')
-rw-r--r--src/metric.rs47
-rw-r--r--src/metric/kd.rs6
-rw-r--r--src/metric/vp.rs4
3 files changed, 24 insertions, 33 deletions
diff --git a/src/metric.rs b/src/metric.rs
index 268aefd..ff996b9 100644
--- a/src/metric.rs
+++ b/src/metric.rs
@@ -6,12 +6,22 @@ pub mod kd;
pub mod soft;
pub mod vp;
-use ordered_float::OrderedFloat;
-
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::iter::FromIterator;
+/// A wrapper that converts a partial ordering into a total one by panicking.
+#[derive(Debug, PartialEq, PartialOrd)]
+struct Ordered<T>(T);
+
+impl<T: PartialOrd> Ord for Ordered<T> {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.partial_cmp(other).unwrap()
+ }
+}
+
+impl<T: PartialEq> Eq for Ordered<T> {}
+
/// An [order embedding](https://en.wikipedia.org/wiki/Order_embedding) for distances.
///
/// Implementations of this trait must satisfy, for all non-negative distances `x` and `y`:
@@ -22,34 +32,18 @@ use std::iter::FromIterator;
/// This trait exists to optimize the common case where distances can be compared more efficiently
/// than their exact values can be computed. For example, taking the square root can be avoided
/// when comparing Euclidean distances (see [SquaredDistance]).
-pub trait Distance: Copy + From<f64> + Into<f64> + Ord {}
-
-/// A raw numerical distance.
-#[derive(Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd)]
-pub struct RawDistance(OrderedFloat<f64>);
-
-impl From<f64> for RawDistance {
- fn from(value: f64) -> Self {
- Self(value.into())
- }
-}
-
-impl From<RawDistance> for f64 {
- fn from(value: RawDistance) -> Self {
- value.0.into_inner()
- }
-}
+pub trait Distance: Copy + From<f64> + Into<f64> + PartialOrd {}
-impl Distance for RawDistance {}
+impl Distance for f64 {}
/// A squared distance, to avoid computing square roots unless absolutely necessary.
-#[derive(Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd)]
-pub struct SquaredDistance(OrderedFloat<f64>);
+#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
+pub struct SquaredDistance(f64);
impl SquaredDistance {
/// Create a SquaredDistance from an already squared value.
pub fn from_squared(value: f64) -> Self {
- Self(value.into())
+ Self(value)
}
}
@@ -61,7 +55,7 @@ impl From<f64> for SquaredDistance {
impl From<SquaredDistance> for f64 {
fn from(value: SquaredDistance) -> Self {
- value.0.into_inner().sqrt()
+ value.0.sqrt()
}
}
@@ -69,8 +63,7 @@ impl Distance for SquaredDistance {}
/// A [metric space](https://en.wikipedia.org/wiki/Metric_space).
pub trait Metric<T: ?Sized = Self> {
- /// The type used to represent distances. Use [RawDistance] to compare the actual values
- /// directly, or another type if comparisons can be implemented more efficiently.
+ /// The type used to represent distances.
type Distance: Distance;
/// Computes the distance between this point and another point. This function must satisfy
@@ -153,7 +146,7 @@ impl<T, D: Distance> PartialOrd for Candidate<T, D> {
impl<T, D: Distance> Ord for Candidate<T, D> {
fn cmp(&self, other: &Self) -> Ordering {
- self.distance.cmp(&other.distance)
+ self.partial_cmp(other).unwrap()
}
}
diff --git a/src/metric/kd.rs b/src/metric/kd.rs
index 2caf4a3..6ea3809 100644
--- a/src/metric/kd.rs
+++ b/src/metric/kd.rs
@@ -1,8 +1,6 @@
//! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree).
-use super::{Metric, NearestNeighbors, Neighborhood};
-
-use ordered_float::OrderedFloat;
+use super::{Metric, NearestNeighbors, Neighborhood, Ordered};
use std::iter::FromIterator;
@@ -82,7 +80,7 @@ impl<T: Cartesian> KdNode<T> {
return;
}
- slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i)));
+ slice.sort_unstable_by_key(|n| Ordered(n.item.coordinate(i)));
let mid = slice.len() / 2;
slice.swap(0, mid);
diff --git a/src/metric/vp.rs b/src/metric/vp.rs
index fae62e5..d6e05df 100644
--- a/src/metric/vp.rs
+++ b/src/metric/vp.rs
@@ -1,6 +1,6 @@
//! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree).
-use super::{Metric, NearestNeighbors, Neighborhood};
+use super::{Metric, NearestNeighbors, Neighborhood, Ordered};
use std::iter::FromIterator;
@@ -29,7 +29,7 @@ impl<T: Metric> VpNode<T> {
fn build(slice: &mut [VpNode<T>]) {
if let Some((node, children)) = slice.split_first_mut() {
let item = &node.item;
- children.sort_by_cached_key(|n| item.distance(&n.item));
+ children.sort_by_cached_key(|n| Ordered(item.distance(&n.item)));
let (inside, outside) = children.split_at_mut(children.len() / 2);
if let Some(last) = inside.last() {