Skip to content

Commit 6001156

Browse files
authored
Merge pull request #3234 from ultraworkers/fix/openai-compatible-reasoning-history
fix(providers): preserve OpenAI-compatible reasoning history
2 parents ae2f203 + a1da1ca commit 6001156

6 files changed

Lines changed: 133 additions & 120 deletions

File tree

rust/crates/api/src/providers/mod.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
296296
None
297297
}
298298

299-
300-
301-
302299
#[must_use]
303300
pub fn strip_provider_prefix(canonical_model: &str) -> String {
304301
if let Some(pos) = canonical_model.find('/') {
@@ -308,8 +305,6 @@ pub fn strip_provider_prefix(canonical_model: &str) -> String {
308305
}
309306
}
310307

311-
312-
313308
#[must_use]
314309
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
315310
let resolved_model = resolve_model_alias(model);

rust/crates/api/src/providers/openai_compat.rs

Lines changed: 42 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ use crate::types::{
1616
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
1717
};
1818

19-
use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix};
20-
19+
use super::{preflight_message_request, resolve_model_alias, Provider, ProviderFuture};
2120

2221
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
2322
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
@@ -213,80 +212,22 @@ impl OpenAiCompatClient {
213212
}
214213

215214
pub async fn send_message(
216-
&self,
217-
request: &MessageRequest,
218-
) -> Result<MessageResponse, ApiError> {
219-
// 1. Keep track of what Claw originally asked for
220-
let original_model = request.model.clone();
221-
let canonical = resolve_model_alias(&request.model);
222-
223-
// 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash")
224-
let downstream_model = strip_provider_prefix(&canonical);
225-
226-
let mut request = MessageRequest {
227-
stream: false,
228-
..request.clone()
229-
};
230-
request.model = downstream_model; // Use the clean name for the API payload
231-
232-
preflight_message_request(&request)?;
233-
let response = self.send_with_retry(&request).await?;
234-
let request_id = request_id_from_headers(response.headers());
235-
let body = response.text().await.map_err(ApiError::from)?;
236-
237-
// Some backends return {"error":{"message":"...","type":"...","code":...}}
238-
// instead of a valid completion object. Check for this before attempting
239-
// full deserialization so the user sees the actual error, not a cryptic.
240-
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
241-
if let Some(err_obj) = raw.get("error") {
242-
let msg = err_obj
243-
.get("message")
244-
.and_then(|m| m.as_str())
245-
.unwrap_or("provider returned an error")
246-
.to_string();
247-
let code = err_obj
248-
.get("code")
249-
.and_then(serde_json::Value::as_u64)
250-
.map(|c| c as u16);
251-
return Err(ApiError::Api {
252-
status: reqwest::StatusCode::from_u16(code.unwrap_or(400))
253-
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
254-
error_type: err_obj
255-
.get("type")
256-
.and_then(|t| t.as_str())
257-
.map(str::to_owned),
258-
message: Some(msg),
259-
request_id,
260-
body,
261-
retryable: false,
262-
suggested_action: suggested_action_for_status(
263-
reqwest::StatusCode::from_u16(code.unwrap_or(400))
264-
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
265-
),
266-
retry_after: None,
267-
});
268-
}
269-
}
270-
271-
// Pass original_model to the deserializer error context so debugging logs are accurate
272-
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
273-
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
274-
})?;
275-
276-
let mut normalized = normalize_response(&request.model, payload)?;
277-
if normalized.request_id.is_none() {
278-
normalized.request_id = request_id;
279-
}
215+
&self,
216+
request: &MessageRequest,
217+
) -> Result<MessageResponse, ApiError> {
218+
let original_model = request.model.clone();
219+
let canonical = resolve_model_alias(&request.model);
280220

281-
// 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy
282-
normalized.model = original_model;
221+
let mut request = MessageRequest {
222+
stream: false,
223+
..request.clone()
224+
};
225+
request.model = canonical;
283226

284-
Ok(normalized)
285-
}
286-
// Some backends return {"error":{"message":"...","type":"...","code":...}}
287-
// instead of a valid completion object. Check for this before attempting
288-
// full deserialization so the user sees the actual error, not a cryptic
289-
// "missing field 'id'" parse failure.
227+
preflight_message_request(&request)?;
228+
let response = self.send_with_retry(&request).await?;
229+
let request_id = request_id_from_headers(response.headers());
230+
let body = response.text().await.map_err(ApiError::from)?;
290231
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
291232
if let Some(err_obj) = raw.get("error") {
292233
let msg = err_obj
@@ -318,41 +259,41 @@ impl OpenAiCompatClient {
318259
}
319260
}
320261
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
321-
ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error)
262+
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
322263
})?;
323264
let mut normalized = normalize_response(&request.model, payload)?;
324265
if normalized.request_id.is_none() {
325266
normalized.request_id = request_id;
326267
}
268+
normalized.model = original_model;
327269
Ok(normalized)
328270
}
329271

