@@ -6,7 +6,7 @@ use std::iter::Enumerate;
66use std:: mem;
77use std:: slice;
88
9- use ndarray:: { Array1 , CowArray , Ix1 } ;
9+ use ndarray:: { Array1 , ArrayViewMut1 , CowArray , Ix1 } ;
1010use rand:: { RngCore , SeedableRng } ;
1111use rand_xorshift:: XorShiftRng ;
1212use reductive:: pq:: TrainPQ ;
@@ -147,6 +147,47 @@ where
147147 }
148148 }
149149
150+ /// Realize the embedding of a word into the given vector.
151+ ///
152+ /// This variant of `embedding` realizes the embedding into the
153+ /// given vector. This makes it possible to look up embeddings
154+ /// without any additional allocations. This method returns
155+ /// `false` and does not modify the vector if no embedding could
156+ /// be found.
157+ ///
158+ /// Panics when then the vector does not have the same
159+ /// dimensionality as the word embeddings.
160+ pub fn embedding_into ( & self , word : & str , mut target : ArrayViewMut1 < f32 > ) -> bool {
161+ assert_eq ! (
162+ target. len( ) ,
163+ self . dims( ) ,
164+ "Embeddings have {} dimensions, whereas target array has {}" ,
165+ self . dims( ) ,
166+ target. len( )
167+ ) ;
168+
169+ let index = if let Some ( idx) = self . vocab . idx ( word) {
170+ idx
171+ } else {
172+ return false ;
173+ } ;
174+
175+ match index {
176+ WordIndex :: Word ( idx) => target. assign ( & self . storage . embedding ( idx) ) ,
177+ WordIndex :: Subword ( indices) => {
178+ target. fill ( 0. ) ;
179+
180+ for idx in indices {
181+ target += & self . storage . embedding ( idx) . view ( ) ;
182+ }
183+
184+ l2_normalize ( target. view_mut ( ) ) ;
185+ }
186+ }
187+
188+ true
189+ }
190+
150191 /// Get the embedding and original norm of a word.
151192 ///
152193 /// Returns for a word:
@@ -531,14 +572,15 @@ mod tests {
531572 use std:: io:: { BufReader , Cursor , Seek , SeekFrom } ;
532573
533574 use approx:: AbsDiffEq ;
534- use ndarray:: array;
575+ use ndarray:: { array, Array1 } ;
535576 use toml:: toml;
536577
537578 use super :: Embeddings ;
538579 use crate :: chunks:: metadata:: Metadata ;
539580 use crate :: chunks:: norms:: NdNorms ;
540581 use crate :: chunks:: storage:: { MmapArray , NdArray , StorageView } ;
541582 use crate :: chunks:: vocab:: SimpleVocab ;
583+ use crate :: compat:: fasttext:: ReadFastText ;
542584 use crate :: compat:: word2vec:: ReadWord2VecRaw ;
543585 use crate :: io:: { MmapEmbeddings , ReadEmbeddings , WriteEmbeddings } ;
544586
@@ -559,6 +601,26 @@ mod tests {
559601 } )
560602 }
561603
604+ #[ test]
605+ fn embedding_into_equal_to_embedding ( ) {
606+ let mut reader = BufReader :: new ( File :: open ( "testdata/fasttext.bin" ) . unwrap ( ) ) ;
607+ let embeds = Embeddings :: read_fasttext ( & mut reader) . unwrap ( ) ;
608+
609+ // Known word
610+ let mut target = Array1 :: zeros ( embeds. dims ( ) ) ;
611+ assert ! ( embeds. embedding_into( "ganz" , target. view_mut( ) ) ) ;
612+ assert_eq ! ( target, embeds. embedding( "ganz" ) . unwrap( ) ) ;
613+
614+ // Unknown word
615+ let mut target = Array1 :: zeros ( embeds. dims ( ) ) ;
616+ assert ! ( embeds. embedding_into( "iddqd" , target. view_mut( ) ) ) ;
617+ assert_eq ! ( target, embeds. embedding( "iddqd" ) . unwrap( ) ) ;
618+
619+ // Unknown word, non-zero vector
620+ assert ! ( embeds. embedding_into( "idspispopd" , target. view_mut( ) ) ) ;
621+ assert_eq ! ( target, embeds. embedding( "idspispopd" ) . unwrap( ) ) ;
622+ }
623+
562624 #[ test]
563625 fn mmap ( ) {
564626 let check_embeds = test_embeddings ( ) ;
0 commit comments