Skip to content

Commit c164661

Browse files
committed
fix(providers): preserve OpenAI-compatible reasoning history
1 parent ae2f203 commit c164661

3 files changed

Lines changed: 61 additions & 107 deletions

File tree

rust/crates/api/src/providers/mod.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
296296
None
297297
}
298298

299-
300-
301-
302299
#[must_use]
303300
pub fn strip_provider_prefix(canonical_model: &str) -> String {
304301
if let Some(pos) = canonical_model.find('/') {
@@ -308,8 +305,6 @@ pub fn strip_provider_prefix(canonical_model: &str) -> String {
308305
}
309306
}
310307

311-
312-
313308
#[must_use]
314309
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
315310
let resolved_model = resolve_model_alias(model);

rust/crates/api/src/providers/openai_compat.rs

Lines changed: 46 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ use crate::types::{
1616
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
1717
};
1818

19-
use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix};
20-
19+
use super::{
20+
preflight_message_request, resolve_model_alias, strip_provider_prefix, Provider, ProviderFuture,
21+
};
2122

2223
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
2324
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
@@ -213,80 +214,23 @@ impl OpenAiCompatClient {
213214
}
214215

215216
pub async fn send_message(
216-
&self,
217-
request: &MessageRequest,
218-
) -> Result<MessageResponse, ApiError> {
219-
// 1. Keep track of what Claw originally asked for
220-
let original_model = request.model.clone();
221-
let canonical = resolve_model_alias(&request.model);
222-
223-
// 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash")
224-
let downstream_model = strip_provider_prefix(&canonical);
225-
226-
let mut request = MessageRequest {
227-
stream: false,
228-
..request.clone()
229-
};
230-
request.model = downstream_model; // Use the clean name for the API payload
231-
232-
preflight_message_request(&request)?;
233-
let response = self.send_with_retry(&request).await?;
234-
let request_id = request_id_from_headers(response.headers());
235-
let body = response.text().await.map_err(ApiError::from)?;
236-
237-
// Some backends return {"error":{"message":"...","type":"...","code":...}}
238-
// instead of a valid completion object. Check for this before attempting
239-
// full deserialization so the user sees the actual error, not a cryptic.
240-
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
241-
if let Some(err_obj) = raw.get("error") {
242-
let msg = err_obj
243-
.get("message")
244-
.and_then(|m| m.as_str())
245-
.unwrap_or("provider returned an error")
246-
.to_string();
247-
let code = err_obj
248-
.get("code")
249-
.and_then(serde_json::Value::as_u64)
250-
.map(|c| c as u16);
251-
return Err(ApiError::Api {
252-
status: reqwest::StatusCode::from_u16(code.unwrap_or(400))
253-
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
254-
error_type: err_obj
255-
.get("type")
256-
.and_then(|t| t.as_str())
257-
.map(str::to_owned),
258-
message: Some(msg),
259-
request_id,
260-
body,
261-
retryable: false,
262-
suggested_action: suggested_action_for_status(
263-
reqwest::StatusCode::from_u16(code.unwrap_or(400))
264-
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
265-
),
266-
retry_after: None,
267-
});
268-
}
269-
}
270-
271-
// Pass original_model to the deserializer error context so debugging logs are accurate
272-
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
273-
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
274-
})?;
275-
276-
let mut normalized = normalize_response(&request.model, payload)?;
277-
if normalized.request_id.is_none() {
278-
normalized.request_id = request_id;
279-
}
217+
&self,
218+
request: &MessageRequest,
219+
) -> Result<MessageResponse, ApiError> {
220+
let original_model = request.model.clone();
221+
let canonical = resolve_model_alias(&request.model);
222+
let downstream_model = strip_provider_prefix(&canonical);
280223

281-
// 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy
282-
normalized.model = original_model;
224+
let mut request = MessageRequest {
225+
stream: false,
226+
..request.clone()
227+
};
228+
request.model = downstream_model;
283229

284-
Ok(normalized)
285-
}
286-
// Some backends return {"error":{"message":"...","type":"...","code":...}}
287-
// instead of a valid completion object. Check for this before attempting
288-
// full deserialization so the user sees the actual error, not a cryptic
289-
// "missing field 'id'" parse failure.
230+
preflight_message_request(&request)?;
231+
let response = self.send_with_retry(&request).await?;
232+
let request_id = request_id_from_headers(response.headers());
233+
let body = response.text().await.map_err(ApiError::from)?;
290234
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
291235
if let Some(err_obj) = raw.get("error") {
292236
let msg = err_obj
@@ -318,41 +262,42 @@ impl OpenAiCompatClient {
318262
}
319263
}
320264
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
321-
ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error)
265+
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
322266
})?;
323267
let mut normalized = normalize_response(&request.model, payload)?;
324268
if normalized.request_id.is_none() {
325269
normalized.request_id = request_id;
326270
}
271+
normalized.model = original_model;
327272
Ok(normalized)
328273
}
329274

