@@ -76,11 +76,8 @@ impl SessionWrapper {
7676 }
7777 }
7878
79- fn run < ' s , ' i , ' v : ' i , const N : usize > (
80- & ' s self ,
81- input_values : impl Into < ort:: session:: SessionInputs < ' i , ' v , N > > ,
82- ) -> ort:: Result < ort:: session:: SessionOutputs < ' s > > {
83- unsafe { & mut * self . inner . get ( ) } . run ( input_values)
79+ fn with_session < R > ( & self , f : impl FnOnce ( & mut ort:: session:: Session ) -> R ) -> R {
80+ f ( unsafe { & mut * self . inner . get ( ) } )
8481 }
8582}
8683
@@ -102,13 +99,12 @@ impl SessionWrapper {
10299 }
103100 }
104101
105- fn run < ' s , ' i , ' v : ' i , const N : usize > (
106- & ' s self ,
107- input_values : impl Into < ort :: session :: SessionInputs < ' i , ' v , N > > ,
108- ) -> ort :: Result < ort:: session:: SessionOutputs < ' s > > {
102+ /// Mutex is held for the entire closure — covers inference, output extraction,
103+ /// and drop of SessionOutputs. Prevents the race where another thread calls
104+ /// run() while outputs are still being consumed.
105+ fn with_session < R > ( & self , f : impl FnOnce ( & mut ort:: session:: Session ) -> R ) -> R {
109106 let guard = self . inner . lock ( ) . unwrap ( ) ;
110- // SAFETY: Mutex ensures exclusive access. UnsafeCell provides &mut.
111- unsafe { & mut * guard. get ( ) } . run ( input_values)
107+ f ( unsafe { & mut * guard. get ( ) } )
112108 }
113109}
114110
@@ -861,83 +857,85 @@ impl OnnxEmbeddingModel {
861857 session : & SessionWrapper ,
862858 batch : & [ Vec < u32 > ] ,
863859 ) -> Result < Vec < Vec < f32 > > , Box < dyn Error > > {
864- let batch_size = batch. len ( ) ;
865- let max_len = batch. iter ( ) . map ( |c| c. len ( ) ) . max ( ) . unwrap_or ( 0 ) ;
860+ session. with_session ( |sess| {
861+ let batch_size = batch. len ( ) ;
862+ let max_len = batch. iter ( ) . map ( |c| c. len ( ) ) . max ( ) . unwrap_or ( 0 ) ;
866863
867- let mut flat_ids: Vec < i64 > = Vec :: with_capacity ( batch_size * max_len) ;
868- let mut flat_mask: Vec < i64 > = Vec :: with_capacity ( batch_size * max_len) ;
869- let mut flat_type_ids: Vec < i64 > = Vec :: with_capacity ( batch_size * max_len) ;
864+ let mut flat_ids: Vec < i64 > = Vec :: with_capacity ( batch_size * max_len) ;
865+ let mut flat_mask: Vec < i64 > = Vec :: with_capacity ( batch_size * max_len) ;
866+ let mut flat_type_ids: Vec < i64 > = Vec :: with_capacity ( batch_size * max_len) ;
870867
871- for chunk in batch {
872- let real_len = chunk. len ( ) ;
873- for & id in chunk. iter ( ) {
874- flat_ids. push ( id as i64 ) ;
868+ for chunk in batch {
869+ let real_len = chunk. len ( ) ;
870+ for & id in chunk. iter ( ) {
871+ flat_ids. push ( id as i64 ) ;
872+ }
873+ flat_ids. extend ( std:: iter:: repeat_n ( 0i64 , max_len - real_len) ) ;
874+ flat_mask. extend ( std:: iter:: repeat_n ( 1i64 , real_len) ) ;
875+ flat_mask. extend ( std:: iter:: repeat_n ( 0i64 , max_len - real_len) ) ;
876+ flat_type_ids. extend ( std:: iter:: repeat_n ( 0i64 , max_len) ) ;
875877 }
876- flat_ids. extend ( std:: iter:: repeat_n ( 0i64 , max_len - real_len) ) ;
877- flat_mask. extend ( std:: iter:: repeat_n ( 1i64 , real_len) ) ;
878- flat_mask. extend ( std:: iter:: repeat_n ( 0i64 , max_len - real_len) ) ;
879- flat_type_ids. extend ( std:: iter:: repeat_n ( 0i64 , max_len) ) ;
880- }
881878
882- let input_ids = ort:: value:: Tensor :: from_array ( ( vec ! [ batch_size, max_len] , flat_ids) )
883- . map_err ( |_| LibError :: OnnxModelEvalFailed ) ?;
884- let attention_mask =
885- ort:: value:: Tensor :: from_array ( ( vec ! [ batch_size, max_len] , flat_mask. clone ( ) ) )
879+ let input_ids = ort:: value:: Tensor :: from_array ( ( vec ! [ batch_size, max_len] , flat_ids) )
886880 . map_err ( |_| LibError :: OnnxModelEvalFailed ) ?;
887- let token_type_ids =
888- ort:: value:: Tensor :: from_array ( ( vec ! [ batch_size, max_len] , flat_type_ids) )
881+ let attention_mask =
882+ ort:: value:: Tensor :: from_array ( ( vec ! [ batch_size, max_len] , flat_mask. clone ( ) ) )
883+ . map_err ( |_| LibError :: OnnxModelEvalFailed ) ?;
884+ let token_type_ids =
885+ ort:: value:: Tensor :: from_array ( ( vec ! [ batch_size, max_len] , flat_type_ids) )
886+ . map_err ( |_| LibError :: OnnxModelEvalFailed ) ?;
887+
888+ let outputs = sess
889+ . run ( ort:: inputs![
890+ "input_ids" => input_ids,
891+ "attention_mask" => attention_mask,
892+ "token_type_ids" => token_type_ids,
893+ ] )
889894 . map_err ( |_| LibError :: OnnxModelEvalFailed ) ?;
890895
891- let outputs = session
892- . run ( ort:: inputs![
893- "input_ids" => input_ids,
894- "attention_mask" => attention_mask,
895- "token_type_ids" => token_type_ids,
896- ] )
897- . map_err ( |_| LibError :: OnnxModelEvalFailed ) ?;
898-
899- let ( shape, data) = outputs[ 0 ]
900- . try_extract_tensor :: < f32 > ( )
901- . map_err ( |_| LibError :: OnnxModelEvalFailed ) ?;
896+ let ( shape, data) = outputs[ 0 ]
897+ . try_extract_tensor :: < f32 > ( )
898+ . map_err ( |_| LibError :: OnnxModelEvalFailed ) ?;
902899
903- let ndim = shape. len ( ) ;
904- let mut embeddings = Vec :: with_capacity ( batch_size) ;
900+ let ndim = shape. len ( ) ;
901+ let mut embeddings = Vec :: with_capacity ( batch_size) ;
905902
906- if ndim == 2 {
907- let hidden_dim = shape[ 1 ] as usize ;
908- for i in 0 ..batch_size {
909- let start = i * hidden_dim;
910- let mut emb = data[ start..start + hidden_dim] . to_vec ( ) ;
911- normalize ( & mut emb) ;
912- embeddings. push ( emb) ;
913- }
914- } else if ndim == 3 {
915- let seq_len = shape[ 1 ] as usize ;
916- let hidden_dim = shape[ 2 ] as usize ;
917- for i in 0 ..batch_size {
918- let mut emb = vec ! [ 0.0f32 ; hidden_dim] ;
919- let mut count = 0.0f32 ;
920- for j in 0 ..seq_len {
921- let mask_val = flat_mask[ i * max_len + j] as f32 ;
922- if mask_val > 0.0 {
923- let offset = ( i * seq_len + j) * hidden_dim;
924- for k in 0 ..hidden_dim {
925- emb[ k] += data[ offset + k] ;
903+ if ndim == 2 {
904+ let hidden_dim = shape[ 1 ] as usize ;
905+ for i in 0 ..batch_size {
906+ let start = i * hidden_dim;
907+ let mut emb = data[ start..start + hidden_dim] . to_vec ( ) ;
908+ normalize ( & mut emb) ;
909+ embeddings. push ( emb) ;
910+ }
911+ } else if ndim == 3 {
912+ let seq_len = shape[ 1 ] as usize ;
913+ let hidden_dim = shape[ 2 ] as usize ;
914+ for i in 0 ..batch_size {
915+ let mut emb = vec ! [ 0.0f32 ; hidden_dim] ;
916+ let mut count = 0.0f32 ;
917+ for j in 0 ..seq_len {
918+ let mask_val = flat_mask[ i * max_len + j] as f32 ;
919+ if mask_val > 0.0 {
920+ let offset = ( i * seq_len + j) * hidden_dim;
921+ for k in 0 ..hidden_dim {
922+ emb[ k] += data[ offset + k] ;
923+ }
924+ count += 1.0 ;
926925 }
927- count += 1.0 ;
928926 }
927+ if count > 0.0 {
928+ emb. iter_mut ( ) . for_each ( |v| * v /= count) ;
929+ }
930+ normalize ( & mut emb) ;
931+ embeddings. push ( emb) ;
929932 }
930- if count > 0.0 {
931- emb. iter_mut ( ) . for_each ( |v| * v /= count) ;
932- }
933- normalize ( & mut emb) ;
934- embeddings. push ( emb) ;
933+ } else {
934+ return Err ( Box :: new ( LibError :: OnnxModelEvalFailed ) ) ;
935935 }
936- } else {
937- return Err ( Box :: new ( LibError :: OnnxModelEvalFailed ) ) ;
938- }
939936
940- Ok ( embeddings)
937+ Ok ( embeddings)
938+ } )
941939 }
942940
943941 /// Tokenize one batch and run inference.
0 commit comments