Skip to content

Commit 497a930

Browse files
danieldkDaniël de Kok
authored andcommitted
Embedding::{embedding_into,embedding_with_norm}: use Storage::embeddings
1 parent 8e92399 commit 497a930

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/embeddings.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,10 @@ where
174174
WordIndex::Subword(indices) => {
175175
target.fill(0.);
176176

177-
for idx in indices {
178-
target += &self.storage.embedding(idx).view();
177+
let embeds = self.storage.embeddings(&indices);
178+
179+
for embed in embeds.outer_iter() {
180+
target += &embed;
179181
}
180182

181183
l2_normalize(target.view_mut());
@@ -204,10 +206,8 @@ where
204206
norm: self.norms().map(|n| n[idx]).unwrap_or(1.),
205207
}),
206208
WordIndex::Subword(indices) => {
207-
let mut embed = Array1::zeros((self.storage.shape().1,));
208-
for idx in indices {
209-
embed += &self.storage.embedding(idx).view();
210-
}
209+
let embeds = self.storage.embeddings(&indices);
210+
let mut embed = embeds.sum_axis(Axis(0));
211211

212212
let norm = l2_normalize(embed.view_mut());
213213

0 commit comments

Comments
 (0)