Skip to content

Commit 36808fc

Browse files
danieldksebpuetz
authored andcommitted
Add the Embedding::into method
This method realizes an embedding into a user-provided array. This makes it possible to look up embeddings without additional allocations (for unknown words). Fixes #110.
1 parent 7fa5196 commit 36808fc

1 file changed

Lines changed: 64 additions & 2 deletions

File tree

src/embeddings.rs

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::iter::Enumerate;
66
use std::mem;
77
use std::slice;
88

9-
use ndarray::{Array1, CowArray, Ix1};
9+
use ndarray::{Array1, ArrayViewMut1, CowArray, Ix1};
1010
use rand::{RngCore, SeedableRng};
1111
use rand_xorshift::XorShiftRng;
1212
use reductive::pq::TrainPQ;
@@ -147,6 +147,47 @@ where
147147
}
148148
}
149149

150+
/// Realize the embedding of a word into the given vector.
151+
///
152+
/// This variant of `embedding` realizes the embedding into the
153+
/// given vector. This makes it possible to look up embeddings
154+
/// without any additional allocations. This method returns
155+
/// `false` and does not modify the vector if no embedding could
156+
/// be found.
157+
///
158+
/// Panics when then the vector does not have the same
159+
/// dimensionality as the word embeddings.
160+
pub fn embedding_into(&self, word: &str, mut target: ArrayViewMut1<f32>) -> bool {
161+
assert_eq!(
162+
target.len(),
163+
self.dims(),
164+
"Embeddings have {} dimensions, whereas target array has {}",
165+
self.dims(),
166+
target.len()
167+
);
168+
169+
let index = if let Some(idx) = self.vocab.idx(word) {
170+
idx
171+
} else {
172+
return false;
173+
};
174+
175+
match index {
176+
WordIndex::Word(idx) => target.assign(&self.storage.embedding(idx)),
177+
WordIndex::Subword(indices) => {
178+
target.fill(0.);
179+
180+
for idx in indices {
181+
target += &self.storage.embedding(idx).view();
182+
}
183+
184+
l2_normalize(target.view_mut());
185+
}
186+
}
187+
188+
true
189+
}
190+
150191
/// Get the embedding and original norm of a word.
151192
///
152193
/// Returns for a word:
@@ -531,14 +572,15 @@ mod tests {
531572
use std::io::{BufReader, Cursor, Seek, SeekFrom};
532573

533574
use approx::AbsDiffEq;
534-
use ndarray::array;
575+
use ndarray::{array, Array1};
535576
use toml::toml;
536577

537578
use super::Embeddings;
538579
use crate::chunks::metadata::Metadata;
539580
use crate::chunks::norms::NdNorms;
540581
use crate::chunks::storage::{MmapArray, NdArray, StorageView};
541582
use crate::chunks::vocab::SimpleVocab;
583+
use crate::compat::fasttext::ReadFastText;
542584
use crate::compat::word2vec::ReadWord2VecRaw;
543585
use crate::io::{MmapEmbeddings, ReadEmbeddings, WriteEmbeddings};
544586

@@ -559,6 +601,26 @@ mod tests {
559601
})
560602
}
561603

604+
#[test]
605+
fn embedding_into_equal_to_embedding() {
606+
let mut reader = BufReader::new(File::open("testdata/fasttext.bin").unwrap());
607+
let embeds = Embeddings::read_fasttext(&mut reader).unwrap();
608+
609+
// Known word
610+
let mut target = Array1::zeros(embeds.dims());
611+
assert!(embeds.embedding_into("ganz", target.view_mut()));
612+
assert_eq!(target, embeds.embedding("ganz").unwrap());
613+
614+
// Unknown word
615+
let mut target = Array1::zeros(embeds.dims());
616+
assert!(embeds.embedding_into("iddqd", target.view_mut()));
617+
assert_eq!(target, embeds.embedding("iddqd").unwrap());
618+
619+
// Unknown word, non-zero vector
620+
assert!(embeds.embedding_into("idspispopd", target.view_mut()));
621+
assert_eq!(target, embeds.embedding("idspispopd").unwrap());
622+
}
623+
562624
#[test]
563625
fn mmap() {
564626
let check_embeds = test_embeddings();

0 commit comments

Comments
 (0)