From 57f4d9dbe851439b24e31977b8c5dc60e246dda3 Mon Sep 17 00:00:00 2001
From: Tavian Barnes <tavianator@tavianator.com>
Date: Thu, 25 Jun 2020 11:44:47 -0400
Subject: cos: Add prenormalized cosine/angular distances, and an order
 embedding

---
 src/cos.rs | 348 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 332 insertions(+), 16 deletions(-)

(limited to 'src')

diff --git a/src/cos.rs b/src/cos.rs
index 2fde4ce..3d3219c 100644
--- a/src/cos.rs
+++ b/src/cos.rs
@@ -1,17 +1,19 @@
 //! [Cosine distance](https://en.wikipedia.org/wiki/Cosine_similarity).
 
 use crate::coords::Coordinates;
-use crate::distance::{Metric, Proximity};
+use crate::distance::{Distance, Metric, Proximity, Value};
 
 use num_traits::real::Real;
 use num_traits::{one, zero};
 
+use std::cmp::Ordering;
+
 /// Compute the [cosine *similarity*] between two points.
 ///
-/// This is not suitable for implementing [`Proximity::distance()`] because the result is reversed
+/// Use [cosine_distance] instead if you are implementing [Proximity::distance()].
 ///
 /// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity
-/// [`Proximity::distance()`]: Proximity#method.distance
+/// [Proximity::distance()]: Proximity#method.distance
 pub fn cosine_similarity<T, U>(x: T, y: U) -> T::Value
 where
     T: Coordinates,
@@ -91,16 +93,96 @@ where
     }
 }
 
+/// Compute the [cosine *similarity*] between two pre-normalized (unit magnitude) points.
+///
+/// Use [prenorm_cosine_distance] instead if you are implementing [Proximity::distance()].
+///
+/// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity
+/// [`Proximity::distance()`]: Proximity#method.distance
+pub fn prenorm_cosine_similarity<T, U>(x: T, y: U) -> T::Value
+where
+    T: Coordinates,
+    U: Coordinates<Value = T::Value>,
+    T::Value: Real,
+{
+    debug_assert!(x.dims() == y.dims());
+
+    let mut dot: T::Value = zero();
+
+    for i in 0..x.dims() {
+        dot += x.coord(i) * y.coord(i);
+    }
+
+    dot
+}
+
+/// Compute the [cosine distance] between two pre-normalized (unit magnitude) points.
+///
+/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity
+pub fn prenorm_cosine_distance<T, U>(x: T, y: U) -> T::Value
+where
+    T: Coordinates,
+    U: Coordinates<Value = T::Value>,
+    T::Value: Real,
+{
+    let one: T::Value = one();
+    one - prenorm_cosine_similarity(x, y)
+}
+
+/// Equips any [coordinate space] with the [cosine distance] function for pre-normalized (unit
+/// magnitude) points.
+///
+/// [coordinate space]: [Coordinates]
+/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub struct PrenormCosine<T>(pub T);
+
+impl<T> Proximity for PrenormCosine<T>
+where
+    T: Coordinates,
+    T::Value: Real,
+{
+    type Distance = T::Value;
+
+    fn distance(&self, other: &Self) -> Self::Distance {
+        prenorm_cosine_distance(&self.0, &other.0)
+    }
+}
+
+impl<T> Proximity<T> for PrenormCosine<T>
+where
+    T: Coordinates,
+    T::Value: Real,
+{
+    type Distance = T::Value;
+
+    fn distance(&self, other: &T) -> Self::Distance {
+        prenorm_cosine_distance(&self.0, other)
+    }
+}
+
+impl<T> Proximity<PrenormCosine<T>> for T
+where
+    T: Coordinates,
+    T::Value: Real,
+{
+    type Distance = T::Value;
+
+    fn distance(&self, other: &PrenormCosine<T>) -> Self::Distance {
+        prenorm_cosine_distance(self, &other.0)
+    }
+}
+
 /// Compute the [angular distance] between two points.
 ///
 /// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
