Skip to content

Commit 82879b9

Browse files
committed
fix(embeddings): prevent race condition in session run
- Wrap session execution and output extraction in a closure - Ensure mutex is held until tensor data is fully consumed - Refactor SessionWrapper to use with_session for safe access
1 parent 980b24b commit 82879b9

1 file changed

Lines changed: 72 additions & 74 deletions

File tree

embeddings/src/model/local.rs

Lines changed: 72 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,8 @@ impl SessionWrapper {
7676
}
7777
}
7878

79-
fn run<'s, 'i, 'v: 'i, const N: usize>(
80-
&'s self,
81-
input_values: impl Into<ort::session::SessionInputs<'i, 'v, N>>,
82-
) -> ort::Result<ort::session::SessionOutputs<'s>> {
83-
unsafe { &mut *self.inner.get() }.run(input_values)
79+
fn with_session<R>(&self, f: impl FnOnce(&mut ort::session::Session) -> R) -> R {
80+
f(unsafe { &mut *self.inner.get() })
8481
}
8582
}
8683

@@ -102,13 +99,12 @@ impl SessionWrapper {
10299
}
103100
}
104101

105-
fn run<'s, 'i, 'v: 'i, const N: usize>(
106-
&'s self,
107-
input_values: impl Into<ort::session::SessionInputs<'i, 'v, N>>,
108-
) -> ort::Result<ort::session::SessionOutputs<'s>> {
102+
/// Mutex is held for the entire closure — covers inference, output extraction,
103+
/// and drop of SessionOutputs. Prevents the race where another thread calls
104+
/// run() while outputs are still being consumed.
105+
fn with_session<R>(&self, f: impl FnOnce(&mut ort::session::Session) -> R) -> R {
109106
let guard = self.inner.lock().unwrap();
110-
// SAFETY: Mutex ensures exclusive access. UnsafeCell provides &mut.
111-
unsafe { &mut *guard.get() }.run(input_values)
107+
f(unsafe { &mut *guard.get() })
112108
}
113109
}
114110

