Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions embeddings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,54 @@ codegen-units = 1
lto = true
strip = "debuginfo"

# ─── Debug-build stack-size fix for candle / BERT inference ──────────────────
#
# The Manticore daemon serves SQL queries on Boost coroutines whose stacks are
# 128 KiB (0x20000) — see "thread stack size = 0x20000" in any searchd crash
# dump. That budget is large enough for release builds of candle's BERT forward
# pass, but in *debug* builds Rust generates much fatter stack frames (no
# inlining, no dead-store elimination, no SROA). candle's transformer layers
# call deep through gemm and tensor ops; the cumulative frame size overflows
# the daemon's 128 KiB coroutine stack and silently corrupts whatever sits
# next to it in memory. The corruption then surfaces seconds later as a glibc
# heap-corruption abort in unrelated code (response buffer free, coroutine
# stack destruct, malloc on accept) — exactly the bug we chased through
# test_481, test_490, test_508.
#
# Fix: compile the heavy numerical dependencies with `opt-level = 1` even in
# debug. opt-level 1 enables function inlining and basic SROA, which collapses
# candle's frame depth back to a size that fits in 128 KiB. Our own crate
# stays at debug's default opt-level 0 so we keep full debuggability on the
# code we actually maintain.
#
# This is the standard Rust pattern for "third-party heavy crates that bloat
# debug builds" — used by serde-json, image, regex, ndarray, etc. Zero
# runtime cost in release; small CI build-time hit in debug.
[profile.dev.package.candle-core]
opt-level = 1
[profile.dev.package.candle-nn]
opt-level = 1
[profile.dev.package.candle-transformers]
opt-level = 1
[profile.dev.package.candle-kernels]
opt-level = 1
[profile.dev.package.gemm]
opt-level = 1
[profile.dev.package.gemm-common]
opt-level = 1
[profile.dev.package.gemm-f16]
opt-level = 1
[profile.dev.package.gemm-f32]
opt-level = 1
[profile.dev.package.gemm-f64]
opt-level = 1
[profile.dev.package.gemm-c32]
opt-level = 1
[profile.dev.package.gemm-c64]
opt-level = 1
[profile.dev.package."half"]
opt-level = 1