330-
pub async fn stream_message(
331-
&self,
332-
request: &MessageRequest,
333-
) -> Result<MessageStream, ApiError> {
334-
// 1. Keep track of the original model name
335-
let original_model = request.model.clone();
336-
let canonical = resolve_model_alias(&request.model);
337-
338-
// 2. Clean it up for DeepSeek
339-
let downstream_model = strip_provider_prefix(&canonical);
340-
341-
let mut streaming_request = request.clone().with_streaming();
342-
streaming_request.model = downstream_model;
343-
344-
preflight_message_request(&streaming_request)?;
345-
let response = self.send_with_retry(&streaming_request).await?;
346-
347-
Ok(MessageStream {
348-
request_id: request_id_from_headers(response.headers()),
349-
response,
350-
parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()),
351-
pending: VecDeque::new(),
352-
done: false,
353-
state: StreamState::new(original_model), // 3. Use the original name here
354-
})
355-
}
275+
pub async fn stream_message(
276+
&self,
277+
request: &MessageRequest,
278+
) -> Result<MessageStream, ApiError> {
279+
let original_model = request.model.clone();
280+
let canonical = resolve_model_alias(&request.model);
281+
let downstream_model = strip_provider_prefix(&canonical);
282+
283+
let mut streaming_request = request.clone().with_streaming();
284+
streaming_request.model = downstream_model;
285+
286+
preflight_message_request(&streaming_request)?;
287+
let response = self.send_with_retry(&streaming_request).await?;
288+
289+
Ok(MessageStream {
290+
request_id: request_id_from_headers(response.headers()),
291+
response,
292+
parser: OpenAiSseParser::with_context(
293+
self.config.provider_name,
294+
original_model.clone(),
295+
),
296+
pending: VecDeque::new(),
297+
done: false,
298+
state: StreamState::new(original_model),
299+
})
300+
}
356301

357302
async fn send_with_retry(
358303
&self,

rust/crates/rusty-claude-cli/src/main.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13737,8 +13737,15 @@ fn push_output_block(
1373713737
};
1373813738
*pending_tool = Some((id, name, initial_input));
1373913739
}
13740-
OutputContentBlock::Thinking { thinking, .. } => {
13740+
OutputContentBlock::Thinking {
13741+
thinking,
13742+
signature,
13743+
} => {
1374113744
render_thinking_block_summary(out, Some(thinking.chars().count()), false)?;
13745+
events.push(AssistantEvent::Thinking {
13746+
thinking,
13747+
signature,
13748+
});
1374213749
*block_has_thinking_summary = true;
1374313750
}
1374413751
OutputContentBlock::RedactedThinking { .. } => {
@@ -19073,6 +19080,13 @@ UU conflicted.rs",
1907319080

1907419081
assert!(matches!(
1907519082
&events[0],
19083+
AssistantEvent::Thinking {
19084+
thinking,
19085+
signature
19086+
} if thinking == "step 1" && signature.as_deref() == Some("sig_123")
19087+
));
19088+
assert!(matches!(
19089+
&events[1],
1907619090
AssistantEvent::TextDelta(text) if text == "Final answer"
1907719091
));
1907819092
let rendered = String::from_utf8(out).expect("utf8");

0 commit comments

Comments
 (0)