|
1 | 1 | mod span_attributes; |
2 | 2 | mod types; |
3 | 3 |
|
4 | | -use std::{convert::Infallible, time::Duration}; |
5 | | - |
6 | | -use axum::{ |
7 | | - Json, |
8 | | - extract::State, |
9 | | - response::{ |
10 | | - IntoResponse, Response, |
11 | | - sse::{Event as SseEvent, Sse}, |
12 | | - }, |
13 | | -}; |
14 | | -use fastrace::prelude::{Event as TraceEvent, *}; |
15 | | -use log::error; |
| 4 | +use axum::response::sse::Event as SseEvent; |
| 5 | +use fastrace::Span; |
16 | 6 | use opentelemetry_semantic_conventions::attribute::GEN_AI_RESPONSE_FINISH_REASONS; |
| 7 | +use reqwest::Url; |
17 | 8 | use span_attributes::{ |
18 | | - StreamOutputCollector, apply_span_properties, chunk_span_properties, request_span_properties, |
19 | | - response_span_properties, usage_span_properties, |
| 9 | + StreamOutputCollector, chunk_span_properties, request_span_properties, response_span_properties, |
20 | 10 | }; |
21 | | -use tokio::sync::{oneshot, oneshot::error::TryRecvError}; |
22 | 11 | pub use types::ChatCompletionError; |
23 | 12 |
|
24 | 13 | use crate::{ |
25 | | - config::entities::{Model, ResourceEntry}, |
26 | 14 | gateway::{ |
27 | | - error::GatewayError, |
28 | 15 | formats::OpenAIChatFormat, |
29 | | - traits::ChatFormat, |
| 16 | + traits::ProviderCapabilities, |
30 | 17 | types::{ |
31 | 18 | common::Usage, |
32 | | - openai::{ChatCompletionRequest, ChatCompletionResponse}, |
33 | | - response::{ChatResponse, ChatResponseStream}, |
| 19 | + openai::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse}, |
34 | 20 | }, |
35 | 21 | }, |
36 | | - proxy::{ |
37 | | - AppState, |
38 | | - hooks::{self, RequestContext}, |
39 | | - provider::create_provider_instance, |
40 | | - }, |
41 | | - utils::future::{WithSpan, maybe_timeout}, |
| 22 | + proxy::handlers::format_handler::FormatHandlerAdapter, |
42 | 23 | }; |
43 | 24 |
|
44 | | -pub async fn chat_completions( |
45 | | - State(state): State<AppState>, |
46 | | - mut request_ctx: RequestContext, |
47 | | - Json(mut request_data): Json<ChatCompletionRequest>, |
48 | | -) -> Result<Response, ChatCompletionError> { |
49 | | - hooks::observability::record_start_time(&mut request_ctx).await; |
50 | | - hooks::authorization::check( |
51 | | - &mut request_ctx, |
52 | | - OpenAIChatFormat::extract_model(&request_data).to_owned(), |
53 | | - ) |
54 | | - .await?; |
55 | | - hooks::rate_limit::pre_check(&mut request_ctx).await?; |
56 | | - |
57 | | - let model = request_ctx |
58 | | - .extensions() |
59 | | - .await |
60 | | - .get::<ResourceEntry<Model>>() |
61 | | - .cloned() |
62 | | - .ok_or(ChatCompletionError::MissingModelInContext)?; |
63 | | - |
64 | | - // Replace request model name with real model name |
65 | | - request_data.model = model.model.clone(); |
66 | | - let timeout = model.timeout.map(Duration::from_millis); |
67 | | - |
68 | | - let gateway = state.gateway(); |
69 | | - let resources = state.resources(); |
70 | | - let provider = model.provider(resources.as_ref()).ok_or_else(|| { |
71 | | - GatewayError::Internal(format!("provider {} not found", model.provider_id)) |
72 | | - })?; |
73 | | - let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; |
74 | | - let provider_base_url = provider_instance.effective_base_url().ok(); |
| 25 | +pub(crate) struct ChatCompletionsAdapter; |
75 | 26 |
|
76 | | - let span = Span::enter_with_local_parent("aisix.llm.chat_completions"); |
77 | | - apply_span_properties( |
78 | | - &span, |
79 | | - request_span_properties( |
80 | | - &request_data, |
81 | | - provider_instance.def.as_ref(), |
82 | | - provider_base_url.as_ref(), |
83 | | - ), |
84 | | - ); |
| 27 | +impl FormatHandlerAdapter for ChatCompletionsAdapter { |
| 28 | + type Format = OpenAIChatFormat; |
| 29 | + type Request = ChatCompletionRequest; |
| 30 | + type Response = ChatCompletionResponse; |
| 31 | + type StreamChunk = ChatCompletionChunk; |
| 32 | + type Error = ChatCompletionError; |
| 33 | + type Collector = StreamOutputCollector; |
85 | 34 |
|
86 | | - let (response, span) = (WithSpan { |
87 | | - inner: maybe_timeout( |
88 | | - timeout, |
89 | | - gateway.chat_completion(&request_data, &provider_instance), |
90 | | - ), |
91 | | - span: Some(span), |
92 | | - }) |
93 | | - .await; |
94 | | - |
95 | | - match response { |
96 | | - Ok(Ok(ChatResponse::Complete { response, usage })) => { |
97 | | - span.add_properties(|| response_span_properties(&response, &usage)); |
98 | | - handle_regular_request(response, usage, &mut request_ctx).await |
99 | | - } |
100 | | - Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { |
101 | | - handle_stream_request(stream, usage_rx, &mut request_ctx, span).await |
102 | | - } |
103 | | - Ok(Err(err)) => { |
104 | | - span.add_property(|| ("error.type", "gateway_error")); |
105 | | - Err(err.into()) |
106 | | - } |
107 | | - Err(err) => { |
108 | | - span.add_property(|| ("error.type", "timeout")); |
109 | | - Err(ChatCompletionError::Timeout(err)) |
110 | | - } |
| 35 | + fn span_name() -> &'static str { |
| 36 | + "aisix.llm.chat_completions" |
111 | 37 | } |
112 | | -} |
113 | 38 |
|
114 | | -async fn handle_regular_request( |
115 | | - response: ChatCompletionResponse, |
116 | | - usage: Usage, |
117 | | - request_ctx: &mut RequestContext, |
118 | | -) -> Result<Response, ChatCompletionError> { |
119 | | - if let Err(err) = hooks::rate_limit::post_check(request_ctx, &usage).await { |
120 | | - error!("Rate limit post_check error: {}", err); |
| 39 | + fn missing_model_error() -> Self::Error { |
| 40 | + ChatCompletionError::MissingModelInContext |
121 | 41 | } |
122 | 42 |
|
123 | | - let mut resp = Json(response).into_response(); |
124 | | - hooks::rate_limit::inject_response_headers(request_ctx, resp.headers_mut()).await; |
125 | | - hooks::observability::record_usage(request_ctx, &usage).await; |
| 43 | + fn set_model(request: &mut Self::Request, model: String) { |
| 44 | + request.model = model; |
| 45 | + } |
126 | 46 |
|
127 | | - Ok(resp) |
128 | | -} |
| 47 | + fn request_span_properties( |
| 48 | + request: &Self::Request, |
| 49 | + provider: &dyn ProviderCapabilities, |
| 50 | + base_url: Option<&Url>, |
| 51 | + ) -> Vec<(String, String)> { |
| 52 | + request_span_properties(request, provider, base_url) |
| 53 | + } |
129 | 54 |
|
130 | | -fn spawn_stream_usage_observer(request_ctx: RequestContext, usage_rx: oneshot::Receiver<Usage>) { |
131 | | - tokio::spawn(async move { |
132 | | - let mut request_ctx = request_ctx; |
| 55 | + fn response_span_properties(response: &Self::Response, usage: &Usage) -> Vec<(String, String)> { |
| 56 | + response_span_properties(response, usage) |
| 57 | + } |
133 | 58 |
|
134 | | - match usage_rx.await { |
135 | | - Ok(usage) => { |
136 | | - if let Err(err) = |
137 | | - hooks::rate_limit::post_check_streaming(&mut request_ctx, &usage).await |
138 | | - { |
139 | | - error!("Rate limit post_check_streaming error: {}", err); |
140 | | - } |
141 | | - hooks::observability::record_streaming_usage(&mut request_ctx, &usage).await; |
142 | | - } |
143 | | - Err(err) => { |
144 | | - error!("Failed to receive streaming usage from gateway: {}", err); |
145 | | - } |
| 59 | + fn apply_chunk_span_properties(span: &Span, chunk: &Self::StreamChunk, is_first_item: bool) { |
| 60 | + if is_first_item { |
| 61 | + span.add_properties(|| chunk_span_properties(chunk)); |
| 62 | + return; |
146 | 63 | } |
147 | | - }); |
148 | | -} |
149 | | - |
150 | | -async fn handle_stream_request( |
151 | | - stream: ChatResponseStream<OpenAIChatFormat>, |
152 | | - usage_rx: oneshot::Receiver<Usage>, |
153 | | - request_ctx: &mut RequestContext, |
154 | | - span: Span, |
155 | | -) -> Result<Response, ChatCompletionError> { |
156 | | - use futures::stream::StreamExt; |
157 | | - |
158 | | - let stream_request_ctx = request_ctx.clone(); |
159 | | - let sse_stream = futures::stream::unfold( |
160 | | - ( |
161 | | - stream, |
162 | | - span, |
163 | | - 0usize, |
164 | | - stream_request_ctx, |
165 | | - false, |
166 | | - false, |
167 | | - Some(usage_rx), |
168 | | - StreamOutputCollector::default(), |
169 | | - ), |
170 | | - |( |
171 | | - mut stream, |
172 | | - span, |
173 | | - idx, |
174 | | - mut request_ctx, |
175 | | - done, |
176 | | - saw_chunk, |
177 | | - mut usage_rx, |
178 | | - mut output_collector, |
179 | | - )| async move { |
180 | | - if done { |
181 | | - drop(span); |
182 | | - return None; |
183 | | - } |
184 | | - |
185 | | - match stream.next().await { |
186 | | - Some(Ok(chunk)) => { |
187 | | - output_collector.record_chunk(&chunk); |
188 | 64 |
|
189 | | - if idx == 0 { |
190 | | - hooks::observability::record_first_token_latency(&mut request_ctx).await; |
191 | | - span.add_event( |
192 | | - TraceEvent::new("first token arrived") |
193 | | - .with_property(|| ("kind", "first_token_arrived")), |
194 | | - ); |
195 | | - span.add_properties(|| chunk_span_properties(&chunk)); |
196 | | - } else { |
197 | | - let properties = chunk_span_properties(&chunk); |
198 | | - properties |
199 | | - .iter() |
200 | | - .filter(|(key, _)| { |
201 | | - key == GEN_AI_RESPONSE_FINISH_REASONS |
202 | | - || key == "llm.finish_reason" |
203 | | - || key == "llm.token_count.completion_details.reasoning" |
204 | | - }) |
205 | | - .for_each(|item| span.add_property(|| item.clone())); |
206 | | - } |
207 | | - |
208 | | - let mut event = |
209 | | - SseEvent::default().data(OpenAIChatFormat::serialize_chunk_payload(&chunk)); |
210 | | - if let Some(event_type) = OpenAIChatFormat::sse_event_type(&chunk) { |
211 | | - event = event.event(event_type); |
212 | | - } |
213 | | - let event = Ok::<SseEvent, Infallible>(event); |
| 65 | + let properties = chunk_span_properties(chunk); |
| 66 | + properties |
| 67 | + .iter() |
| 68 | + .filter(|(key, _)| { |
| 69 | + key == GEN_AI_RESPONSE_FINISH_REASONS |
| 70 | + || key == "llm.finish_reason" |
| 71 | + || key == "llm.token_count.completion_details.reasoning" |
| 72 | + }) |
| 73 | + .for_each(|item| span.add_property(|| item.clone())); |
| 74 | + } |
214 | 75 |
|
215 | | - Some(( |
216 | | - event, |
217 | | - ( |
218 | | - stream, |
219 | | - span, |
220 | | - idx + 1, |
221 | | - request_ctx, |
222 | | - false, |
223 | | - true, |
224 | | - usage_rx, |
225 | | - output_collector, |
226 | | - ), |
227 | | - )) |
228 | | - } |
229 | | - Some(Err(err)) => { |
230 | | - error!("Gateway stream error: {}", err); |
231 | | - span.add_property(|| ("error.type", "stream_error")); |
232 | | - span.add_properties(|| output_collector.output_message_span_properties()); |
233 | | - if let Some(usage_rx) = usage_rx.take() { |
234 | | - spawn_stream_usage_observer(request_ctx.clone(), usage_rx); |
235 | | - } |
236 | | - drop(span); |
237 | | - None |
238 | | - } |
239 | | - None => { |
240 | | - span.add_properties(|| output_collector.output_message_span_properties()); |
| 76 | + fn starts_output(_chunk: &Self::StreamChunk) -> bool { |
| 77 | + true |
| 78 | + } |
241 | 79 |
|
242 | | - if let Some(mut usage_rx) = usage_rx.take() { |
243 | | - match usage_rx.try_recv() { |
244 | | - Ok(usage) => { |
245 | | - if let Err(err) = hooks::rate_limit::post_check_streaming( |
246 | | - &mut request_ctx, |
247 | | - &usage, |
248 | | - ) |
249 | | - .await |
250 | | - { |
251 | | - error!("Rate limit post_check_streaming error: {}", err); |
252 | | - } |
253 | | - hooks::observability::record_streaming_usage( |
254 | | - &mut request_ctx, |
255 | | - &usage, |
256 | | - ) |
257 | | - .await; |
258 | | - span.add_properties(|| usage_span_properties(&usage)); |
259 | | - } |
260 | | - Err(TryRecvError::Empty) => { |
261 | | - spawn_stream_usage_observer(request_ctx.clone(), usage_rx); |
262 | | - } |
263 | | - Err(TryRecvError::Closed) => { |
264 | | - error!( |
265 | | - "Failed to receive streaming usage from gateway: channel closed" |
266 | | - ); |
267 | | - } |
268 | | - } |
269 | | - } |
| 80 | + fn record_stream_item(collector: &mut Self::Collector, chunk: &Self::StreamChunk) { |
| 81 | + collector.record_chunk(chunk); |
| 82 | + } |
270 | 83 |
|
271 | | - if saw_chunk { |
272 | | - Some(( |
273 | | - Ok(SseEvent::default().data("[DONE]")), |
274 | | - ( |
275 | | - stream, |
276 | | - span, |
277 | | - idx + 1, |
278 | | - request_ctx, |
279 | | - true, |
280 | | - saw_chunk, |
281 | | - usage_rx, |
282 | | - output_collector, |
283 | | - ), |
284 | | - )) |
285 | | - } else { |
286 | | - drop(span); |
287 | | - None |
288 | | - } |
289 | | - } |
290 | | - } |
291 | | - }, |
292 | | - ); |
| 84 | + fn output_message_span_properties(collector: &Self::Collector) -> Vec<(String, String)> { |
| 85 | + collector.output_message_span_properties() |
| 86 | + } |
293 | 87 |
|
294 | | - let mut response = Sse::new(sse_stream).into_response(); |
295 | | - hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; |
296 | | - Ok(response) |
| 88 | + fn end_of_stream_event(saw_item: bool) -> Option<SseEvent> { |
| 89 | + saw_item.then(|| SseEvent::default().data("[DONE]")) |
| 90 | + } |
297 | 91 | } |
0 commit comments