From 57f4d9dbe851439b24e31977b8c5dc60e246dda3 Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Thu, 25 Jun 2020 11:44:47 -0400 Subject: cos: Add prenormalized cosine/angular distances, and an order embedding --- src/cos.rs | 348 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 332 insertions(+), 16 deletions(-) diff --git a/src/cos.rs b/src/cos.rs index 2fde4ce..3d3219c 100644 --- a/src/cos.rs +++ b/src/cos.rs @@ -1,17 +1,19 @@ //! [Cosine distance](https://en.wikipedia.org/wiki/Cosine_similarity). use crate::coords::Coordinates; -use crate::distance::{Metric, Proximity}; +use crate::distance::{Distance, Metric, Proximity, Value}; use num_traits::real::Real; use num_traits::{one, zero}; +use std::cmp::Ordering; + /// Compute the [cosine *similarity*] between two points. /// -/// This is not suitable for implementing [`Proximity::distance()`] because the result is reversed +/// Use [cosine_distance] instead if you are implementing [Proximity::distance()]. /// /// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity -/// [`Proximity::distance()`]: Proximity#method.distance +/// [Proximity::distance()]: Proximity#method.distance pub fn cosine_similarity(x: T, y: U) -> T::Value where T: Coordinates, @@ -91,16 +93,96 @@ where } } +/// Compute the [cosine *similarity*] between two pre-normalized (unit magnitude) points. +/// +/// Use [prenorm_cosine_distance] instead if you are implementing [Proximity::distance()]. +/// +/// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity +/// [`Proximity::distance()`]: Proximity#method.distance +pub fn prenorm_cosine_similarity(x: T, y: U) -> T::Value +where + T: Coordinates, + U: Coordinates, + T::Value: Real, +{ + debug_assert!(x.dims() == y.dims()); + + let mut dot: T::Value = zero(); + + for i in 0..x.dims() { + dot += x.coord(i) * y.coord(i); + } + + dot +} + +/// Compute the [cosine distance] between two pre-normalized (unit magnitude) points. +/// +/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity +pub fn prenorm_cosine_distance(x: T, y: U) -> T::Value +where + T: Coordinates, + U: Coordinates, + T::Value: Real, +{ + let one: T::Value = one(); + one - prenorm_cosine_similarity(x, y) +} + +/// Equips any [coordinate space] with the [cosine distance] function for pre-normalized (unit +/// magnitude) points. +/// +/// [coordinate space]: [Coordinates] +/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct PrenormCosine(pub T); + +impl Proximity for PrenormCosine +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &Self) -> Self::Distance { + prenorm_cosine_distance(&self.0, &other.0) + } +} + +impl Proximity for PrenormCosine +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &T) -> Self::Distance { + prenorm_cosine_distance(&self.0, other) + } +} + +impl Proximity> for T +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &PrenormCosine) -> Self::Distance { + prenorm_cosine_distance(self, &other.0) + } +} + /// Compute the [angular distance] between two points. /// /// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity -pub fn angular_distance(x: T, y: U) -> T::Value +pub fn angular_distance(x: T, y: U) -> AngularDistance where T: Coordinates, U: Coordinates, T::Value: Real, { - cosine_similarity(x, y).acos() + AngularDistance::from_cos(cosine_similarity(x, y)) } /// Equips any [coordinate space] with the [angular distance] metric. @@ -114,11 +196,12 @@ impl Proximity for Angular where T: Coordinates, T::Value: Real, + AngularDistance: Distance, { - type Distance = T::Value; + type Distance = AngularDistance; fn distance(&self, other: &Self) -> Self::Distance { - cosine_distance(&self.0, &other.0) + angular_distance(&self.0, &other.0) } } @@ -126,8 +209,9 @@ impl Proximity for Angular where T: Coordinates, T::Value: Real, + AngularDistance: Distance, { - type Distance = T::Value; + type Distance = AngularDistance; fn distance(&self, other: &T) -> Self::Distance { angular_distance(&self.0, other) @@ -138,8 +222,9 @@ impl Proximity> for T where T: Coordinates, T::Value: Real, + AngularDistance: Distance, { - type Distance = T::Value; + type Distance = AngularDistance; fn distance(&self, other: &Angular) -> Self::Distance { angular_distance(self, &other.0) @@ -151,6 +236,7 @@ impl Metric for Angular where T: Coordinates, T::Value: Real, + AngularDistance: Distance, {} /// Angular distance is a metric. @@ -158,6 +244,7 @@ impl Metric for Angular where T: Coordinates, T::Value: Real, + AngularDistance: Distance, {} /// Angular distance is a metric. @@ -165,12 +252,179 @@ impl Metric> for T where T: Coordinates, T::Value: Real, + AngularDistance: Distance, +{} + +/// Compute the [angular distance] between two points. +/// +/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity +pub fn prenorm_angular_distance(x: T, y: U) -> AngularDistance +where + T: Coordinates, + U: Coordinates, + T::Value: Real, +{ + AngularDistance::from_cos(prenorm_cosine_similarity(x, y)) +} + +/// Equips any [coordinate space] with the [angular distance] metric for pre-normalized (unit +/// magnitude) points. +/// +/// [coordinate space]: [Coordinates] +/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct PrenormAngular(pub T); + +impl Proximity for PrenormAngular +where + T: Coordinates, + T::Value: Real, + AngularDistance: Distance, +{ + type Distance = AngularDistance; + + fn distance(&self, other: &Self) -> Self::Distance { + prenorm_angular_distance(&self.0, &other.0) + } +} + +impl Proximity for PrenormAngular +where + T: Coordinates, + T::Value: Real, + AngularDistance: Distance, +{ + type Distance = AngularDistance; + + fn distance(&self, other: &T) -> Self::Distance { + prenorm_angular_distance(&self.0, other) + } +} + +impl Proximity> for T +where + T: Coordinates, + T::Value: Real, + AngularDistance: Distance, +{ + type Distance = AngularDistance; + + fn distance(&self, other: &PrenormAngular) -> Self::Distance { + prenorm_angular_distance(self, &other.0) + } +} + +/// Angular distance is a metric. +impl Metric for PrenormAngular +where + T: Coordinates, + T::Value: Real, + AngularDistance: Distance, {} +/// Angular distance is a metric. +impl Metric for PrenormAngular +where + T: Coordinates, + T::Value: Real, + AngularDistance: Distance, +{} + +/// Angular distance is a metric. +impl Metric> for T +where + T: Coordinates, + T::Value: Real, + AngularDistance: Distance, +{} + +/// An [angular distance]. +/// +/// This type stores the cosine of the angle, to avoid computing the expensive trancendental +/// `acos()` function until absolutely necessary. +/// +/// # use acap::distance::Distance; +/// # use acap::cos::AngularDistance; +/// let zero = AngularDistance::from_cos(1.0); +/// let pi_2 = AngularDistance::from_cos(0.0); +/// let pi = AngularDistance::from_cos(-1.0); +/// assert!(zero < pi_2 && pi_2 < pi); +/// +/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct AngularDistance(T); + +impl AngularDistance { + /// Creates an `AngularDistance` from the cosine of an angle. + pub fn from_cos(value: T) -> Self { + Self(value) + } + + /// Get the cosine of this angle. + pub fn cos(self) -> T { + self.0 + } +} + +impl PartialOrd for AngularDistance { + fn partial_cmp(&self, other: &AngularDistance) -> Option { + // acos() is decreasing, so swap the comparison order + other.0.partial_cmp(&self.0) + } +} + +macro_rules! impl_distance { + ($f:ty) => { + impl From> for $f { + #[inline] + fn from(value: AngularDistance<$f>) -> $f { + value.0.acos() + } + } + + impl PartialOrd<$f> for AngularDistance<$f> { + #[inline] + fn partial_cmp(&self, other: &$f) -> Option { + self.value().partial_cmp(other) + } + } + + impl PartialOrd> for $f { + #[inline] + fn partial_cmp(&self, other: &AngularDistance<$f>) -> Option { + self.partial_cmp(&other.value()) + } + } + + impl PartialEq<$f> for AngularDistance<$f> { + #[inline] + fn eq(&self, other: &$f) -> bool { + self.value() == *other + } + } + + impl PartialEq> for $f { + #[inline] + fn eq(&self, other: &AngularDistance<$f>) -> bool { + *self == other.value() + } + } + + impl Distance for AngularDistance<$f> { + type Value = $f; + } + } +} + +impl_distance!(f32); +impl_distance!(f64); + #[cfg(test)] mod tests { use super::*; + use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI, SQRT_2}; + #[test] fn test_cosine() { assert_eq!(cosine_distance([3.0, 4.0], [3.0, 4.0]), 0.0); @@ -179,17 +433,79 @@ mod tests { assert_eq!(cosine_distance([3.0, 4.0], [4.0, -3.0]), 1.0); } + #[test] + fn test_prenorm_cosine() { + assert_eq!(prenorm_cosine_distance([0.6, 0.8], [0.6, 0.8]), 0.0); + assert_eq!(prenorm_cosine_distance([0.6, 0.8], [-0.8, 0.6]), 1.0); + assert_eq!(prenorm_cosine_distance([0.6, 0.8], [-0.6, -0.8]), 2.0); + assert_eq!(prenorm_cosine_distance([0.6, 0.8], [0.8, -0.6]), 1.0); + } + #[test] fn test_angular() { - use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI}; + let zero = angular_distance([3.0, 4.0], [3.0, 4.0]); + let pi_4 = Angular([0.0, 1.0]).distance(&Angular([1.0, 1.0])); + let pi_2 = Angular([3.0, 4.0]).distance(&[-4.0, 3.0]); + let pi = [3.0, 4.0].distance(&Angular([-3.0, -4.0])); + + assert_eq!(zero.cos(), 1.0); + assert_eq!(pi_2.cos(), 0.0); + assert_eq!(pi.cos(), -1.0); + + assert_eq!(zero, 0.0); + + assert!(zero < pi_4); + assert!(zero < pi_2); + assert!(zero < pi); - assert_eq!(angular_distance([3.0, 4.0], [3.0, 4.0]), 0.0); + assert!(pi_4 < pi_2); + assert!(pi_4 < pi); - assert!((angular_distance([3.0, 4.0], [-4.0, 3.0]) - FRAC_PI_2).abs() < 1.0e-9); - assert!((angular_distance([3.0, 4.0], [-3.0, -4.0]) - PI).abs() < 1.0e-9); - assert!((angular_distance([3.0, 4.0], [4.0, -3.0]) - FRAC_PI_2).abs() < 1.0e-9); + assert!(pi_2 < pi); - assert!((angular_distance([0.0, 1.0], [1.0, 1.0]) - FRAC_PI_4).abs() < 1.0e-9); + assert!(FRAC_PI_4 < pi_2); + assert!(pi_2 > FRAC_PI_4); + + assert!(pi_2 < PI); + assert!(PI > pi_2); + + assert!((pi_4.value() - FRAC_PI_4).abs() < 1.0e-9); + assert!((pi_2.value() - FRAC_PI_2).abs() < 1.0e-9); + assert!((pi.value() - PI).abs() < 1.0e-9); } -} + #[test] + fn test_prenorm_angular() { + let sqrt_2_inv = 1.0 / SQRT_2; + + let zero = prenorm_angular_distance([0.6, 0.8], [0.6, 0.8]); + let pi_4 = PrenormAngular([0.0, 1.0]).distance(&PrenormAngular([sqrt_2_inv, sqrt_2_inv])); + let pi_2 = PrenormAngular([0.6, 0.8]).distance(&[-0.8, 0.6]); + let pi = [0.6, 0.8].distance(&PrenormAngular([-0.6, -0.8])); + + assert_eq!(zero.cos(), 1.0); + assert_eq!(pi_2.cos(), 0.0); + assert_eq!(pi.cos(), -1.0); + + assert_eq!(zero, 0.0); + + assert!(zero < pi_4); + assert!(zero < pi_2); + assert!(zero < pi); + + assert!(pi_4 < pi_2); + assert!(pi_4 < pi); + + assert!(pi_2 < pi); + + assert!(FRAC_PI_4 < pi_2); + assert!(pi_2 > FRAC_PI_4); + + assert!(pi_2 < PI); + assert!(PI > pi_2); + + assert!((pi_4.value() - FRAC_PI_4).abs() < 1.0e-9); + assert!((pi_2.value() - FRAC_PI_2).abs() < 1.0e-9); + assert!((pi.value() - PI).abs() < 1.0e-9); + } +} -- cgit v1.2.3