File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -174,8 +174,10 @@ where
174174 WordIndex :: Subword ( indices) => {
175175 target. fill ( 0. ) ;
176176
177- for idx in indices {
178- target += & self . storage . embedding ( idx) . view ( ) ;
177+ let embeds = self . storage . embeddings ( & indices) ;
178+
179+ for embed in embeds. outer_iter ( ) {
180+ target += & embed;
179181 }
180182
181183 l2_normalize ( target. view_mut ( ) ) ;
@@ -204,10 +206,8 @@ where
204206 norm : self . norms ( ) . map ( |n| n[ idx] ) . unwrap_or ( 1. ) ,
205207 } ) ,
206208 WordIndex :: Subword ( indices) => {
207- let mut embed = Array1 :: zeros ( ( self . storage . shape ( ) . 1 , ) ) ;
208- for idx in indices {
209- embed += & self . storage . embedding ( idx) . view ( ) ;
210- }
209+ let embeds = self . storage . embeddings ( & indices) ;
210+ let mut embed = embeds. sum_axis ( Axis ( 0 ) ) ;
211211
212212 let norm = l2_normalize ( embed. view_mut ( ) ) ;
213213
You can’t perform that action at this time.
0 commit comments