@@ -401,9 +401,16 @@ pub fn load_tokenizer(path: &PathBuf) -> Result<Tokenizer, Box<dyn Error>> {
401401 Tokenizer :: from_bytes ( & bytes) . map_err ( |_| LibError :: ModelTokenizerLoadFailed . into ( ) )
402402}
403403
404- /// BERT-style local embedding model
404+ /// BERT-style local embedding model.
405+ ///
406+ /// `model` is `Arc<Mutex<BertModel>>` to match T5/Causal/Quantized — candle's
407+ /// BertModel takes `&self` on forward but concurrent forward calls produced
408+ /// flaky crashes in the daemon when multiple INSERTs / queries hit the same
409+ /// model in parallel. Serialising forward here mirrors the other model types'
410+ /// existing posture and trades nothing measurable on perf (uncontended Mutex
411+ /// is sub-100ns; a BERT forward is six orders of magnitude more).
405412pub struct BertEmbeddingModel {
406- model : BertModel ,
413+ model : Arc < Mutex < BertModel > > ,
407414 tokenizer : Tokenizer ,
408415 max_input_len : usize ,
409416 hidden_size : usize ,
@@ -440,7 +447,7 @@ impl BertEmbeddingModel {
440447 let model = BertModel :: load ( vb, & config) . map_err ( |_| LibError :: ModelLoadFailed ) ?;
441448
442449 Ok ( Self {
443- model,
450+ model : Arc :: new ( Mutex :: new ( model ) ) ,
444451 tokenizer : tokenizer. clone ( ) ,
445452 max_input_len,
446453 hidden_size,
@@ -454,6 +461,27 @@ impl BertEmbeddingModel {
454461 let mut all_embeddings = Vec :: with_capacity ( chunks. len ( ) ) ;
455462
456463 for batch in chunks. chunks ( batch_size ( ) ) {
464+ // Fast path for batch-of-1 (daemon's SELECT KNN(text,...) hot path):
465+ // no padding needed, so skip the attention_mask multiply and use a
466+ // plain sum/scalar-div mean pool. Matches pre-975b294 behavior.
467+ if batch. len ( ) == 1 {
468+ let chunk = & batch[ 0 ] ;
469+ let token_ids = Tensor :: new ( chunk. as_slice ( ) , & self . device ) ?. unsqueeze ( 0 ) ?;
470+ let token_type_ids = token_ids. zeros_like ( ) ?;
471+ let emb = {
472+ let model = self . model . lock ( ) . unwrap ( ) ;
473+ model. forward ( & token_ids, & token_type_ids, None ) ?
474+ } ;
475+ let seq_len = token_ids. dims ( ) [ 1 ] ;
476+ let summed = emb. sum ( 1 ) ?. to_dtype ( DType :: F32 ) ?;
477+ let divisor = Tensor :: new ( seq_len as f32 , & self . device ) ?;
478+ let mean_emb = summed. broadcast_div ( & divisor) ?;
479+ let mut emb_vec: Vec < f32 > = mean_emb. get ( 0 ) ?. to_vec1 :: < f32 > ( ) ?;
480+ normalize ( & mut emb_vec) ;
481+ all_embeddings. push ( emb_vec) ;
482+ continue ;
483+ }
484+
457485 let batch_size = batch. len ( ) ;
458486 let max_len = batch. iter ( ) . map ( |c| c. len ( ) ) . max ( ) . unwrap_or ( 0 ) ;
459487
@@ -473,9 +501,10 @@ impl BertEmbeddingModel {
473501 Tensor :: from_vec ( flat_mask. clone ( ) , ( batch_size, max_len) , & self . device ) ?;
474502 let token_type_ids = token_ids. zeros_like ( ) ?;
475503
476- let emb = self
477- . model
478- . forward ( & token_ids, & token_type_ids, Some ( & attention_mask) ) ?;
504+ let emb = {
505+ let model = self . model . lock ( ) . unwrap ( ) ;
506+ model. forward ( & token_ids, & token_type_ids, Some ( & attention_mask) ) ?
507+ } ;
479508 // emb: [batch_size, max_len, hidden_size]
480509
481510 // Attention-mask-aware mean pooling: sum(emb * mask) / sum(mask)
@@ -1119,15 +1148,22 @@ impl LocalModel {
11191148 . map ( |t| pre_truncate_text ( t, max_input_len) )
11201149 . collect ( ) ;
11211150
1122- // Enable parallel tokenization via rayon (once)
1123- static INIT_PARALLEL : std:: sync:: Once = std:: sync:: Once :: new ( ) ;
1124- INIT_PARALLEL . call_once ( || {
1125- std:: env:: set_var ( "TOKENIZERS_PARALLELISM" , "true" ) ;
1126- } ) ;
1127-
1128- let encodings = tokenizer
1129- . encode_batch ( texts, true )
1130- . map_err ( |_| LibError :: ModelTokenizerEncodeFailed ) ?;
1151+ // Adaptive tokenization: encode_batch fans out via rayon, which is pure
1152+ // overhead for small batches. The daemon's SELECT KNN(text,...) hot path
1153+ // always sends batch=1 — go sequential there. Parallelise only when the
1154+ // batch is big enough to amortise the rayon dispatch. Threshold mirrors
1155+ // the ONNX path's "no threading overhead" cutoff.
1156+ let encodings = if texts. len ( ) > batch_size ( ) {
1157+ tokenizer
1158+ . encode_batch ( texts, true )
1159+ . map_err ( |_| LibError :: ModelTokenizerEncodeFailed ) ?
1160+ } else {
1161+ texts
1162+ . iter ( )
1163+ . map ( |t| tokenizer. encode ( * t, true ) )
1164+ . collect :: < Result < Vec < _ > , _ > > ( )
1165+ . map_err ( |_| LibError :: ModelTokenizerEncodeFailed ) ?
1166+ } ;
11311167
11321168 let truncated: Vec < Vec < u32 > > = encodings
11331169 . iter ( )
@@ -1151,6 +1187,33 @@ impl TextModel for LocalModel {
11511187 // BERT and ONNX: batched path (batch_size up to batch_size() per forward pass)
11521188 match self {
11531189 LocalModel :: Bert ( m) => {
1190+ // Dedicated single-text bypass: SELECT KNN(field, k, 'text') hits this
1191+ // path on every query. Skip all batching wrappers, intermediate Vecs,
1192+ // and the chunks.chunks() loop — go straight encode → forward → pool.
1193+ if texts. len ( ) == 1 {
1194+ let text = pre_truncate_text ( texts[ 0 ] , m. max_input_len ) ;
1195+ let enc = m
1196+ . tokenizer
1197+ . encode ( text, true )
1198+ . map_err ( |_| LibError :: ModelTokenizerEncodeFailed ) ?;
1199+ let ids = enc. get_ids ( ) ;
1200+ let ids = & ids[ ..ids. len ( ) . min ( m. max_input_len ) ] ;
1201+
1202+ let token_ids = Tensor :: new ( ids, & m. device ) ?. unsqueeze ( 0 ) ?;
1203+ let token_type_ids = token_ids. zeros_like ( ) ?;
1204+ let emb = {
1205+ let model = m. model . lock ( ) . unwrap ( ) ;
1206+ model. forward ( & token_ids, & token_type_ids, None ) ?
1207+ } ;
1208+ let seq_len = token_ids. dims ( ) [ 1 ] ;
1209+ let summed = emb. sum ( 1 ) ?. to_dtype ( DType :: F32 ) ?;
1210+ let divisor = Tensor :: new ( seq_len as f32 , & m. device ) ?;
1211+ let mean_emb = summed. broadcast_div ( & divisor) ?;
1212+ let mut emb_vec: Vec < f32 > = mean_emb. get ( 0 ) ?. to_vec1 :: < f32 > ( ) ?;
1213+ normalize ( & mut emb_vec) ;
1214+ return Ok ( vec ! [ emb_vec] ) ;
1215+ }
1216+
11541217 return Self :: predict_batched ( & m. tokenizer , m. max_input_len , texts, |chunks| {
11551218 m. predict_chunks ( chunks)
11561219 } ) ;
0 commit comments