330-
pub async fn stream_message(
331-
&self,
332-
request: &MessageRequest,
333-
) -> Result<MessageStream, ApiError> {
334-
// 1. Keep track of the original model name
335-
let original_model = request.model.clone();
336-
let canonical = resolve_model_alias(&request.model);
337-
338-
// 2. Clean it up for DeepSeek
339-
let downstream_model = strip_provider_prefix(&canonical);
340-
341-
let mut streaming_request = request.clone().with_streaming();
342-
streaming_request.model = downstream_model;
343-
344-
preflight_message_request(&streaming_request)?;
345-
let response = self.send_with_retry(&streaming_request).await?;
346-
347-
Ok(MessageStream {
348-
request_id: request_id_from_headers(response.headers()),
349-
response,
350-
parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()),
351-
pending: VecDeque::new(),
352-
done: false,
353-
state: StreamState::new(original_model), // 3. Use the original name here
354-
})
355-
}
272+
pub async fn stream_message(
273+
&self,
274+
request: &MessageRequest,
275+
) -> Result<MessageStream, ApiError> {
276+
let original_model = request.model.clone();
277+
let canonical = resolve_model_alias(&request.model);
278+
279+
let mut streaming_request = request.clone().with_streaming();
280+
streaming_request.model = canonical;
281+
282+
preflight_message_request(&streaming_request)?;
283+
let response = self.send_with_retry(&streaming_request).await?;
284+
285+
Ok(MessageStream {
286+
request_id: request_id_from_headers(response.headers()),
287+
response,
288+
parser: OpenAiSseParser::with_context(
289+
self.config.provider_name,
290+
original_model.clone(),
291+
),
292+
pending: VecDeque::new(),
293+
done: false,
294+
state: StreamState::new(original_model),
295+
})
296+
}
356297

