From 48a8abb94e1318f67bbd2809186c62009456d7c6 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Tue, 5 May 2020 16:30:34 -0400 Subject: metric: Relax Distances to have only a partial order --- src/metric.rs | 47 ++++++++++++++++++++--------------------------- src/metric/kd.rs | 6 ++---- src/metric/vp.rs | 4 ++-- 3 files changed, 24 insertions(+), 33 deletions(-) (limited to 'src') diff --git a/src/metric.rs b/src/metric.rs index 268aefd..ff996b9 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -6,12 +6,22 @@ pub mod kd; pub mod soft; pub mod vp; -use ordered_float::OrderedFloat; - use std::cmp::Ordering; use std::collections::BinaryHeap; use std::iter::FromIterator; +/// A wrapper that converts a partial ordering into a total one by panicking. +#[derive(Debug, PartialEq, PartialOrd)] +struct Ordered(T); + +impl Ord for Ordered { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} + +impl Eq for Ordered {} + /// An [order embedding](https://en.wikipedia.org/wiki/Order_embedding) for distances. /// /// Implementations of this trait must satisfy, for all non-negative distances `x` and `y`: @@ -22,34 +32,18 @@ use std::iter::FromIterator; /// This trait exists to optimize the common case where distances can be compared more efficiently /// than their exact values can be computed. For example, taking the square root can be avoided /// when comparing Euclidean distances (see [SquaredDistance]). -pub trait Distance: Copy + From + Into + Ord {} - -/// A raw numerical distance. -#[derive(Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] -pub struct RawDistance(OrderedFloat); - -impl From for RawDistance { - fn from(value: f64) -> Self { - Self(value.into()) - } -} - -impl From for f64 { - fn from(value: RawDistance) -> Self { - value.0.into_inner() - } -} +pub trait Distance: Copy + From + Into + PartialOrd {} -impl Distance for RawDistance {} +impl Distance for f64 {} /// A squared distance, to avoid computing square roots unless absolutely necessary. -#[derive(Debug, Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] -pub struct SquaredDistance(OrderedFloat); +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +pub struct SquaredDistance(f64); impl SquaredDistance { /// Create a SquaredDistance from an already squared value. pub fn from_squared(value: f64) -> Self { - Self(value.into()) + Self(value) } } @@ -61,7 +55,7 @@ impl From for SquaredDistance { impl From for f64 { fn from(value: SquaredDistance) -> Self { - value.0.into_inner().sqrt() + value.0.sqrt() } } @@ -69,8 +63,7 @@ impl Distance for SquaredDistance {} /// A [metric space](https://en.wikipedia.org/wiki/Metric_space). pub trait Metric { - /// The type used to represent distances. Use [RawDistance] to compare the actual values - /// directly, or another type if comparisons can be implemented more efficiently. + /// The type used to represent distances. type Distance: Distance; /// Computes the distance between this point and another point. This function must satisfy @@ -153,7 +146,7 @@ impl PartialOrd for Candidate { impl Ord for Candidate { fn cmp(&self, other: &Self) -> Ordering { - self.distance.cmp(&other.distance) + self.partial_cmp(other).unwrap() } } diff --git a/src/metric/kd.rs b/src/metric/kd.rs index 2caf4a3..6ea3809 100644 --- a/src/metric/kd.rs +++ b/src/metric/kd.rs @@ -1,8 +1,6 @@ //! [k-d trees](https://en.wikipedia.org/wiki/K-d_tree). -use super::{Metric, NearestNeighbors, Neighborhood}; - -use ordered_float::OrderedFloat; +use super::{Metric, NearestNeighbors, Neighborhood, Ordered}; use std::iter::FromIterator; @@ -82,7 +80,7 @@ impl KdNode { return; } - slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i))); + slice.sort_unstable_by_key(|n| Ordered(n.item.coordinate(i))); let mid = slice.len() / 2; slice.swap(0, mid); diff --git a/src/metric/vp.rs b/src/metric/vp.rs index fae62e5..d6e05df 100644 --- a/src/metric/vp.rs +++ b/src/metric/vp.rs @@ -1,6 +1,6 @@ //! [Vantage-point trees](https://en.wikipedia.org/wiki/Vantage-point_tree). -use super::{Metric, NearestNeighbors, Neighborhood}; +use super::{Metric, NearestNeighbors, Neighborhood, Ordered}; use std::iter::FromIterator; @@ -29,7 +29,7 @@ impl VpNode { fn build(slice: &mut [VpNode]) { if let Some((node, children)) = slice.split_first_mut() { let item = &node.item; - children.sort_by_cached_key(|n| item.distance(&n.item)); + children.sort_by_cached_key(|n| Ordered(item.distance(&n.item))); let (inside, outside) = children.split_at_mut(children.len() / 2); if let Some(last) = inside.last() { -- cgit v1.2.3