Skip to content

Commit db7317b

Browse files
authored
feat: unified chat handler (#92)
1 parent eaaed3e commit db7317b

12 files changed

Lines changed: 579 additions & 844 deletions

File tree

Lines changed: 58 additions & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -1,297 +1,91 @@
11
mod span_attributes;
22
mod types;
33

4-
use std::{convert::Infallible, time::Duration};
5-
6-
use axum::{
7-
Json,
8-
extract::State,
9-
response::{
10-
IntoResponse, Response,
11-
sse::{Event as SseEvent, Sse},
12-
},
13-
};
14-
use fastrace::prelude::{Event as TraceEvent, *};
15-
use log::error;
4+
use axum::response::sse::Event as SseEvent;
5+
use fastrace::Span;
166
use opentelemetry_semantic_conventions::attribute::GEN_AI_RESPONSE_FINISH_REASONS;
7+
use reqwest::Url;
178
use span_attributes::{
18-
StreamOutputCollector, apply_span_properties, chunk_span_properties, request_span_properties,
19-
response_span_properties, usage_span_properties,
9+
StreamOutputCollector, chunk_span_properties, request_span_properties, response_span_properties,
2010
};
21-
use tokio::sync::{oneshot, oneshot::error::TryRecvError};
2211
pub use types::ChatCompletionError;
2312

2413
use crate::{
25-
config::entities::{Model, ResourceEntry},
2614
gateway::{
27-
error::GatewayError,
2815
formats::OpenAIChatFormat,
29-
traits::ChatFormat,
16+
traits::ProviderCapabilities,
3017
types::{
3118
common::Usage,
32-
openai::{ChatCompletionRequest, ChatCompletionResponse},
33-
response::{ChatResponse, ChatResponseStream},
19+
openai::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse},
3420
},
3521
},
36-
proxy::{
37-
AppState,
38-
hooks::{self, RequestContext},
39-
provider::create_provider_instance,
40-
},
41-
utils::future::{WithSpan, maybe_timeout},
22+
proxy::handlers::format_handler::FormatHandlerAdapter,
4223
};
4324

