Skip to content

Commit 02730d2

Browse files
committed
Merge branch 'master' into ae/arbitrary-models
2 parents 2827b13 + feffa7d commit 02730d2

4 files changed

Lines changed: 172 additions & 24 deletions

File tree

.github/workflows/embedding_build_template.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ jobs:
298298
env:
299299
GIT_COMMIT_ID: ${{ steps.git_meta.outputs.commit }}
300300
GIT_TIMESTAMP_ID: ${{ steps.git_meta.outputs.timestamp }}
301+
# Windows: opt for speed (opt-level=3) since size is less critical there;
302+
# other targets keep Cargo.toml's opt-level=z for smaller binaries.
303+
CARGO_PROFILE_RELEASE_OPT_LEVEL: ${{ inputs.distr == 'windows' && '3' || 'z' }}
301304

302305
- run: |
303306
mkdir build

embeddings/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ name = "manticore_knn_embeddings"
3636
crate-type = ["cdylib"]
3737

3838
[profile.release]
39-
opt-level = 3
39+
opt-level = "z"
4040
codegen-units = 1
4141
lto = true
4242
strip = "debuginfo"

embeddings/src/model/local.rs

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,16 @@ pub fn load_tokenizer(path: &PathBuf) -> Result<Tokenizer, Box<dyn Error>> {
401401
Tokenizer::from_bytes(&bytes).map_err(|_| LibError::ModelTokenizerLoadFailed.into())
402402
}
403403

