From c53a3031f7a8ea0578634d53597c2817f586665b Mon Sep 17 00:00:00 2001 From: Tavian Barnes Date: Thu, 25 Jun 2020 09:04:36 -0400 Subject: cos: Implement cosine and angular distance --- src/cos.rs | 195 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 2 files changed, 196 insertions(+) create mode 100644 src/cos.rs diff --git a/src/cos.rs b/src/cos.rs new file mode 100644 index 0000000..2fde4ce --- /dev/null +++ b/src/cos.rs @@ -0,0 +1,195 @@ +//! [Cosine distance](https://en.wikipedia.org/wiki/Cosine_similarity). + +use crate::coords::Coordinates; +use crate::distance::{Metric, Proximity}; + +use num_traits::real::Real; +use num_traits::{one, zero}; + +/// Compute the [cosine *similarity*] between two points. +/// +/// This is not suitable for implementing [`Proximity::distance()`] because the result is reversed +/// +/// [cosine *similarity*]: https://en.wikipedia.org/wiki/Cosine_similarity +/// [`Proximity::distance()`]: Proximity#method.distance +pub fn cosine_similarity(x: T, y: U) -> T::Value +where + T: Coordinates, + U: Coordinates, + T::Value: Real, +{ + debug_assert!(x.dims() == y.dims()); + + let mut dot: T::Value = zero(); + let mut xx: T::Value = zero(); + let mut yy: T::Value = zero(); + + for i in 0..x.dims() { + let xi = x.coord(i); + let yi = y.coord(i); + dot += xi * yi; + xx += xi * xi; + yy += yi * yi; + } + + dot / (xx * yy).sqrt() +} + +/// Compute the [cosine distance] between two points. +/// +/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity +pub fn cosine_distance(x: T, y: U) -> T::Value +where + T: Coordinates, + U: Coordinates, + T::Value: Real, +{ + let one: T::Value = one(); + one - cosine_similarity(x, y) +} + +/// Equips any [coordinate space] with the [cosine distance] function. +/// +/// [coordinate space]: [Coordinates] +/// [cosine distance]: https://en.wikipedia.org/wiki/Cosine_similarity +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct Cosine(pub T); + +impl Proximity for Cosine +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &Self) -> Self::Distance { + cosine_distance(&self.0, &other.0) + } +} + +impl Proximity for Cosine +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &T) -> Self::Distance { + cosine_distance(&self.0, other) + } +} + +impl Proximity> for T +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &Cosine) -> Self::Distance { + cosine_distance(self, &other.0) + } +} + +/// Compute the [angular distance] between two points. +/// +/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity +pub fn angular_distance(x: T, y: U) -> T::Value +where + T: Coordinates, + U: Coordinates, + T::Value: Real, +{ + cosine_similarity(x, y).acos() +} + +/// Equips any [coordinate space] with the [angular distance] metric. +/// +/// [coordinate space]: [Coordinates] +/// [angular distance]: https://en.wikipedia.org/wiki/Cosine_similarity#Angular_distance_and_similarity +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct Angular(pub T); + +impl Proximity for Angular +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &Self) -> Self::Distance { + cosine_distance(&self.0, &other.0) + } +} + +impl Proximity for Angular +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &T) -> Self::Distance { + angular_distance(&self.0, other) + } +} + +impl Proximity> for T +where + T: Coordinates, + T::Value: Real, +{ + type Distance = T::Value; + + fn distance(&self, other: &Angular) -> Self::Distance { + angular_distance(self, &other.0) + } +} + +/// Angular distance is a metric. +impl Metric for Angular +where + T: Coordinates, + T::Value: Real, +{} + +/// Angular distance is a metric. +impl Metric for Angular +where + T: Coordinates, + T::Value: Real, +{} + +/// Angular distance is a metric. +impl Metric> for T +where + T: Coordinates, + T::Value: Real, +{} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cosine() { + assert_eq!(cosine_distance([3.0, 4.0], [3.0, 4.0]), 0.0); + assert_eq!(cosine_distance([3.0, 4.0], [-4.0, 3.0]), 1.0); + assert_eq!(cosine_distance([3.0, 4.0], [-3.0, -4.0]), 2.0); + assert_eq!(cosine_distance([3.0, 4.0], [4.0, -3.0]), 1.0); + } + + #[test] + fn test_angular() { + use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI}; + + assert_eq!(angular_distance([3.0, 4.0], [3.0, 4.0]), 0.0); + + assert!((angular_distance([3.0, 4.0], [-4.0, 3.0]) - FRAC_PI_2).abs() < 1.0e-9); + assert!((angular_distance([3.0, 4.0], [-3.0, -4.0]) - PI).abs() < 1.0e-9); + assert!((angular_distance([3.0, 4.0], [4.0, -3.0]) - FRAC_PI_2).abs() < 1.0e-9); + + assert!((angular_distance([0.0, 1.0], [1.0, 1.0]) - FRAC_PI_4).abs() < 1.0e-9); + } +} + diff --git a/src/lib.rs b/src/lib.rs index 43d9bf1..e6ca957 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,7 @@ pub mod chebyshev; pub mod coords; +pub mod cos; pub mod distance; pub mod euclid; pub mod exhaustive; -- cgit v1.2.3