Skip to content

Commit 0755ddf

Browse files
committed
fix(providers): strip provider prefix from model names for openai_compat endpoints
1 parent 3acb677 commit 0755ddf

2 files changed

Lines changed: 112 additions & 29 deletions

File tree

rust/crates/api/src/providers/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,20 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
296296
None
297297
}
298298

299+
300+
301+
302+
#[must_use]
303+
pub fn strip_provider_prefix(canonical_model: &str) -> String {
304+
if let Some(pos) = canonical_model.find('/') {
305+
canonical_model[pos + 1..].to_string()
306+
} else {
307+
canonical_model.to_string()
308+
}
309+
}
310+
311+
312+
299313
#[must_use]
300314
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
301315
let resolved_model = resolve_model_alias(model);

rust/crates/api/src/providers/openai_compat.rs

Lines changed: 98 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ use crate::types::{
1616
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
1717
};
1818

19-
use super::{preflight_message_request, Provider, ProviderFuture};
19+
use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix};
20+
2021

2122
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
2223
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
@@ -212,17 +213,76 @@ impl OpenAiCompatClient {
212213
}
213214

214215
pub async fn send_message(
215-
&self,
216-
request: &MessageRequest,
217-
) -> Result<MessageResponse, ApiError> {
218-
let request = MessageRequest {
219-
stream: false,
220-
..request.clone()
221-
};
222-
preflight_message_request(&request)?;
223-
let response = self.send_with_retry(&request).await?;
224-
let request_id = request_id_from_headers(response.headers());
225-
let body = response.text().await.map_err(ApiError::from)?;
216+
&self,
217+
request: &MessageRequest,
218+
) -> Result<MessageResponse, ApiError> {
219+
// 1. Keep track of what Claw originally asked for
220+
let original_model = request.model.clone();
221+
let canonical = resolve_model_alias(&request.model);
222+
223+
// 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash")
224+
let downstream_model = strip_provider_prefix(&canonical);
225+
226+
let mut request = MessageRequest {
227+
stream: false,
228+
..request.clone()
229+
};
230+
request.model = downstream_model; // Use the clean name for the API payload
231+
232+
preflight_message_request(&request)?;
233+
let response = self.send_with_retry(&request).await?;
234+
let request_id = request_id_from_headers(response.headers());
235+
let body = response.text().await.map_err(ApiError::from)?;
236+
237+
// Some backends return {"error":{"message":"...","type":"...","code":...}}
238+
// instead of a valid completion object. Check for this before attempting
239+
// full deserialization so the user sees the actual error, not a cryptic.
240+
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
241+
if let Some(err_obj) = raw.get("error") {
242+
let msg = err_obj
243+
.get("message")
244+
.and_then(|m| m.as_str())
245+
.unwrap_or("provider returned an error")
246+
.to_string();
247+
let code = err_obj
248+
.get("code")
249+
.and_then(serde_json::Value::as_u64)
250+
.map(|c| c as u16);
251+
return Err(ApiError::Api {
252+
status: reqwest::StatusCode::from_u16(code.unwrap_or(400))
253+
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
254+
error_type: err_obj
255+
.get("type")
256+
.and_then(|t| t.as_str())
257+
.map(str::to_owned),
258+
message: Some(msg),
259+
request_id,
260+
body,
261+
retryable: false,
262+
suggested_action: suggested_action_for_status(
263+
reqwest::StatusCode::from_u16(code.unwrap_or(400))
264+
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
265+
),
266+
retry_after: None,
267+
});
268+
}
269+
}
270+
271+
// Pass original_model to the deserializer error context so debugging logs are accurate
272+
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
273+
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
274+
})?;
275+
276+
let mut normalized = normalize_response(&request.model, payload)?;
277+
if normalized.request_id.is_none() {
278+
normalized.request_id = request_id;
279+
}
280+
281+
// 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy
282+
normalized.model = original_model;
283+
284+
Ok(normalized)
285+
}
226286
// Some backends return {"error":{"message":"...","type":"...","code":...}}
227287
// instead of a valid completion object. Check for this before attempting
228288
// full deserialization so the user sees the actual error, not a cryptic
@@ -267,23 +327,32 @@ impl OpenAiCompatClient {
267327
Ok(normalized)
268328
}
269329

270-
pub async fn stream_message(
271-
&self,
272-
request: &MessageRequest,
273-
) -> Result<MessageStream, ApiError> {
274-
preflight_message_request(request)?;
275-
let response = self
276-
.send_with_retry(&request.clone().with_streaming())
277-
.await?;
278-
Ok(MessageStream {
279-
request_id: request_id_from_headers(response.headers()),
280-
response,
281-
parser: OpenAiSseParser::with_context(self.config.provider_name, request.model.clone()),
282-
pending: VecDeque::new(),
283-
done: false,
284-
state: StreamState::new(request.model.clone()),
285-
})
286-
}
330+
pub async fn stream_message(
331+
&self,
332+
request: &MessageRequest,
333+
) -> Result<MessageStream, ApiError> {
334+
// 1. Keep track of the original model name
335+
let original_model = request.model.clone();
336+
let canonical = resolve_model_alias(&request.model);
337+
338+
// 2. Clean it up for DeepSeek
339+
let downstream_model = strip_provider_prefix(&canonical);
340+
341+
let mut streaming_request = request.clone().with_streaming();
342+
streaming_request.model = downstream_model;
343+
344+
preflight_message_request(&streaming_request)?;
345+
let response = self.send_with_retry(&streaming_request).await?;
346+
347+
Ok(MessageStream {
348+
request_id: request_id_from_headers(response.headers()),
349+
response,
350+
parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()),
351+
pending: VecDeque::new(),
352+
done: false,
353+
state: StreamState::new(original_model), // 3. Use the original name here
354+
})
355+
}
287356

288357
async fn send_with_retry(
289358
&self,

0 commit comments

Comments
 (0)