-pub fn angular_distance<T, U>(x: T, y: U) -> T::Value
+pub fn angular_distance<T, U>(x: T, y: U) -> AngularDistance<T::Value>
 where
     T: Coordinates,
     U: Coordinates<Value = T::Value>,
     T::Value: Real,
 {
-    cosine_similarity(x, y).acos()
+    AngularDistance::from_cos(cosine_similarity(x, y))
 }
 
 /// Equips any [coordinate space] with the [angular distance] metric.
@@ -114,11 +196,12 @@ impl<T> Proximity for Angular<T>
 where
     T: Coordinates,
     T::Value: Real,
+    AngularDistance<T::Value>: Distance,
 {
-    type Distance = T::Value;
+    type Distance = AngularDistance<T::Value>;
 
     fn distance(&self, other: &Self) -> Self::Distance {
-        cosine_distance(&self.0, &other.0)
+        angular_distance(&self.0, &other.0)
     }
 }
 
@@ -126,8 +209,9 @@ impl<T> Proximity<T> for Angular<T>
 where
     T: Coordinates,
     T::Value: Real,
+    AngularDistance<T::Value>: Distance,
 {
-    type Distance = T::Value;
+    type Distance = AngularDistance<T::Value>;
 
     fn distance(&self, other: &T) -> Self::Distance {
         angular_distance(&self.0, other)
@@ -138,8 +222,9 @@ impl<T> Proximity<Angular<T>> for T
 where
     T: Coordinates,
     T::Value: Real,
+    AngularDistance<T::Value>: Distance,
 {
-    type Distance = T::Value;
+    type Distance = AngularDistance<T::Value>;
 
     fn distance(&self, other: &Angular<T>) -> Self::Distance {
         angular_distance(self, &other.0)
@@ -151,6 +236,7 @@ impl<T> Metric for Angular<T>
 where
     T: Coordinates,
     T::Value: Real,
+    AngularDistance<T::Value>: Distance,
 {}
 
 /// Angular distance is a metric.
@@ -158,6 +244,7 @@ impl<T> Metric<T> for Angular<T>
 where
     T: Coordinates,
     T::Value: Real,
+    AngularDistance<T::Value>: Distance,
 {}
 
 /// Angular distance is a metric.
@@ -165,12 +252,179 @@ impl<T> Metric<Angular<T>> for T
 where
     T: Coordinates,
     T::Value: Real,
+    AngularDistance<T::Value>: Distance,
+{}
+
+/// Compute the [angular distance] between two points.
+///
+/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
+pub fn prenorm_angular_distance<T, U>(x: T, y: U) -> AngularDistance<T::Value>
+where
+    T: Coordinates,
+    U: Coordinates<Value = T::Value>,
+    T::Value: Real,
+{
+    AngularDistance::from_cos(prenorm_cosine_similarity(x, y))
+}
+
+/// Equips any [coordinate space] with the [angular distance] metric for pre-normalized (unit
+/// magnitude) points.
+///
+/// [coordinate space]: [Coordinates]
+/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub struct PrenormAngular<T>(pub T);
+
+impl<T> Proximity for PrenormAngular<T>
+where
+    T: Coordinates,
+    T::Value: Real,
+    AngularDistance<T::Value>: Distance,
+{
+    type Distance = AngularDistance<T::Value>;
+
+    fn distance(&self, other: &Self) -> Self::Distance {
+        prenorm_angular_distance(&self.0, &other.0)
+    }
+}
+
+impl<T> Proximity<T> for PrenormAngular<T>
+where
+    T: Coordinates,
+    T::Value: Real,
+    AngularDistance<T::Value>: Distance,
+{
+    type Distance = AngularDistance<T::Value>;
+
+    fn distance(&self, other: &T) -> Self::Distance {
+        prenorm_angular_distance(&self.0, other)
+    }
+}
+
+impl<T> Proximity<PrenormAngular<T>> for T
+where
+    T: Coordinates,
+    T::Value: Real,
+    AngularDistance<T::Value>: Distance,
+{
+    type Distance = AngularDistance<T::Value>;
+
+    fn distance(&self, other: &PrenormAngular<T>) -> Self::Distance {
+        prenorm_angular_distance(self, &other.0)
+    }
+}
+
+/// Angular distance is a metric.
+impl<T> Metric for PrenormAngular<T>
+where
+    T: Coordinates,
+    T::Value: Real,
+    AngularDistance<T::Value>: Distance,
 {}
 
+/// Angular distance is a metric.
+impl<T> Metric<T> for PrenormAngular<T>
+where
+    T: Coordinates,
+    T::Value: Real,
+    AngularDistance<T::Value>: Distance,
+{}
+
+/// Angular distance is a metric.
+impl<T> Metric<PrenormAngular<T>> for T
+where
+    T: Coordinates,
+    T::Value: Real,
+    AngularDistance<T::Value>: Distance,
+{}
+
+/// An [angular distance].
+///
+/// This type stores the cosine of the angle, to avoid computing the expensive trancendental
+/// `acos()` function until absolutely necessary.
+///
+///     # use acap::distance::Distance;
+///     # use acap::cos::AngularDistance;
+///     let zero = AngularDistance::from_cos(1.0);
+///     let pi_2 = AngularDistance::from_cos(0.0);
+///     let pi = AngularDistance::from_cos(-1.0);
+///     assert!(zero < pi_2 && pi_2 < pi);
+///
+/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity
+#[derive(Clone, Copy, Debug, PartialEq)]
+pub struct AngularDistance<T>(T);
+
+impl<T: Real + Value> AngularDistance<T> {
+    /// Creates an `AngularDistance` from the cosine of an angle.
+    pub fn from_cos(value: T) -> Self {
+        Self(value)
+    }
+
+    /// Get the cosine of this angle.
+    pub fn cos(self) -> T {
+        self.0
+    }
+}
+
+impl<T: PartialOrd> PartialOrd for AngularDistance<T> {
+    fn partial_cmp(&self, other: &AngularDistance<T>) -> Option<Ordering> {
+        // acos() is decreasing, so swap the comparison order
+        other.0.partial_cmp(&self.0)
+    }
+}
+
+macro_rules! impl_distance {
+    ($f:ty) => {
+        impl From<AngularDistance<$f>> for $f {
+            #[inline]
+            fn from(value: AngularDistance<$f>) -> $f {
+                value.0.acos()
+            }
+        }
+
+        impl PartialOrd<$f> for AngularDistance<$f> {
+            #[inline]
+            fn partial_cmp(&self, other: &$f) -> Option<Ordering> {
+                self.value().partial_cmp(other)
+            }
+        }
+
+        impl PartialOrd<AngularDistance<$f>> for $f {
+            #[inline]
+            fn partial_cmp(&self, other: &AngularDistance<$f>) -> Option<Ordering> {
+                self.partial_cmp(&other.value())
+            }
+        }
+
+        impl PartialEq<$f> for AngularDistance<$f> {
+            #[inline]
+            fn eq(&self, other: &$f) -> bool {
+                self.value() == *other
+            }
+        }
+
+        impl PartialEq<AngularDistance<$f>> for $f {
+            #[inline]
+            fn eq(&self, other: &AngularDistance<$f>) -> bool {
+                *self == other.value()
+            }
+        }
+
+        impl Distance for AngularDistance<$f> {
+            type Value = $f;
+        }
+    }
+}
+
+impl_distance!(f32);
+impl_distance!(f64);
+
 #[cfg(test)]
 mod tests {
     use super::*;
 
+    use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI, SQRT_2};
+
     #[test]
     fn test_cosine() {
         assert_eq!(cosine_distance([3.0, 4.0], [3.0, 4.0]), 0.0);
@@ -179,17 +433,79 @@ mod tests {
         assert_eq!(cosine_distance([3.0, 4.0], [4.0, -3.0]), 1.0);
     }
 
+    #[test]
+    fn test_prenorm_cosine() {
+        assert_eq!(prenorm_cosine_distance([0.6, 0.8], [0.6, 0.8]), 0.0);
+        assert_eq!(prenorm_cosine_distance([0.6, 0.8], [-0.8, 0.6]), 1.0);
+        assert_eq!(prenorm_cosine_distance([0.6, 0.8], [-0.6, -0.8]), 2.0);
+        assert_eq!(prenorm_cosine_distance([0.6, 0.8], [0.8, -0.6]), 1.0);
+    }
+
     #[test]
     fn test_angular() {
-        use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
+        let zero = angular_distance([3.0, 4.0], [3.0, 4.0]);
+        let pi_4 = Angular([0.0, 1.0]).distance(&Angular([1.0, 1.0]));
+        let pi_2 = Angular([3.0, 4.0]).distance(&[-4.0, 3.0]);
+        let pi = [3.0, 4.0].distance(&Angular([-3.0, -4.0]));
+
+        assert_eq!(zero.cos(), 1.0);
+        assert_eq!(pi_2.cos(), 0.0);
+        assert_eq!(pi.cos(), -1.0);
+
+        assert_eq!(zero, 0.0);
+
+        assert!(zero < pi_4);
+        assert!(zero < pi_2);
+        assert!(zero < pi);
 
-        assert_eq!(angular_distance([3.0, 4.0], [3.0, 4.0]), 0.0);
+        assert!(pi_4 < pi_2);
+        assert!(pi_4 < pi);
 
-        assert!((angular_distance([3.0, 4.0], [-4.0, 3.0]) - FRAC_PI_2).abs() < 1.0e-9);
-        assert!((angular_distance([3.0, 4.0], [-3.0, -4.0]) - PI).abs() < 1.0e-9);
-        assert!((angular_distance([3.0, 4.0], [4.0, -3.0]) - FRAC_PI_2).abs() < 1.0e-9);
+        assert!(pi_2 < pi);
 
-        assert!((angular_distance([0.0, 1.0], [1.0, 1.0]) - FRAC_PI_4).abs() < 1.0e-9);
+        assert!(FRAC_PI_4 < pi_2);
+        assert!(pi_2 > FRAC_PI_4);
+
+        assert!(pi_2 < PI);
+        assert!(PI > pi_2);
+
+        assert!((pi_4.value() - FRAC_PI_4).abs() < 1.0e-9);
+        assert!((pi_2.value() - FRAC_PI_2).abs() < 1.0e-9);
+        assert!((pi.value() - PI).abs() < 1.0e-9);
     }
-}
 
+    #[test]
+    fn test_prenorm_angular() {
+        let sqrt_2_inv = 1.0 / SQRT_2;
+
+        let zero = prenorm_angular_distance([0.6, 0.8], [0.6, 0.8]);
+        let pi_4 = PrenormAngular([0.0, 1.0]).distance(&PrenormAngular([sqrt_2_inv, sqrt_2_inv]));
+        let pi_2 = PrenormAngular([0.6, 0.8]).distance(&[-0.8, 0.6]);
+        let pi = [0.6, 0.8].distance(&PrenormAngular([-0.6, -0.8]));
+
+        assert_eq!(zero.cos(), 1.0);
+        assert_eq!(pi_2.cos(), 0.0);
+        assert_eq!(pi.cos(), -1.0);
+
+        assert_eq!(zero, 0.0);
+
+        assert!(zero < pi_4);
+        assert!(zero < pi_2);
+        assert!(zero < pi);
+
+        assert!(pi_4 < pi_2);
+        assert!(pi_4 < pi);
+
+        assert!(pi_2 < pi);
+
+        assert!(FRAC_PI_4 < pi_2);
+        assert!(pi_2 > FRAC_PI_4);
+
+        assert!(pi_2 < PI);
+        assert!(PI > pi_2);
+
+        assert!((pi_4.value() - FRAC_PI_4).abs() < 1.0e-9);
+        assert!((pi_2.value() - FRAC_PI_2).abs() < 1.0e-9);
+        assert!((pi.value() - PI).abs() < 1.0e-9);
+    }
+}
-- 
cgit v1.2.3