404-
/// BERT-style local embedding model
404+
/// BERT-style local embedding model.
405+
///
406+
/// `model` is `Arc<Mutex<BertModel>>` to match T5/Causal/Quantized — candle's
407+
/// BertModel takes `&self` on forward but concurrent forward calls produced
408+
/// flaky crashes in the daemon when multiple INSERTs / queries hit the same
409+
/// model in parallel. Serialising forward here mirrors the other model types'
410+
/// existing posture and trades nothing measurable on perf (uncontended Mutex
411+
/// is sub-100ns; a BERT forward is six orders of magnitude more).
405412
pub struct BertEmbeddingModel {
406-
model: BertModel,
413+
model: Arc<Mutex<BertModel>>,
407414
tokenizer: Tokenizer,
408415
max_input_len: usize,
409416
hidden_size: usize,
@@ -440,7 +447,7 @@ impl BertEmbeddingModel {
440447
let model = BertModel::load(vb, &config).map_err(|_| LibError::ModelLoadFailed)?;
441448

442449
Ok(Self {
443-
model,
450+
model: Arc::new(Mutex::new(model)),
444451
tokenizer: tokenizer.clone(),
445452
max_input_len,
446453
hidden_size,
@@ -454,6 +461,27 @@ impl BertEmbeddingModel {
454461
let mut all_embeddings = Vec::with_capacity(chunks.len());
455462

456463
for batch in chunks.chunks(batch_size()) {
464+
// Fast path for batch-of-1 (daemon's SELECT KNN(text,...) hot path):
465+
// no padding needed, so skip the attention_mask multiply and use a
466+
// plain sum/scalar-div mean pool. Matches pre-975b294 behavior.
467+
if batch.len() == 1 {
468+
let chunk = &batch[0];
469+
let token_ids = Tensor::new(chunk.as_slice(), &self.device)?.unsqueeze(0)?;
470+
let token_type_ids = token_ids.zeros_like()?;
471+
let emb = {
472+
let model = self.model.lock().unwrap();
473+
model.forward(&token_ids, &token_type_ids, None)?
474+
};
475+
let seq_len = token_ids.dims()[1];
476+
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
477+
let divisor = Tensor::new(seq_len as f32, &self.device)?;
478+
let mean_emb = summed.broadcast_div(&divisor)?;
479+
let mut emb_vec: Vec<f32> = mean_emb.get(0)?.to_vec1::<f32>()?;
480+
normalize(&mut emb_vec);
481+
all_embeddings.push(emb_vec);
482+
continue;
483+
}
484+
457485
let batch_size = batch.len();
458486
let max_len = batch.iter().map(|c| c.len()).max().unwrap_or(0);
459487

@@ -473,9 +501,10 @@ impl BertEmbeddingModel {
473501
Tensor::from_vec(flat_mask.clone(), (batch_size, max_len), &self.device)?;
474502
let token_type_ids = token_ids.zeros_like()?;
475503

476-
let emb = self
477-
.model
478-
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
504+
let emb = {
505+
let model = self.model.lock().unwrap();
506+
model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?
507+
};
479508
// emb: [batch_size, max_len, hidden_size]
480509

481510
// Attention-mask-aware mean pooling: sum(emb * mask) / sum(mask)
@@ -1119,15 +1148,22 @@ impl LocalModel {
11191148
.map(|t| pre_truncate_text(t, max_input_len))
11201149
.collect();
11211150

1122-
// Enable parallel tokenization via rayon (once)
1123-
static INIT_PARALLEL: std::sync::Once = std::sync::Once::new();
1124-
INIT_PARALLEL.call_once(|| {
1125-
std::env::set_var("TOKENIZERS_PARALLELISM", "true");
1126-
});
1127-
1128-
let encodings = tokenizer
1129-
.encode_batch(texts, true)
1130-
.map_err(|_| LibError::ModelTokenizerEncodeFailed)?;
1151+
// Adaptive tokenization: encode_batch fans out via rayon, which is pure
1152+
// overhead for small batches. The daemon's SELECT KNN(text,...) hot path
1153+
// always sends batch=1 — go sequential there. Parallelise only when the
1154+
// batch is big enough to amortise the rayon dispatch. Threshold mirrors
1155+
// the ONNX path's "no threading overhead" cutoff.
1156+
let encodings = if texts.len() > batch_size() {
1157+
tokenizer
1158+
.encode_batch(texts, true)
1159+
.map_err(|_| LibError::ModelTokenizerEncodeFailed)?
1160+
} else {
1161+
texts
1162+
.iter()
1163+
.map(|t| tokenizer.encode(*t, true))
1164+
.collect::<Result<Vec<_>, _>>()
1165+
.map_err(|_| LibError::ModelTokenizerEncodeFailed)?
1166+
};
11311167

11321168
let truncated: Vec<Vec<u32>> = encodings
11331169
.iter()
@@ -1151,6 +1187,33 @@ impl TextModel for LocalModel {
11511187
// BERT and ONNX: batched path (batch_size up to batch_size() per forward pass)
11521188
match self {
11531189
LocalModel::Bert(m) => {
1190+
// Dedicated single-text bypass: SELECT KNN(field, k, 'text') hits this
1191+
// path on every query. Skip all batching wrappers, intermediate Vecs,
1192+
// and the chunks.chunks() loop — go straight encode → forward → pool.
1193+
if texts.len() == 1 {
1194+
let text = pre_truncate_text(texts[0], m.max_input_len);
1195+
let enc = m
1196+
.tokenizer
1197+
.encode(text, true)
1198+
.map_err(|_| LibError::ModelTokenizerEncodeFailed)?;
1199+
let ids = enc.get_ids();
1200+
let ids = &ids[..ids.len().min(m.max_input_len)];
1201+
1202+
let token_ids = Tensor::new(ids, &m.device)?.unsqueeze(0)?;
1203+
let token_type_ids = token_ids.zeros_like()?;
1204+
let emb = {
1205+
let model = m.model.lock().unwrap();
1206+
model.forward(&token_ids, &token_type_ids, None)?
1207+
};
1208+
let seq_len = token_ids.dims()[1];
1209+
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
1210+
let divisor = Tensor::new(seq_len as f32, &m.device)?;
1211+
let mean_emb = summed.broadcast_div(&divisor)?;
1212+
let mut emb_vec: Vec<f32> = mean_emb.get(0)?.to_vec1::<f32>()?;
1213+
normalize(&mut emb_vec);
1214+
return Ok(vec![emb_vec]);
1215+
}
1216+
11541217
return Self::predict_batched(&m.tokenizer, m.max_input_len, texts, |chunks| {
11551218
m.predict_chunks(chunks)
11561219
});

embeddings/src/model/text_model_wrapper.rs

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,46 @@ use crate::model::{create_model, Model, ModelOptions, TextModel};
22
use std::os::raw::c_char;
33
use std::{ffi::c_void, ptr};
44

5+
/// Sentinel written at offset 0 of every live model handle. Lets FFI entry
6+
/// points detect garbage, null, or freed pointers handed in by the C++ caller
7+
/// and return a clean error instead of dereferencing into UB.
8+
const MODEL_MAGIC: u64 = 0xC0FF_EE5E_E7BE_EFDE;
9+
10+
/// Sentinel written over MODEL_MAGIC in `Drop` before the inner fields are
11+
/// destroyed. A concurrent reader racing with `free_model_result` either sees
12+
/// MAGIC (and proceeds safely) or DEAD (and gets a clean error).
13+
const MODEL_DEAD: u64 = 0xDEAD_DEAD_DEAD_DEAD;
14+
15+
/// Heap-allocated wrapper that the FFI hands to C++ as `*mut c_void`. The C++
16+
/// side stores the raw pointer and passes it back into every call; we use the
17+
/// `magic` field to validate that the pointer still references a live handle.
18+
///
19+
/// Layout note: `#[repr(C)]` and `magic` as the first field guarantee that the
20+
/// first 8 bytes of the allocation are the canary, regardless of what the inner
21+
/// `Model` enum's discriminant looks like.
22+
#[repr(C)]
23+
struct ModelHandle {
24+
magic: u64,
25+
inner: Model,
26+
}
27+
28+
impl ModelHandle {
29+
fn new(inner: Model) -> Self {
30+
Self {
31+
magic: MODEL_MAGIC,
32+
inner,
33+
}
34+
}
35+
}
36+
37+
impl Drop for ModelHandle {
38+
fn drop(&mut self) {
39+
// Tombstone before the inner Model is dropped so any concurrent FFI
40+
// reader sees MODEL_DEAD rather than MODEL_MAGIC.
41+
self.magic = MODEL_DEAD;
42+
}
43+
}
44+
545
/// cbindgen:field-names=[m_pModel, m_szError]
646
#[repr(C)]
747
pub struct TextModelResult {
@@ -94,7 +134,7 @@ impl TextModelWrapper {
94134

95135
match create_model(options) {
96136
Ok(model) => TextModelResult {
97-
model: Box::into_raw(Box::new(model)) as *mut c_void,
137+
model: Box::into_raw(Box::new(ModelHandle::new(model))) as *mut c_void,
98138
error: ptr::null_mut(),
99139
},
100140
Err(e) => {
@@ -110,7 +150,9 @@ impl TextModelWrapper {
110150
pub extern "C" fn free_model_result(res: TextModelResult) {
111151
unsafe {
112152
if !res.model.is_null() {
113-
drop(Box::from_raw(res.model as *mut Model));
153+
// Drop runs ModelHandle::drop first (tombstones magic to
154+
// MODEL_DEAD), then destroys the inner Model.
155+
drop(Box::from_raw(res.model as *mut ModelHandle));
114156
}
115157

116158
if !res.error.is_null() {
@@ -119,15 +161,45 @@ impl TextModelWrapper {
119161
}
120162
}
121163

122-
fn as_model(&self) -> &Model {
123-
unsafe { &*(self.0 as *const Model) }
164+
/// Validate the handle pointer before dereferencing. Returns a static error
165+
/// string the caller can surface to C++ instead of crashing on a bad ptr.
166+
/// Catches null, double-free / freed (MODEL_DEAD), and garbage handles.
167+
/// Cannot catch a free that happens mid-call — that requires shared
168+
/// ownership on the C++ side and is out of scope here.
169+
fn as_model(&self) -> Result<&Model, &'static str> {
170+
if self.0.is_null() {
171+
return Err("embeddings: model handle is null");
172+
}
173+
// Read the magic without forming a &ModelHandle reference first — that
174+
// would already be UB if the pointer is invalid. ptr::read of an
175+
// 8-byte aligned u64 is a single atomic load on every target Manticore
176+
// ships on, so this is safe against a concurrent Drop tombstone write.
177+
let magic = unsafe { std::ptr::read(self.0 as *const u64) };
178+
match magic {
179+
MODEL_MAGIC => Ok(unsafe { &(*(self.0 as *const ModelHandle)).inner }),
180+
MODEL_DEAD => Err("embeddings: model has been freed (use-after-free)"),
181+
_ => Err("embeddings: model handle is corrupted (invalid magic)"),
182+
}
124183
}
125184

126185
pub extern "C" fn make_vect_embeddings(
127186
&self,
128187
texts: *const StringItem,
129188
count: usize,
130189
) -> FloatVecResult {
190+
let model = match self.as_model() {
191+
Ok(m) => m,
192+
Err(msg) => {
193+
let c_error = std::ffi::CString::new(msg).unwrap();
194+
return FloatVecResult {
195+
error: c_error.into_raw(),
196+
ptr: ptr::null(),
197+
len: 0,
198+
cap: 0,
199+
};
200+
}
201+
};
202+
131203
let string_slice = unsafe { std::slice::from_raw_parts(texts, count) };
132204

133205
// Zero-copy: borrow C++ strings directly as &str.
@@ -141,7 +213,6 @@ impl TextModelWrapper {
141213
.collect();
142214

143215
let mut float_vec_list: Vec<FloatVec> = Vec::new();
144-
let model = self.as_model();
145216
let embeddings_list = model.predict(&string_refs);
146217
let c_error = match embeddings_list {
147218
Ok(embeddings_list) => {
@@ -198,18 +269,29 @@ impl TextModelWrapper {
198269
}
199270

200271
pub extern "C" fn get_hidden_size(&self) -> usize {
201-
self.as_model().get_hidden_size()
272+
// No error channel here; return 0 on a bad handle so the C++ caller
273+
// sees an obviously-wrong dimension instead of UB. The handle is
274+
// already validated before any real work, so a 0 here means the C++
275+
// side handed us an invalid pointer.
276+
self.as_model().map(|m| m.get_hidden_size()).unwrap_or(0)
202277
}
203278

204279
pub extern "C" fn get_max_input_len(&self) -> usize {
205-
self.as_model().get_max_input_len()
280+
self.as_model().map(|m| m.get_max_input_len()).unwrap_or(0)
206281
}
207282

208283
/// Validates the API key by making a minimal test request to the API.
209284
/// Returns null on success, or an error message string on failure.
210285
/// The caller is responsible for freeing the error string using free_string().
211286
pub extern "C" fn validate_api_key(&self) -> *mut c_char {
212-
let model = self.as_model();
287+
let model = match self.as_model() {
288+
Ok(m) => m,
289+
Err(msg) => {
290+
return std::ffi::CString::new(msg)
291+
.map(|c| c.into_raw())
292+
.unwrap_or(ptr::null_mut());
293+
}
294+
};
213295
match model.validate_api_key() {
214296
Ok(()) => ptr::null_mut(),
215297
Err(e) => {

0 commit comments

Comments
 (0)