357298
async fn send_with_retry(
358299
&self,

rust/crates/api/tests/openai_compat_integration.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,12 +548,13 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() {
548548
.with_base_url("http://origin.invalid/v1");
549549
let response = client
550550
.send_message(&MessageRequest {
551-
model: "gpt-4o".to_string(),
551+
model: "openai/gpt-4.1-mini".to_string(),
552552
..sample_request(false)
553553
})
554554
.await
555555
.expect("proxy should return the OpenAI-compatible response");
556556

557+
assert_eq!(response.model, "openai/gpt-4.1-mini");
557558
assert_eq!(response.total_tokens(), 7);
558559
let captured = state.lock().await;
559560
let request = captured.first().expect("proxy should capture request");
@@ -562,6 +563,8 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() {
562563
request.headers.get("authorization").map(String::as_str),
563564
Some("Bearer openai-test-key")
564565
);
566+
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
567+
assert_eq!(body["model"], json!("openai/gpt-4.1-mini"));
565568
}
566569

567570
#[allow(clippy::await_holding_lock)]

rust/crates/runtime/src/session_control.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,28 @@ mod tests {
832832

833833
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
834834

835+
struct EnvVarGuard {
836+
key: &'static str,
837+
previous: Option<std::ffi::OsString>,
838+
}
839+
840+
impl EnvVarGuard {
841+
fn set(key: &'static str, value: &Path) -> Self {
842+
let previous = std::env::var_os(key);
843+
std::env::set_var(key, value);
844+
Self { key, previous }
845+
}
846+
}
847+
848+
impl Drop for EnvVarGuard {
849+
fn drop(&mut self) {
850+
match &self.previous {
851+
Some(value) => std::env::set_var(self.key, value),
852+
None => std::env::remove_var(self.key),
853+
}
854+
}
855+
}
856+
835857
fn temp_dir() -> PathBuf {
836858
let nanos = SystemTime::now()
837859
.duration_since(UNIX_EPOCH)
@@ -1290,8 +1312,11 @@ mod tests {
12901312
#[test]
12911313
fn latest_session_returns_all_empty_error_when_sessions_exist_but_have_no_messages() {
12921314
// given — create sessions with 0 messages (empty)
1315+
let _env_guard = crate::test_env_lock();
12931316
let base = temp_dir();
12941317
fs::create_dir_all(&base).expect("base dir should exist");
1318+
let isolated_config_home = base.join("config-home");
1319+
let _claw_config_home = EnvVarGuard::set("CLAW_CONFIG_HOME", &isolated_config_home);
12951320
let store = SessionStore::from_cwd(&base).expect("store should build");
12961321

12971322
let empty_handle = store.create_handle("empty-session");

rust/crates/runtime/src/worker_boot.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,16 +1644,13 @@ mod tests {
16441644

16451645
let tmp = tempfile::tempdir().expect("tempdir");
16461646
let worktree = tmp.path().join("worktree");
1647-
let git_dir = tmp.path().join("external-gitdir");
16481647
fs::create_dir_all(&worktree).expect("worktree dir");
1649-
fs::create_dir_all(git_dir.join("objects")).expect("objects dir");
1650-
fs::create_dir_all(git_dir.join("refs/heads")).expect("refs dir");
1651-
fs::write(git_dir.join("HEAD"), "ref: refs/heads/main\n").expect("HEAD");
1652-
fs::write(
1653-
worktree.join(".git"),
1654-
format!("gitdir: {}\n", git_dir.display()),
1655-
)
1656-
.expect(".git file");
1648+
Command::new("git")
1649+
.arg("init")
1650+
.current_dir(&worktree)
1651+
.output()
1652+
.expect("git init should run");
1653+
let git_dir = worktree.join(".git");
16571654

16581655
let original_permissions = fs::metadata(&git_dir)
16591656
.expect("gitdir metadata")

rust/crates/rusty-claude-cli/src/main.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13737,8 +13737,15 @@ fn push_output_block(
1373713737
};
1373813738
*pending_tool = Some((id, name, initial_input));
1373913739
}
13740-
OutputContentBlock::Thinking { thinking, .. } => {
13740+
OutputContentBlock::Thinking {
13741+
thinking,
13742+
signature,
13743+
} => {
1374113744
render_thinking_block_summary(out, Some(thinking.chars().count()), false)?;
13745+
events.push(AssistantEvent::Thinking {
13746+
thinking,
13747+
signature,
13748+
});
1374213749
*block_has_thinking_summary = true;
1374313750
}
1374413751
OutputContentBlock::RedactedThinking { .. } => {
@@ -19073,6 +19080,13 @@ UU conflicted.rs",
1907319080

1907419081
assert!(matches!(
1907519082
&events[0],
19083+
AssistantEvent::Thinking {
19084+
thinking,
19085+
signature
19086+
} if thinking == "step 1" && signature.as_deref() == Some("sig_123")
19087+
));
19088+
assert!(matches!(
19089+
&events[1],
1907619090
AssistantEvent::TextDelta(text) if text == "Final answer"
1907719091
));
1907819092
let rendered = String::from_utf8(out).expect("utf8");
@@ -19649,6 +19663,41 @@ mod dump_manifests_tests {
1964919663

1965019664
#[cfg(test)]
1965119665
mod alias_resolution_tests {
19666+
fn ollama_env_lock() -> std::sync::MutexGuard<'static, ()> {
19667+
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
19668+
LOCK.get_or_init(|| std::sync::Mutex::new(()))
19669+
.lock()
19670+
.expect("ollama env lock poisoned")
19671+
}
19672+
19673+
struct EnvVarGuard {
19674+
key: &'static str,
19675+
previous: Option<String>,
19676+
}
19677+
19678+
impl EnvVarGuard {
19679+
fn unset(key: &'static str) -> Self {
19680+
let previous = std::env::var(key).ok();
19681+
std::env::remove_var(key);
19682+
Self { key, previous }
19683+
}
19684+
19685+
fn set(key: &'static str, value: &str) -> Self {
19686+
let previous = std::env::var(key).ok();
19687+
std::env::set_var(key, value);
19688+
Self { key, previous }
19689+
}
19690+
}
19691+
19692+
impl Drop for EnvVarGuard {
19693+
fn drop(&mut self) {
19694+
match &self.previous {
19695+
Some(value) => std::env::set_var(self.key, value),
19696+
None => std::env::remove_var(self.key),
19697+
}
19698+
}
19699+
}
19700+
1965219701
use super::{resolve_model_alias_with_config, validate_model_syntax};
1965319702

1965419703
#[test]
@@ -19670,6 +19719,8 @@ mod alias_resolution_tests {
1967019719

1967119720
#[test]
1967219721
fn test_alias_resolution_syntax_validation() {
19722+
let _guard = ollama_env_lock();
19723+
let _env = EnvVarGuard::unset("OLLAMA_HOST");
1967319724
// Resolved aliases should pass syntax validation
1967419725
let resolved = resolve_model_alias_with_config("opus");
1967519726
assert!(validate_model_syntax(&resolved).is_ok());
@@ -19680,6 +19731,8 @@ mod alias_resolution_tests {
1968019731

1968119732
#[test]
1968219733
fn test_unknown_alias_fails_validation() {
19734+
let _guard = ollama_env_lock();
19735+
let _env = EnvVarGuard::unset("OLLAMA_HOST");
1968319736
// Unknown aliases resolve to themselves
1968419737
let resolved = resolve_model_alias_with_config("unknown-alias");
1968519738
assert_eq!(resolved, "unknown-alias");
@@ -19699,14 +19752,13 @@ mod alias_resolution_tests {
1969919752
}
1970019753
#[test]
1970119754
fn test_ollama_host_bypasses_model_validation() {
19702-
// Safety: test sets and clears env var within the test.
19703-
std::env::set_var("OLLAMA_HOST", "http://127.0.0.1:11434");
19755+
let _guard = ollama_env_lock();
19756+
let _env = EnvVarGuard::set("OLLAMA_HOST", "http://127.0.0.1:11434");
1970419757
// Ollama model names with colons pass
1970519758
assert!(validate_model_syntax("qwen3:8b").is_ok());
1970619759
assert!(validate_model_syntax("gemma4:e2b").is_ok());
1970719760
assert!(validate_model_syntax("qwen3.6:27b-nvfp4").is_ok());
1970819761
// Empty model still rejected
1970919762
assert!(validate_model_syntax("").is_err());
19710-
std::env::remove_var("OLLAMA_HOST");
1971119763
}
1971219764
}

0 commit comments

Comments
 (0)