44-
pub async fn chat_completions(
45-
State(state): State<AppState>,
46-
mut request_ctx: RequestContext,
47-
Json(mut request_data): Json<ChatCompletionRequest>,
48-
) -> Result<Response, ChatCompletionError> {
49-
hooks::observability::record_start_time(&mut request_ctx).await;
50-
hooks::authorization::check(
51-
&mut request_ctx,
52-
OpenAIChatFormat::extract_model(&request_data).to_owned(),
53-
)
54-
.await?;
55-
hooks::rate_limit::pre_check(&mut request_ctx).await?;
56-
57-
let model = request_ctx
58-
.extensions()
59-
.await
60-
.get::<ResourceEntry<Model>>()
61-
.cloned()
62-
.ok_or(ChatCompletionError::MissingModelInContext)?;
63-
64-
// Replace request model name with real model name
65-
request_data.model = model.model.clone();
66-
let timeout = model.timeout.map(Duration::from_millis);
67-
68-
let gateway = state.gateway();
69-
let resources = state.resources();
70-
let provider = model.provider(resources.as_ref()).ok_or_else(|| {
71-
GatewayError::Internal(format!("provider {} not found", model.provider_id))
72-
})?;
73-
let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?;
74-
let provider_base_url = provider_instance.effective_base_url().ok();
25+
pub(crate) struct ChatCompletionsAdapter;
7526

76-
let span = Span::enter_with_local_parent("aisix.llm.chat_completions");
77-
apply_span_properties(
78-
&span,
79-
request_span_properties(
80-
&request_data,
81-
provider_instance.def.as_ref(),
82-
provider_base_url.as_ref(),
83-
),
84-
);
27+
impl FormatHandlerAdapter for ChatCompletionsAdapter {
28+
type Format = OpenAIChatFormat;
29+
type Request = ChatCompletionRequest;
30+
type Response = ChatCompletionResponse;
31+
type StreamChunk = ChatCompletionChunk;
32+
type Error = ChatCompletionError;
33+
type Collector = StreamOutputCollector;
8534

86-
let (response, span) = (WithSpan {
87-
inner: maybe_timeout(
88-
timeout,
89-
gateway.chat_completion(&request_data, &provider_instance),
90-
),
91-
span: Some(span),
92-
})
93-
.await;
94-
95-
match response {
96-
Ok(Ok(ChatResponse::Complete { response, usage })) => {
97-
span.add_properties(|| response_span_properties(&response, &usage));
98-
handle_regular_request(response, usage, &mut request_ctx).await
99-
}
100-
Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => {
101-
handle_stream_request(stream, usage_rx, &mut request_ctx, span).await
102-
}
103-
Ok(Err(err)) => {
104-
span.add_property(|| ("error.type", "gateway_error"));
105-
Err(err.into())
106-
}
107-
Err(err) => {
108-
span.add_property(|| ("error.type", "timeout"));
109-
Err(ChatCompletionError::Timeout(err))
110-
}
35+
fn span_name() -> &'static str {
36+
"aisix.llm.chat_completions"
11137
}
112-
}
11338

114-
async fn handle_regular_request(
115-
response: ChatCompletionResponse,
116-
usage: Usage,
117-
request_ctx: &mut RequestContext,
118-
) -> Result<Response, ChatCompletionError> {
119-
if let Err(err) = hooks::rate_limit::post_check(request_ctx, &usage).await {
120-
error!("Rate limit post_check error: {}", err);
39+
fn missing_model_error() -> Self::Error {
40+
ChatCompletionError::MissingModelInContext
12141
}
12242

123-
let mut resp = Json(response).into_response();
124-
hooks::rate_limit::inject_response_headers(request_ctx, resp.headers_mut()).await;
125-
hooks::observability::record_usage(request_ctx, &usage).await;
43+
fn set_model(request: &mut Self::Request, model: String) {
44+
request.model = model;
45+
}
12646

127-
Ok(resp)
128-
}
47+
fn request_span_properties(
48+
request: &Self::Request,
49+
provider: &dyn ProviderCapabilities,
50+
base_url: Option<&Url>,
51+
) -> Vec<(String, String)> {
52+
request_span_properties(request, provider, base_url)
53+
}
12954

130-
fn spawn_stream_usage_observer(request_ctx: RequestContext, usage_rx: oneshot::Receiver<Usage>) {
131-
tokio::spawn(async move {
132-
let mut request_ctx = request_ctx;
55+
fn response_span_properties(response: &Self::Response, usage: &Usage) -> Vec<(String, String)> {
56+
response_span_properties(response, usage)
57+
}
13358

134-
match usage_rx.await {
135-
Ok(usage) => {
136-
if let Err(err) =
137-
hooks::rate_limit::post_check_streaming(&mut request_ctx, &usage).await
138-
{
139-
error!("Rate limit post_check_streaming error: {}", err);
140-
}
141-
hooks::observability::record_streaming_usage(&mut request_ctx, &usage).await;
142-
}
143-
Err(err) => {
144-
error!("Failed to receive streaming usage from gateway: {}", err);
145-
}
59+
fn apply_chunk_span_properties(span: &Span, chunk: &Self::StreamChunk, is_first_item: bool) {
60+
if is_first_item {
61+
span.add_properties(|| chunk_span_properties(chunk));
62+
return;
14663
}
147-
});
148-
}
149-
150-
async fn handle_stream_request(
151-
stream: ChatResponseStream<OpenAIChatFormat>,
152-
usage_rx: oneshot::Receiver<Usage>,
153-
request_ctx: &mut RequestContext,
154-
span: Span,
155-
) -> Result<Response, ChatCompletionError> {
156-
use futures::stream::StreamExt;
157-
158-
let stream_request_ctx = request_ctx.clone();
159-
let sse_stream = futures::stream::unfold(
160-
(
161-
stream,
162-
span,
163-
0usize,
164-
stream_request_ctx,
165-
false,
166-
false,
167-
Some(usage_rx),
168-
StreamOutputCollector::default(),
169-
),
170-
|(
171-
mut stream,
172-
span,
173-
idx,
174-
mut request_ctx,
175-
done,
176-
saw_chunk,
177-
mut usage_rx,
178-
mut output_collector,
179-
)| async move {
180-
if done {
181-
drop(span);
182-
return None;
183-
}
184-
185-
match stream.next().await {
186-
Some(Ok(chunk)) => {
187-
output_collector.record_chunk(&chunk);
18864

189-
if idx == 0 {
190-
hooks::observability::record_first_token_latency(&mut request_ctx).await;
191-
span.add_event(
192-
TraceEvent::new("first token arrived")
193-
.with_property(|| ("kind", "first_token_arrived")),
194-
);
195-
span.add_properties(|| chunk_span_properties(&chunk));
196-
} else {
197-
let properties = chunk_span_properties(&chunk);
198-
properties
199-
.iter()
200-
.filter(|(key, _)| {
201-
key == GEN_AI_RESPONSE_FINISH_REASONS
202-
|| key == "llm.finish_reason"
203-
|| key == "llm.token_count.completion_details.reasoning"
204-
})
205-
.for_each(|item| span.add_property(|| item.clone()));
206-
}
207-
208-
let mut event =
209-
SseEvent::default().data(OpenAIChatFormat::serialize_chunk_payload(&chunk));
210-
if let Some(event_type) = OpenAIChatFormat::sse_event_type(&chunk) {
211-
event = event.event(event_type);
212-
}
213-
let event = Ok::<SseEvent, Infallible>(event);
65+
let properties = chunk_span_properties(chunk);
66+
properties
67+
.iter()
68+
.filter(|(key, _)| {
69+
key == GEN_AI_RESPONSE_FINISH_REASONS
70+
|| key == "llm.finish_reason"
71+
|| key == "llm.token_count.completion_details.reasoning"
72+
})
73+
.for_each(|item| span.add_property(|| item.clone()));
74+
}
21475

215-
Some((
216-
event,
217-
(
218-
stream,
219-
span,
220-
idx + 1,
221-
request_ctx,
222-
false,
223-
true,
224-
usage_rx,
225-
output_collector,
226-
),
227-
))
228-
}
229-
Some(Err(err)) => {
230-
error!("Gateway stream error: {}", err);
231-
span.add_property(|| ("error.type", "stream_error"));
232-
span.add_properties(|| output_collector.output_message_span_properties());
233-
if let Some(usage_rx) = usage_rx.take() {
234-
spawn_stream_usage_observer(request_ctx.clone(), usage_rx);
235-
}
236-
drop(span);
237-
None
238-
}
239-
None => {
240-
span.add_properties(|| output_collector.output_message_span_properties());
76+
fn starts_output(_chunk: &Self::StreamChunk) -> bool {
77+
true
78+
}
24179

242-
if let Some(mut usage_rx) = usage_rx.take() {
243-
match usage_rx.try_recv() {
244-
Ok(usage) => {
245-
if let Err(err) = hooks::rate_limit::post_check_streaming(
246-
&mut request_ctx,
247-
&usage,
248-
)
249-
.await
250-
{
251-
error!("Rate limit post_check_streaming error: {}", err);
252-
}
253-
hooks::observability::record_streaming_usage(
254-
&mut request_ctx,
255-
&usage,
256-
)
257-
.await;
258-
span.add_properties(|| usage_span_properties(&usage));
259-
}
260-
Err(TryRecvError::Empty) => {
261-
spawn_stream_usage_observer(request_ctx.clone(), usage_rx);
262-
}
263-
Err(TryRecvError::Closed) => {
264-
error!(
265-
"Failed to receive streaming usage from gateway: channel closed"
266-
);
267-
}
268-
}
269-
}
80+
fn record_stream_item(collector: &mut Self::Collector, chunk: &Self::StreamChunk) {
81+
collector.record_chunk(chunk);
82+
}
27083

271-
if saw_chunk {
272-
Some((
273-
Ok(SseEvent::default().data("[DONE]")),
274-
(
275-
stream,
276-
span,
277-
idx + 1,
278-
request_ctx,
279-
true,
280-
saw_chunk,
281-
usage_rx,
282-
output_collector,
283-
),
284-
))
285-
} else {
286-
drop(span);
287-
None
288-
}
289-
}
290-
}
291-
},
292-
);
84+
fn output_message_span_properties(collector: &Self::Collector) -> Vec<(String, String)> {
85+
collector.output_message_span_properties()
86+
}
29387

294-
let mut response = Sse::new(sse_stream).into_response();
295-
hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await;
296-
Ok(response)
88+
fn end_of_stream_event(saw_item: bool) -> Option<SseEvent> {
89+
saw_item.then(|| SseEvent::default().data("[DONE]"))
90+
}
29791
}

src/proxy/handlers/chat_completions/span_attributes/mod.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,5 @@ pub(super) use telemetry::{
77
chunk_span_properties, request_span_properties, response_span_properties,
88
};
99

10-
pub(super) use crate::proxy::utils::trace::span_attributes::{
11-
apply_span_properties, usage_span_properties,
12-
};
13-
1410
#[cfg(test)]
1511
mod tests;

src/proxy/handlers/chat_completions/span_attributes/stream_output.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,12 @@ struct StreamOutputChoice {
2222
}
2323

2424
#[derive(Default)]
25-
pub(in crate::proxy::handlers::chat_completions) struct StreamOutputCollector {
25+
pub(crate) struct StreamOutputCollector {
2626
choices: BTreeMap<u32, StreamOutputChoice>,
2727
}
2828

2929
impl StreamOutputCollector {
30-
pub(in crate::proxy::handlers::chat_completions) fn record_chunk(
31-
&mut self,
32-
chunk: &ChatCompletionChunk,
33-
) {
30+
pub(crate) fn record_chunk(&mut self, chunk: &ChatCompletionChunk) {
3431
for choice in &chunk.choices {
3532
let output_choice = self.choices.entry(choice.index).or_default();
3633

@@ -69,9 +66,7 @@ impl StreamOutputCollector {
6966
}
7067
}
7168

72-
pub(in crate::proxy::handlers::chat_completions) fn output_message_span_properties(
73-
&self,
74-
) -> Vec<(String, String)> {
69+
pub(crate) fn output_message_span_properties(&self) -> Vec<(String, String)> {
7570
output_message_span_properties(&self.output_message_views())
7671
}
7772

0 commit comments

Comments
 (0)