@@ -861,83 +857,85 @@ impl OnnxEmbeddingModel {
861857
session: &SessionWrapper,
862858
batch: &[Vec<u32>],
863859
) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
864-
let batch_size = batch.len();
865-
let max_len = batch.iter().map(|c| c.len()).max().unwrap_or(0);
860+
session.with_session(|sess| {
861+
let batch_size = batch.len();
862+
let max_len = batch.iter().map(|c| c.len()).max().unwrap_or(0);
866863

867-
let mut flat_ids: Vec<i64> = Vec::with_capacity(batch_size * max_len);
868-
let mut flat_mask: Vec<i64> = Vec::with_capacity(batch_size * max_len);
869-
let mut flat_type_ids: Vec<i64> = Vec::with_capacity(batch_size * max_len);
864+
let mut flat_ids: Vec<i64> = Vec::with_capacity(batch_size * max_len);
865+
let mut flat_mask: Vec<i64> = Vec::with_capacity(batch_size * max_len);
866+
let mut flat_type_ids: Vec<i64> = Vec::with_capacity(batch_size * max_len);
870867

871-
for chunk in batch {
872-
let real_len = chunk.len();
873-
for &id in chunk.iter() {
874-
flat_ids.push(id as i64);
868+
for chunk in batch {
869+
let real_len = chunk.len();
870+
for &id in chunk.iter() {
871+
flat_ids.push(id as i64);
872+
}
873+
flat_ids.extend(std::iter::repeat_n(0i64, max_len - real_len));
874+
flat_mask.extend(std::iter::repeat_n(1i64, real_len));
875+
flat_mask.extend(std::iter::repeat_n(0i64, max_len - real_len));
876+
flat_type_ids.extend(std::iter::repeat_n(0i64, max_len));
875877
}
876-
flat_ids.extend(std::iter::repeat_n(0i64, max_len - real_len));
877-
flat_mask.extend(std::iter::repeat_n(1i64, real_len));
878-
flat_mask.extend(std::iter::repeat_n(0i64, max_len - real_len));
879-
flat_type_ids.extend(std::iter::repeat_n(0i64, max_len));
880-
}
881878

882-
let input_ids = ort::value::Tensor::from_array((vec![batch_size, max_len], flat_ids))
883-
.map_err(|_| LibError::OnnxModelEvalFailed)?;
884-
let attention_mask =
885-
ort::value::Tensor::from_array((vec![batch_size, max_len], flat_mask.clone()))
879+
let input_ids = ort::value::Tensor::from_array((vec![batch_size, max_len], flat_ids))
886880
.map_err(|_| LibError::OnnxModelEvalFailed)?;
887-
let token_type_ids =
888-
ort::value::Tensor::from_array((vec![batch_size, max_len], flat_type_ids))
881+
let attention_mask =
882+
ort::value::Tensor::from_array((vec![batch_size, max_len], flat_mask.clone()))
883+
.map_err(|_| LibError::OnnxModelEvalFailed)?;
884+
let token_type_ids =
885+
ort::value::Tensor::from_array((vec![batch_size, max_len], flat_type_ids))
886+
.map_err(|_| LibError::OnnxModelEvalFailed)?;
887+
888+
let outputs = sess
889+
.run(ort::inputs![
890+
"input_ids" => input_ids,
891+
"attention_mask" => attention_mask,
892+
"token_type_ids" => token_type_ids,
893+
])
889894
.map_err(|_| LibError::OnnxModelEvalFailed)?;
890895

891-
let outputs = session
892-
.run(ort::inputs![
893-
"input_ids" => input_ids,
894-
"attention_mask" => attention_mask,
895-
"token_type_ids" => token_type_ids,
896-
])
897-
.map_err(|_| LibError::OnnxModelEvalFailed)?;
898-
899-
let (shape, data) = outputs[0]
900-
.try_extract_tensor::<f32>()
901-
.map_err(|_| LibError::OnnxModelEvalFailed)?;
896+
let (shape, data) = outputs[0]
897+
.try_extract_tensor::<f32>()
898+
.map_err(|_| LibError::OnnxModelEvalFailed)?;
902899

903-
let ndim = shape.len();
904-
let mut embeddings = Vec::with_capacity(batch_size);
900+
let ndim = shape.len();
901+
let mut embeddings = Vec::with_capacity(batch_size);
905902

906-
if ndim == 2 {
907-
let hidden_dim = shape[1] as usize;
908-
for i in 0..batch_size {
909-
let start = i * hidden_dim;
910-
let mut emb = data[start..start + hidden_dim].to_vec();
911-
normalize(&mut emb);
912-
embeddings.push(emb);
913-
}
914-
} else if ndim == 3 {
915-
let seq_len = shape[1] as usize;
916-
let hidden_dim = shape[2] as usize;
917-
for i in 0..batch_size {
918-
let mut emb = vec![0.0f32; hidden_dim];
919-
let mut count = 0.0f32;
920-
for j in 0..seq_len {
921-
let mask_val = flat_mask[i * max_len + j] as f32;
922-
if mask_val > 0.0 {
923-
let offset = (i * seq_len + j) * hidden_dim;
924-
for k in 0..hidden_dim {
925-
emb[k] += data[offset + k];
903+
if ndim == 2 {
904+
let hidden_dim = shape[1] as usize;
905+
for i in 0..batch_size {
906+
let start = i * hidden_dim;
907+
let mut emb = data[start..start + hidden_dim].to_vec();
908+
normalize(&mut emb);
909+
embeddings.push(emb);
910+
}
911+
} else if ndim == 3 {
912+
let seq_len = shape[1] as usize;
913+
let hidden_dim = shape[2] as usize;
914+
for i in 0..batch_size {
915+
let mut emb = vec![0.0f32; hidden_dim];
916+
let mut count = 0.0f32;
917+
for j in 0..seq_len {
918+
let mask_val = flat_mask[i * max_len + j] as f32;
919+
if mask_val > 0.0 {
920+
let offset = (i * seq_len + j) * hidden_dim;
921+
for k in 0..hidden_dim {
922+
emb[k] += data[offset + k];
923+
}
924+
count += 1.0;
926925
}
927-
count += 1.0;
928926
}
927+
if count > 0.0 {
928+
emb.iter_mut().for_each(|v| *v /= count);
929+
}
930+
normalize(&mut emb);
931+
embeddings.push(emb);
929932
}
930-
if count > 0.0 {
931-
emb.iter_mut().for_each(|v| *v /= count);
932-
}
933-
normalize(&mut emb);
934-
embeddings.push(emb);
933+
} else {
934+
return Err(Box::new(LibError::OnnxModelEvalFailed));
935935
}
936-
} else {
937-
return Err(Box::new(LibError::OnnxModelEvalFailed));
938-
}
939936

940-
Ok(embeddings)
937+
Ok(embeddings)
938+
})
941939
}
942940

943941
/// Tokenize one batch and run inference.

0 commit comments

Comments
 (0)