Skip to content

Commit 4e9cf55

Browse files
author
root
committed
fix: abort LLM stream on timeout
1 parent 9a1b44f commit 4e9cf55

2 files changed

Lines changed: 329 additions & 41 deletions

File tree

src/moderation/smart.rs

Lines changed: 146 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ use std::time::{SystemTime, UNIX_EPOCH};
77

88
use anyhow::{anyhow, Result};
99
use indexmap::IndexMap;
10+
use reqwest::header::CONTENT_TYPE;
1011
use reqwest::Client;
1112
use serde::Serialize;
1213
use serde_json::{json, Value};
14+
use tokio::time::timeout as tokio_timeout;
1315
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
1416

1517
use crate::moderation::bow;
@@ -193,54 +195,80 @@ async fn llm_moderate(
193195
profile.config.ai.base_url.trim_end_matches('/')
194196
);
195197

198+
// NOTE:
199+
// - Always request a streaming response (SSE) so timing out will cancel the in-flight body read,
200+
// which causes the underlying HTTP connection to be dropped instead of being left hanging.
201+
// - Enforce timeout with `tokio::time::timeout` so it cannot "sometimes not timeout".
202+
async fn parse_llm_response(
203+
response: reqwest::Response,
204+
) -> Result<SmartModerationResult, SmartModerationError> {
205+
let response = response
206+
.error_for_status()
207+
.map_err(|err| SmartModerationError::Other(anyhow!(err)))?;
208+
209+
let is_sse = response
210+
.headers()
211+
.get(CONTENT_TYPE)
212+
.and_then(|value| value.to_str().ok())
213+
.map(|value| value.starts_with("text/event-stream"))
214+
.unwrap_or(false);
215+
216+
if is_sse {
217+
let content = read_openai_chat_sse_content(response)
218+
.await
219+
.map_err(|err| SmartModerationError::Other(err))?;
220+
parse_moderation_content(&content)
221+
} else {
222+
let payload = response
223+
.json::<Value>()
224+
.await
225+
.map_err(|err| SmartModerationError::Other(anyhow!(err)))?;
226+
parse_openai_moderation_response(payload)
227+
}
228+
}
229+
196230
let mut attempted_models = Vec::new();
197231
let mut last_error = None;
198232
for _attempt in 0..=max_retries {
199233
let model = pick_model_for_attempt(&models, &attempted_models);
200234
attempted_models.push(model.clone());
201-
let response = http_client
202-
.post(&endpoint)
203-
.bearer_auth(&api_key)
204-
.timeout(timeout)
205-
.json(&json!({
206-
"model": &model,
207-
"messages": [{
208-
"role": "user",
209-
"content": prompt
210-
}],
211-
"temperature": 0
212-
}))
213-
.send()
214-
.await;
235+
236+
let response = tokio_timeout(timeout, async {
237+
let response = http_client
238+
.post(&endpoint)
239+
.bearer_auth(&api_key)
240+
.json(&json!({
241+
"model": &model,
242+
"messages": [{
243+
"role": "user",
244+
"content": prompt
245+
}],
246+
"temperature": 0,
247+
"stream": true
248+
}))
249+
.send()
250+
.await
251+
.map_err(|err| SmartModerationError::Other(anyhow!(err)))?;
252+
253+
parse_llm_response(response).await
254+
})
255+
.await;
215256

216257
match response {
217-
Ok(resp) => {
218-
match resp.error_for_status() {
219-
Ok(resp) => match resp.json::<Value>().await {
220-
Ok(payload) => match parse_openai_moderation_response(payload) {
221-
Ok(parsed) => return Ok(parsed),
222-
Err(err) => {
223-
let err = match err {
224-
SmartModerationError::ConcurrencyLimit(message) => {
225-
anyhow!(message)
226-
}
227-
SmartModerationError::Other(err) => err,
228-
};
229-
last_error = Some(
230-
err.context("failed to parse llm moderation response"),
231-
);
232-
}
233-
},
234-
Err(err) => {
235-
last_error = Some(anyhow!(err).context("failed to decode llm moderation response"));
236-
}
237-
},
238-
Err(err) => {
239-
last_error = Some(anyhow!(err).context("llm moderation request failed"));
240-
}
241-
}
258+
Ok(Ok(parsed)) => return Ok(parsed),
259+
Ok(Err(err)) => {
260+
let err = match err {
261+
SmartModerationError::ConcurrencyLimit(message) => anyhow!(message),
262+
SmartModerationError::Other(err) => err,
263+
};
264+
last_error = Some(err.context("llm moderation request failed"));
265+
}
266+
Err(_) => {
267+
last_error = Some(anyhow!(
268+
"llm moderation request timed out after {}s (connection aborted)",
269+
timeout.as_secs().max(1)
270+
));
242271
}
243-
Err(err) => last_error = Some(anyhow!(err)),
244272
}
245273
}
246274

@@ -253,6 +281,84 @@ async fn llm_moderate(
253281
.into())
254282
}
255283

284+
async fn read_openai_chat_sse_content(
285+
response: reqwest::Response,
286+
) -> Result<String> {
287+
use futures_util::StreamExt;
288+
289+
let mut stream = response.bytes_stream();
290+
let mut buffer: Vec<u8> = Vec::new();
291+
let mut content = String::new();
292+
293+
while let Some(chunk) = stream.next().await {
294+
let chunk = chunk?;
295+
buffer.extend_from_slice(&chunk);
296+
297+
while let Some(newline_idx) = buffer.iter().position(|b| *b == b'\n') {
298+
let mut line = buffer.drain(..=newline_idx).collect::<Vec<u8>>();
299+
while matches!(line.last(), Some(b'\n' | b'\r')) {
300+
line.pop();
301+
}
302+
303+
if line.is_empty() {
304+
continue;
305+
}
306+
307+
let Ok(line) = std::str::from_utf8(&line) else {
308+
continue;
309+
};
310+
let trimmed = line.trim();
311+
let Some(data) = trimmed.strip_prefix("data:") else {
312+
continue;
313+
};
314+
let data = data.trim();
315+
if data.is_empty() {
316+
continue;
317+
}
318+
if data == "[DONE]" {
319+
return Ok(content);
320+
}
321+
322+
let Ok(payload) = serde_json::from_str::<Value>(data) else {
323+
continue;
324+
};
325+
if let Some(delta) = extract_openai_chat_stream_delta(&payload) {
326+
content.push_str(&delta);
327+
}
328+
}
329+
}
330+
331+
Ok(content)
332+
}
333+
334+
fn extract_openai_chat_stream_delta(payload: &Value) -> Option<String> {
335+
let choice = payload
336+
.get("choices")
337+
.and_then(Value::as_array)
338+
.and_then(|choices| choices.first())?;
339+
340+
if let Some(delta) = choice.get("delta") {
341+
if let Some(content) = delta.get("content") {
342+
let text = extract_content_text(content);
343+
if !text.is_empty() {
344+
return Some(text);
345+
}
346+
}
347+
}
348+
349+
// Some upstreams may return a nonstandard streaming shape.
350+
if let Some(message) = choice.get("message") {
351+
if let Some(content) = message.get("content") {
352+
let text = extract_content_text(content);
353+
if !text.is_empty() {
354+
return Some(text);
355+
}
356+
}
357+
}
358+
359+
None
360+
}
361+
256362
async fn run_ai_moderation_with_history(
257363
text: &str,
258364
profile: &ModerationProfile,

0 commit comments

Comments
 (0)