[dev-dependencies]
approx = "0.5.1"
ort = { version = "2.0.0-rc.9", default-features = false, features = ["download-binaries", "tls-rustls", "ndarray"] }
Expand Down
20 changes: 1 addition & 19 deletions embeddings/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,6 @@ const LIB: EmbedLib = EmbedLib {

#[no_mangle]
pub extern "C" fn GetLibFuncs() -> *const EmbedLib {
// Log panics to stderr (with location + payload) instead of silently
// discarding them. The previous no-op hook was hiding the root cause of
// FFI-boundary crashes; we still need catch_unwind at every extern "C"
// entry point (see text_model_wrapper.rs) to convert the unwind into a
// clean error return, but the hook here ensures the original panic site
// appears in the daemon's log before we swallow it.
std::panic::set_hook(Box::new(|info| {
let loc = info
.location()
.map(|l| format!("{}:{}:{}", l.file(), l.line(), l.column()))
.unwrap_or_else(|| "<unknown>".to_string());
let payload = info
.payload()
.downcast_ref::<&str>()
.copied()
.or_else(|| info.payload().downcast_ref::<String>().map(|s| s.as_str()))
.unwrap_or("<non-string payload>");
eprintln!("manticore-knn-embeddings: panic at {loc}: {payload}");
}));
std::panic::set_hook(Box::new(|_| {}));
&LIB
}
124 changes: 38 additions & 86 deletions embeddings/src/model/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl SessionWrapper {
/// and drop of SessionOutputs. Prevents the race where another thread calls
/// run() while outputs are still being consumed.
fn with_session<R>(&self, f: impl FnOnce(&mut ort::session::Session) -> R) -> R {
let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let guard = self.inner.lock().unwrap();
f(unsafe { &mut *guard.get() })
}
}
Expand Down Expand Up @@ -464,35 +464,19 @@ impl BertEmbeddingModel {
// Fast path for batch-of-1 (daemon's SELECT KNN(text,...) hot path):
// no padding needed, so skip the attention_mask multiply and use a
// plain sum/scalar-div mean pool. Matches pre-975b294 behavior.
//
// Lock scope covers ALL candle ops (forward + pool + to_vec1), not
// just forward. Under concurrent inserts the daemon calls predict
// from multiple threads; candle/MKL tensor ops on output tensors
// that alias internal forward storage are not safe to run while
// another thread re-enters forward. Holding the lock until the
// f32 data has been copied into an owned Vec eliminates the race.
if batch.len() == 1 {
let chunk = &batch[0];
let token_ids = Tensor::new(chunk.as_slice(), &self.device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let mut emb_vec: Vec<f32> = {
let model = self.model.lock().unwrap_or_else(|e| e.into_inner());
let emb = model.forward(&token_ids, &token_type_ids, None)?;
let seq_len = token_ids.dims()[1];
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
let divisor = Tensor::new(seq_len as f32, &self.device)?;
let mean_emb = summed.broadcast_div(&divisor)?;
// .contiguous() forces candle's to_vec1 to take its
// contiguous-offsets path (slice::to_vec, cap == len).
// The strided path uses Iterator::collect, which can
// produce Vec with cap > len from FromIterator growth
// doubling — that would mean the (ptr, len, cap) we
// hand across FFI doesn't match the canonical layout
// glibc expects when Vec::from_raw_parts drops on the
// C++ side via free_vec_result. Eliminate the path
// dependency entirely.
mean_emb.get(0)?.contiguous()?.to_vec1::<f32>()?
let emb = {
let model = self.model.lock().unwrap();
model.forward(&token_ids, &token_type_ids, None)?
};
let seq_len = token_ids.dims()[1];
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
let divisor = Tensor::new(seq_len as f32, &self.device)?;
let mean_emb = summed.broadcast_div(&divisor)?;
let mut emb_vec: Vec<f32> = mean_emb.get(0)?.to_vec1::<f32>()?;
normalize(&mut emb_vec);
all_embeddings.push(emb_vec);
continue;
Expand All @@ -517,37 +501,24 @@ impl BertEmbeddingModel {
Tensor::from_vec(flat_mask.clone(), (batch_size, max_len), &self.device)?;
let token_type_ids = token_ids.zeros_like()?;

// Lock scope intentionally covers forward + the full mean-pool
// pipeline + every per-row to_vec1. See the batch-of-1 fast path
// comment above for the concurrency rationale: post-forward tensor
// ops on candle outputs are not safe to run while another thread
// re-enters forward on the same BertModel.
let mut batch_embeddings: Vec<Vec<f32>> = {
let model = self.model.lock().unwrap_or_else(|e| e.into_inner());
let emb = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
// emb: [batch_size, max_len, hidden_size]

// Attention-mask-aware mean pooling: sum(emb * mask) / sum(mask)
let mask_expanded = attention_mask.unsqueeze(2)?; // [batch, max_len, 1]
let masked_emb = emb.broadcast_mul(&mask_expanded)?;
let summed = masked_emb.sum(1)?.to_dtype(DType::F32)?; // [batch, hidden]
let token_counts = attention_mask.sum(1)?.unsqueeze(1)?; // [batch, 1]
let mean_emb = summed.broadcast_div(&token_counts)?;

let mut out = Vec::with_capacity(batch_size);
for i in 0..batch_size {
// See contiguous() rationale on the batch-of-1 fast path
// above — same FFI cap/len invariant requirement applies
// to each row pulled out of the batched mean_emb.
out.push(mean_emb.get(i)?.contiguous()?.to_vec1::<f32>()?);
}
out
let emb = {
let model = self.model.lock().unwrap();
model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?
};
// emb: [batch_size, max_len, hidden_size]

// Attention-mask-aware mean pooling: sum(emb * mask) / sum(mask)
let mask_expanded = attention_mask.unsqueeze(2)?; // [batch, max_len, 1]
let masked_emb = emb.broadcast_mul(&mask_expanded)?;
let summed = masked_emb.sum(1)?.to_dtype(DType::F32)?; // [batch, hidden]
let token_counts = attention_mask.sum(1)?.unsqueeze(1)?; // [batch, 1]
let mean_emb = summed.broadcast_div(&token_counts)?;

for emb_vec in batch_embeddings.iter_mut() {
normalize(emb_vec);
for i in 0..batch_size {
let mut emb_vec: Vec<f32> = mean_emb.get(i)?.to_vec1::<f32>()?;
normalize(&mut emb_vec);
all_embeddings.push(emb_vec);
}
all_embeddings.extend(batch_embeddings);
}

Ok(all_embeddings)
Expand Down Expand Up @@ -1071,20 +1042,12 @@ impl OnnxEmbeddingModel {
.collect();

for handle in handles {
// `join().unwrap()` would panic if the worker thread itself
// panicked — and that panic would unwind through rayon's
// scope into the FFI caller. Convert a panicked worker into
// a normal Err instead.
match handle.join() {
Ok(Ok(embs)) => ordered_results.push(embs),
Ok(Err(e)) => {
match handle.join().unwrap() {
Ok(embs) => ordered_results.push(embs),
Err(e) => {
error = Some(e);
break;
}
Err(_) => {
error = Some(LibError::OnnxModelEvalFailed);
break;
}
}
}
});
Expand Down Expand Up @@ -1227,9 +1190,6 @@ impl TextModel for LocalModel {
// Dedicated single-text bypass: SELECT KNN(field, k, 'text') hits this
// path on every query. Skip all batching wrappers, intermediate Vecs,
// and the chunks.chunks() loop — go straight encode → forward → pool.
//
// Lock scope covers the full candle pipeline through to_vec1; see
// BertEmbeddingModel::predict_chunks for the concurrency rationale.
if texts.len() == 1 {
let text = pre_truncate_text(texts[0], m.max_input_len);
let enc = m
Expand All @@ -1241,18 +1201,15 @@ impl TextModel for LocalModel {

let token_ids = Tensor::new(ids, &m.device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let mut emb_vec: Vec<f32> = {
let model = m.model.lock().unwrap_or_else(|e| e.into_inner());
let emb = model.forward(&token_ids, &token_type_ids, None)?;
let seq_len = token_ids.dims()[1];
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
let divisor = Tensor::new(seq_len as f32, &m.device)?;
let mean_emb = summed.broadcast_div(&divisor)?;
// See contiguous() rationale on
// BertEmbeddingModel::predict_chunks above. Same FFI
// canonical-layout invariant required here.
mean_emb.get(0)?.contiguous()?.to_vec1::<f32>()?
let emb = {
let model = m.model.lock().unwrap();
model.forward(&token_ids, &token_type_ids, None)?
};
let seq_len = token_ids.dims()[1];
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
let divisor = Tensor::new(seq_len as f32, &m.device)?;
let mean_emb = summed.broadcast_div(&divisor)?;
let mut emb_vec: Vec<f32> = mean_emb.get(0)?.to_vec1::<f32>()?;
normalize(&mut emb_vec);
return Ok(vec![emb_vec]);
}
Expand Down Expand Up @@ -1308,7 +1265,7 @@ impl TextModel for LocalModel {
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
let embeddings = match self {
LocalModel::T5(m) => {
let mut model = m.model.lock().unwrap_or_else(|e| e.into_inner());
let mut model = m.model.lock().unwrap();
let emb = model.forward(&token_ids)?;
let cls_emb = emb.i(0)?;
let first_token = cls_emb.i(0)?;
Expand Down Expand Up @@ -1356,15 +1313,15 @@ impl TextModel for LocalModel {
},
LocalModel::Quantized(m) => match &m.model {
QuantizedModelKind::Gemma { model } => {
let mut model = model.lock().unwrap_or_else(|e| e.into_inner());
let mut model = model.lock().unwrap();
let emb = model.forward(&token_ids, 0)?;
let (_, n_tokens, _) = emb.dims3()?;
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
let divisor = Tensor::new(n_tokens as f32, &device)?;
summed.broadcast_div(&divisor)?
}
QuantizedModelKind::Llama { model } => {
let mut model = model.lock().unwrap_or_else(|e| e.into_inner());
let mut model = model.lock().unwrap();
let emb = model.forward(&token_ids, 0)?;
let (_, n_tokens, _) = emb.dims3()?;
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
Expand All @@ -1376,12 +1333,7 @@ impl TextModel for LocalModel {
};

if let Ok(e_j) = embeddings.get(0) {
// See contiguous() rationale on BertEmbeddingModel above.
// Same FFI canonical-layout invariant for T5 / Causal /
// Quantized sequential output.
let emb_vec: Vec<f32> = e_j
.contiguous()
.map_err(|e| -> Box<dyn Error> { Box::new(e) })?
.to_vec1::<f32>()
.map_err(|e| -> Box<dyn Error> { Box::new(e) })?;
let mut emb = emb_vec;
Expand Down
Loading
Loading