22
33use std:: cmp:: Ordering ;
44use std:: collections:: { BinaryHeap , HashSet } ;
5+ use std:: f32;
56
67use ndarray:: { s, Array1 , ArrayView1 , ArrayView2 , CowArray , Ix1 } ;
78use ordered_float:: NotNan ;
@@ -21,6 +22,18 @@ pub struct WordSimilarityResult<'a> {
2122 pub word : & ' a str ,
2223}
2324
25+ impl < ' a > WordSimilarityResult < ' a > {
26+ /// Get the word's similarity in angular similarity.
27+ pub fn angular_similarity ( & self ) -> f32 {
28+ 1f32 - ( self . similarity . acos ( ) / f32:: consts:: PI )
29+ }
30+
31+ /// Get the word's similarity in cosine similarity.
32+ pub fn cosine_similarity ( & self ) -> f32 {
33+ * self . similarity
34+ }
35+ }
36+
2437impl < ' a > Ord for WordSimilarityResult < ' a > {
2538 fn cmp ( & self , other : & Self ) -> Ordering {
2639 match other. similarity . cmp ( & self . similarity ) {
@@ -427,9 +440,11 @@ mod tests {
427440 use std:: fs:: File ;
428441 use std:: io:: BufReader ;
429442
443+ use approx:: AbsDiffEq ;
444+
430445 use crate :: compat:: word2vec:: ReadWord2Vec ;
431446 use crate :: embeddings:: Embeddings ;
432- use crate :: similarity:: { Analogy , EmbeddingSimilarity , WordSimilarity } ;
447+ use crate :: similarity:: { Analogy , EmbeddingSimilarity , WordSimilarity , WordSimilarityResult } ;
433448
434449 static SIMILARITY_ORDER_STUTTGART_10 : & ' static [ & ' static str ] = & [
435450 "Karlsruhe" ,
@@ -530,6 +545,34 @@ mod tests {
530545 "Westfalen" ,
531546 ] ;
532547
548+ #[ test]
549+ fn cosine_similarity_is_correctly_converted_to_angular_similarity ( ) {
550+ assert ! ( ( WordSimilarityResult {
551+ word: "test" ,
552+ similarity: 1f32 . into( )
553+ } )
554+ . angular_similarity( )
555+ . abs_diff_eq( & 1f32 , 1e-5 ) ) ;
556+ assert ! ( ( WordSimilarityResult {
557+ word: "test" ,
558+ similarity: 0.70710678 . into( )
559+ } )
560+ . angular_similarity( )
561+ . abs_diff_eq( & 0.75 , 1e-5 ) ) ;
562+ assert ! ( ( WordSimilarityResult {
563+ word: "test" ,
564+ similarity: 0f32 . into( )
565+ } )
566+ . angular_similarity( )
567+ . abs_diff_eq( & 0.5f32 , 1e-5 ) ) ;
568+ assert ! ( ( WordSimilarityResult {
569+ word: "test" ,
570+ similarity: ( -1f32 ) . into( )
571+ } )
572+ . angular_similarity( )
573+ . abs_diff_eq( & 0f32 , 1e-5 ) ) ;
574+ }
575+
533576 #[ test]
534577 fn test_similarity ( ) {
535578 let f = File :: open ( "testdata/similarity.bin" ) . unwrap ( ) ;
0 commit comments