|
| 1 | +use std::collections::HashMap; |
| 2 | + |
| 3 | +use serde::{Serialize, de::DeserializeOwned}; |
| 4 | +use serde_json::Value; |
| 5 | + |
| 6 | +use crate::gateway::{ |
| 7 | + error::{GatewayError, Result}, |
| 8 | + traits::{NativeHandler, ProviderCapabilities}, |
| 9 | + types::{ |
| 10 | + common::BridgeContext, |
| 11 | + openai::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse}, |
| 12 | + }, |
| 13 | +}; |
| 14 | + |
| 15 | +/// A complete chat API format contract and its bridge rules to the hub format. |
| 16 | +pub trait ChatFormat: Send + Sync + 'static { |
| 17 | + /// Request type for this format. |
| 18 | + type Request: DeserializeOwned + Serialize + Send + Sync; |
| 19 | + /// Non-streaming response type for this format. |
| 20 | + type Response: Serialize + Send + Sync; |
| 21 | + /// Streaming chunk type for this format. |
| 22 | + type StreamChunk: Serialize + Send + Sync; |
| 23 | + /// Stateful bridge data used while converting hub chunks. |
| 24 | + type BridgeState: Default + Send + Unpin; |
| 25 | + /// Stateful bridge data used on native streaming paths. |
| 26 | + type NativeStreamState: Default + Send + Unpin; |
| 27 | + |
| 28 | + /// Stable format name used for logs and diagnostics. |
| 29 | + fn name() -> &'static str; |
| 30 | + |
| 31 | + /// Whether the request expects a streaming response. |
| 32 | + fn is_stream(req: &Self::Request) -> bool; |
| 33 | + |
| 34 | + /// Extract the model identifier from the request. |
| 35 | + fn extract_model(req: &Self::Request) -> &str; |
| 36 | + |
| 37 | + /// Convert this request into the hub request plus side-channel bridge data. |
| 38 | + fn to_hub(req: &Self::Request) -> Result<(ChatCompletionRequest, BridgeContext)>; |
| 39 | + |
| 40 | + /// Convert a hub response back into this format. |
| 41 | + fn from_hub(resp: &ChatCompletionResponse, ctx: &BridgeContext) -> Result<Self::Response>; |
| 42 | + |
| 43 | + /// Convert a hub streaming chunk into zero or more chunks of this format. |
| 44 | + fn from_hub_stream( |
| 45 | + chunk: &ChatCompletionChunk, |
| 46 | + state: &mut Self::BridgeState, |
| 47 | + ctx: &BridgeContext, |
| 48 | + ) -> Result<Vec<Self::StreamChunk>>; |
| 49 | + |
| 50 | + /// Emit any format-specific end-of-stream events. |
| 51 | + fn stream_end_events( |
| 52 | + _state: &mut Self::BridgeState, |
| 53 | + _ctx: &BridgeContext, |
| 54 | + ) -> Vec<Self::StreamChunk> { |
| 55 | + vec![] |
| 56 | + } |
| 57 | + |
| 58 | + /// Return a native handler when the provider can bypass the hub format. |
| 59 | + fn native_support(_provider: &dyn ProviderCapabilities) -> Option<NativeHandler<'_>> |
| 60 | + where |
| 61 | + Self: Sized, |
| 62 | + { |
| 63 | + None |
| 64 | + } |
| 65 | + |
| 66 | + /// Prepare a native request body for providers that support this format directly. |
| 67 | + fn call_native( |
| 68 | + native: &NativeHandler<'_>, |
| 69 | + request: &Self::Request, |
| 70 | + stream: bool, |
| 71 | + ) -> Result<(String, Value)> |
| 72 | + where |
| 73 | + Self: Sized, |
| 74 | + { |
| 75 | + let _ = (request, stream); |
| 76 | + Err(GatewayError::NativeNotSupported { |
| 77 | + provider: native.provider_name().into(), |
| 78 | + }) |
| 79 | + } |
| 80 | + |
| 81 | + /// Convert a native streaming chunk into zero or more chunks of this format. |
| 82 | + fn transform_native_stream_chunk( |
| 83 | + provider: &dyn ProviderCapabilities, |
| 84 | + raw: &str, |
| 85 | + state: &mut Self::NativeStreamState, |
| 86 | + ) -> Result<Vec<Self::StreamChunk>>; |
| 87 | + |
| 88 | + /// Parse a native non-streaming response into this format. |
| 89 | + fn parse_native_response(native: &NativeHandler<'_>, body: Value) -> Result<Self::Response> |
| 90 | + where |
| 91 | + Self: Sized, |
| 92 | + { |
| 93 | + let _ = body; |
| 94 | + Err(GatewayError::Bridge(format!( |
| 95 | + "parse_native_response called on a non-native format for provider {}", |
| 96 | + native.provider_name() |
| 97 | + ))) |
| 98 | + } |
| 99 | + |
| 100 | + /// Serialize a chunk into the JSON payload used by SSE framing. |
| 101 | + fn serialize_chunk_payload(chunk: &Self::StreamChunk) -> String; |
| 102 | + |
| 103 | + /// Optional SSE event type for this chunk. |
| 104 | + fn sse_event_type(_chunk: &Self::StreamChunk) -> Option<&'static str> { |
| 105 | + None |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +/// Incremental state for reconstructing tool calls across hub chunks. |
| 110 | +#[derive(Debug, Clone, Default)] |
| 111 | +pub struct ToolCallAccumulator { |
| 112 | + pub id: Option<String>, |
| 113 | + pub kind: Option<String>, |
| 114 | + pub name: Option<String>, |
| 115 | + pub arguments: String, |
| 116 | +} |
| 117 | + |
| 118 | +/// Key for partially assembled tool calls: (choice_index, tool_call_index). |
| 119 | +pub type ToolCallAccumulatorKey = (u32, usize); |
| 120 | + |
| 121 | +/// Stateful data used while transforming provider chunks into hub chunks. |
| 122 | +#[derive(Debug, Clone, Default)] |
| 123 | +pub struct ChatStreamState { |
| 124 | + pub chunk_index: usize, |
| 125 | + pub tool_call_accumulators: HashMap<ToolCallAccumulatorKey, ToolCallAccumulator>, |
| 126 | + pub input_tokens: u32, |
| 127 | + pub output_tokens: u32, |
| 128 | +} |
| 129 | + |
| 130 | +#[cfg(test)] |
| 131 | +mod tests { |
| 132 | + use std::borrow::Cow; |
| 133 | + |
| 134 | + use http::HeaderMap; |
| 135 | + use serde_json::json; |
| 136 | + |
| 137 | + use super::{ChatFormat, ChatStreamState, ToolCallAccumulator}; |
| 138 | + use crate::gateway::{ |
| 139 | + error::GatewayError, |
| 140 | + traits::{ |
| 141 | + NativeHandler, NativeOpenAIResponsesSupport, ProviderAuth, ProviderMeta, |
| 142 | + StreamReaderKind, provider::ChatTransform, |
| 143 | + }, |
| 144 | + types::{ |
| 145 | + common::BridgeContext, |
| 146 | + openai::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse}, |
| 147 | + }, |
| 148 | + }; |
| 149 | + |
| 150 | + struct DummyNativeProvider; |
| 151 | + |
| 152 | + impl ProviderMeta for DummyNativeProvider { |
| 153 | + fn name(&self) -> &'static str { |
| 154 | + "dummy-native-provider" |
| 155 | + } |
| 156 | + |
| 157 | + fn default_base_url(&self) -> &'static str { |
| 158 | + "https://example.com" |
| 159 | + } |
| 160 | + |
| 161 | + fn stream_reader_kind(&self) -> StreamReaderKind { |
| 162 | + StreamReaderKind::Sse |
| 163 | + } |
| 164 | + |
| 165 | + fn build_auth_headers( |
| 166 | + &self, |
| 167 | + _auth: &ProviderAuth, |
| 168 | + ) -> crate::gateway::error::Result<HeaderMap> { |
| 169 | + Ok(HeaderMap::new()) |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + impl ChatTransform for DummyNativeProvider {} |
| 174 | + |
| 175 | + impl NativeOpenAIResponsesSupport for DummyNativeProvider { |
| 176 | + fn native_openai_responses_endpoint(&self, _model: &str) -> Cow<'static, str> { |
| 177 | + Cow::Borrowed("/v1/responses") |
| 178 | + } |
| 179 | + |
| 180 | + fn transform_openai_responses_request( |
| 181 | + &self, |
| 182 | + _req: &crate::gateway::types::openai::responses::ResponsesApiRequest, |
| 183 | + ) -> crate::gateway::error::Result<serde_json::Value> { |
| 184 | + Ok(json!({})) |
| 185 | + } |
| 186 | + |
| 187 | + fn transform_openai_responses_response( |
| 188 | + &self, |
| 189 | + _body: serde_json::Value, |
| 190 | + ) -> crate::gateway::error::Result< |
| 191 | + crate::gateway::types::openai::responses::ResponsesApiResponse, |
| 192 | + > { |
| 193 | + unreachable!("not used in this test") |
| 194 | + } |
| 195 | + |
| 196 | + fn transform_openai_responses_stream_chunk( |
| 197 | + &self, |
| 198 | + _raw: &str, |
| 199 | + _state: &mut crate::gateway::traits::OpenAIResponsesNativeStreamState, |
| 200 | + ) -> crate::gateway::error::Result< |
| 201 | + Vec<crate::gateway::types::openai::responses::ResponsesApiStreamEvent>, |
| 202 | + > { |
| 203 | + Ok(vec![]) |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + struct DummyFormat; |
| 208 | + |
| 209 | + impl ChatFormat for DummyFormat { |
| 210 | + type Request = serde_json::Value; |
| 211 | + type Response = serde_json::Value; |
| 212 | + type StreamChunk = serde_json::Value; |
| 213 | + type BridgeState = (); |
| 214 | + type NativeStreamState = (); |
| 215 | + |
| 216 | + fn name() -> &'static str { |
| 217 | + "dummy" |
| 218 | + } |
| 219 | + |
| 220 | + fn is_stream(_req: &Self::Request) -> bool { |
| 221 | + false |
| 222 | + } |
| 223 | + |
| 224 | + fn extract_model(_req: &Self::Request) -> &str { |
| 225 | + "dummy-model" |
| 226 | + } |
| 227 | + |
| 228 | + fn to_hub( |
| 229 | + _req: &Self::Request, |
| 230 | + ) -> crate::gateway::error::Result<(ChatCompletionRequest, BridgeContext)> { |
| 231 | + unreachable!("not used in this test") |
| 232 | + } |
| 233 | + |
| 234 | + fn from_hub( |
| 235 | + _resp: &ChatCompletionResponse, |
| 236 | + _ctx: &BridgeContext, |
| 237 | + ) -> crate::gateway::error::Result<Self::Response> { |
| 238 | + unreachable!("not used in this test") |
| 239 | + } |
| 240 | + |
| 241 | + fn from_hub_stream( |
| 242 | + _chunk: &ChatCompletionChunk, |
| 243 | + _state: &mut Self::BridgeState, |
| 244 | + _ctx: &BridgeContext, |
| 245 | + ) -> crate::gateway::error::Result<Vec<Self::StreamChunk>> { |
| 246 | + Ok(vec![]) |
| 247 | + } |
| 248 | + |
| 249 | + fn transform_native_stream_chunk( |
| 250 | + _provider: &dyn crate::gateway::traits::ProviderCapabilities, |
| 251 | + _raw: &str, |
| 252 | + _state: &mut Self::NativeStreamState, |
| 253 | + ) -> crate::gateway::error::Result<Vec<Self::StreamChunk>> { |
| 254 | + Ok(vec![]) |
| 255 | + } |
| 256 | + |
| 257 | + fn serialize_chunk_payload(chunk: &Self::StreamChunk) -> String { |
| 258 | + serde_json::to_string(chunk).unwrap() |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + #[test] |
| 263 | + fn default_call_native_uses_provider_name() { |
| 264 | + let provider = DummyNativeProvider; |
| 265 | + let native = NativeHandler::OpenAIResponses(&provider); |
| 266 | + |
| 267 | + let error = DummyFormat::call_native(&native, &json!({}), false).unwrap_err(); |
| 268 | + assert!(matches!( |
| 269 | + error, |
| 270 | + GatewayError::NativeNotSupported { provider } if provider == "dummy-native-provider" |
| 271 | + )); |
| 272 | + } |
| 273 | + |
| 274 | + #[test] |
| 275 | + fn default_parse_native_response_returns_bridge_error() { |
| 276 | + let provider = DummyNativeProvider; |
| 277 | + let native = NativeHandler::OpenAIResponses(&provider); |
| 278 | + |
| 279 | + let error = DummyFormat::parse_native_response(&native, json!({})).unwrap_err(); |
| 280 | + assert!(matches!( |
| 281 | + error, |
| 282 | + GatewayError::Bridge(message) |
| 283 | + if message.contains("parse_native_response called on a non-native format") |
| 284 | + && message.contains("dummy-native-provider") |
| 285 | + )); |
| 286 | + } |
| 287 | + |
| 288 | + #[test] |
| 289 | + fn stream_state_separates_tool_call_accumulators_by_choice_and_index() { |
| 290 | + let mut state = ChatStreamState::default(); |
| 291 | + state.tool_call_accumulators.insert( |
| 292 | + (0, 0), |
| 293 | + ToolCallAccumulator { |
| 294 | + arguments: "first".into(), |
| 295 | + ..Default::default() |
| 296 | + }, |
| 297 | + ); |
| 298 | + state.tool_call_accumulators.insert( |
| 299 | + (1, 0), |
| 300 | + ToolCallAccumulator { |
| 301 | + arguments: "second".into(), |
| 302 | + ..Default::default() |
| 303 | + }, |
| 304 | + ); |
| 305 | + |
| 306 | + assert_eq!(state.tool_call_accumulators.len(), 2); |
| 307 | + assert_eq!( |
| 308 | + state.tool_call_accumulators.get(&(0, 0)).unwrap().arguments, |
| 309 | + "first" |
| 310 | + ); |
| 311 | + assert_eq!( |
| 312 | + state.tool_call_accumulators.get(&(1, 0)).unwrap().arguments, |
| 313 | + "second" |
| 314 | + ); |
| 315 | + } |
| 316 | +} |
0 commit comments