Skip to content

Commit 07faa4b

Browse files
danieldkDaniël de Kok
authored andcommitted
Add WordSimilarityResult::angular_similarity
1 parent 497a930 commit 07faa4b

1 file changed

Lines changed: 44 additions & 1 deletion

File tree

src/similarity.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use std::cmp::Ordering;
44
use std::collections::{BinaryHeap, HashSet};
5+
use std::f32;
56

67
use ndarray::{s, Array1, ArrayView1, ArrayView2, CowArray, Ix1};
78
use ordered_float::NotNan;
@@ -21,6 +22,18 @@ pub struct WordSimilarityResult<'a> {
2122
pub word: &'a str,
2223
}
2324

25+
impl<'a> WordSimilarityResult<'a> {
26+
/// Get the word's similarity in angular similarity.
27+
pub fn angular_similarity(&self) -> f32 {
28+
1f32 - (self.similarity.acos() / f32::consts::PI)
29+
}
30+
31+
/// Get the word's similarity in cosine similarity.
32+
pub fn cosine_similarity(&self) -> f32 {
33+
*self.similarity
34+
}
35+
}
36+
2437
impl<'a> Ord for WordSimilarityResult<'a> {
2538
fn cmp(&self, other: &Self) -> Ordering {
2639
match other.similarity.cmp(&self.similarity) {
@@ -427,9 +440,11 @@ mod tests {
427440
use std::fs::File;
428441
use std::io::BufReader;
429442

443+
use approx::AbsDiffEq;
444+
430445
use crate::compat::word2vec::ReadWord2Vec;
431446
use crate::embeddings::Embeddings;
432-
use crate::similarity::{Analogy, EmbeddingSimilarity, WordSimilarity};
447+
use crate::similarity::{Analogy, EmbeddingSimilarity, WordSimilarity, WordSimilarityResult};
433448

434449
static SIMILARITY_ORDER_STUTTGART_10: &'static [&'static str] = &[
435450
"Karlsruhe",
@@ -530,6 +545,34 @@ mod tests {
530545
"Westfalen",
531546
];
532547

548+
#[test]
549+
fn cosine_similarity_is_correctly_converted_to_angular_similarity() {
550+
assert!((WordSimilarityResult {
551+
word: "test",
552+
similarity: 1f32.into()
553+
})
554+
.angular_similarity()
555+
.abs_diff_eq(&1f32, 1e-5));
556+
assert!((WordSimilarityResult {
557+
word: "test",
558+
similarity: 0.70710678.into()
559+
})
560+
.angular_similarity()
561+
.abs_diff_eq(&0.75, 1e-5));
562+
assert!((WordSimilarityResult {
563+
word: "test",
564+
similarity: 0f32.into()
565+
})
566+
.angular_similarity()
567+
.abs_diff_eq(&0.5f32, 1e-5));
568+
assert!((WordSimilarityResult {
569+
word: "test",
570+
similarity: (-1f32).into()
571+
})
572+
.angular_similarity()
573+
.abs_diff_eq(&0f32, 1e-5));
574+
}
575+
533576
#[test]
534577
fn test_similarity() {
535578
let f = File::open("testdata/similarity.bin").unwrap();

0 commit comments

Comments
 (0)