Skip to content

Commit 9009a08

Browse files
committed
add real-time response streaming option for Anthropic; replace yanked pulldown-cmark-mdcat with pulldown-cmark
1 parent b0915b3 commit 9009a08

10 files changed

Lines changed: 677 additions & 1285 deletions

File tree

Cargo.lock

Lines changed: 21 additions & 1154 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ toml = "1"
2525
thiserror = "2"
2626
colored = "3"
2727
syntect = "5"
28-
# pulldown-cmark-mdcat v2.7 needs pulldown-cmark v0.12
29-
pulldown-cmark = "0.12"
30-
pulldown-cmark-mdcat = "2.7"
28+
pulldown-cmark = "0.13"
3129
similar = "2"
3230
crossterm = "0.29"
3331
reedline = "0.46"

src/api/anthropic.rs

Lines changed: 256 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use super::types::*;
22
use super::utils::{self, REQUEST_TIMEOUT};
33
use crate::error::{Result, SofosError};
4-
use futures::stream::{Stream, StreamExt};
4+
use futures::stream::StreamExt;
55
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
6-
use std::pin::Pin;
6+
use std::sync::atomic::{AtomicBool, Ordering};
7+
use std::sync::Arc;
78
use std::time::Duration;
89

910
const API_BASE: &str = "https://api.anthropic.com/v1";
@@ -57,13 +58,7 @@ impl AnthropicClient {
5758
}
5859
}
5960

60-
pub async fn create_anthropic_message(
61-
&self,
62-
request: CreateMessageRequest,
63-
) -> Result<CreateMessageResponse> {
64-
let url = format!("{}/messages", API_BASE);
65-
let mut request = request;
66-
61+
fn prepare_request(mut request: CreateMessageRequest) -> CreateMessageRequest {
6762
request.messages = sanitize_messages_for_anthropic(request.messages);
6863

6964
if let Some(tools) = request.tools.take() {
@@ -77,6 +72,16 @@ impl AnthropicClient {
7772
}
7873
}
7974

75+
request
76+
}
77+
78+
pub async fn create_anthropic_message(
79+
&self,
80+
request: CreateMessageRequest,
81+
) -> Result<CreateMessageResponse> {
82+
let url = format!("{}/messages", API_BASE);
83+
let request = Self::prepare_request(request);
84+
8085
let client = self.client.clone();
8186
let response = utils::with_retries("Anthropic", || {
8287
let client = client.clone();
@@ -91,52 +96,261 @@ impl AnthropicClient {
9196
Ok(result)
9297
}
9398

94-
pub async fn _create_message_stream(
99+
pub async fn create_message_streaming<FText, FThink>(
95100
&self,
96-
mut request: CreateMessageRequest,
97-
) -> Result<Pin<Box<dyn Stream<Item = Result<_StreamEvent>> + Send>>> {
101+
request: CreateMessageRequest,
102+
on_text_delta: FText,
103+
on_thinking_delta: FThink,
104+
interrupt_flag: Arc<AtomicBool>,
105+
) -> Result<CreateMessageResponse>
106+
where
107+
FText: Fn(&str) + Send + Sync,
108+
FThink: Fn(&str) + Send + Sync,
109+
{
110+
let mut request = Self::prepare_request(request);
98111
request.stream = Some(true);
112+
99113
let url = format!("{}/messages", API_BASE);
100-
let response = self.client.post(&url).json(&request).send().await?;
114+
115+
let client = self.client.clone();
116+
let response = utils::with_retries("Anthropic", || {
117+
let client = client.clone();
118+
let url = url.clone();
119+
let request = request.clone();
120+
async move {
121+
client
122+
.post(&url)
123+
.json(&request)
124+
.timeout(Duration::from_secs(600))
125+
.send()
126+
.await
127+
}
128+
})
129+
.await?;
130+
101131
let response = utils::check_response_status(response).await?;
102132

103-
let stream = response
104-
.bytes_stream()
105-
.map(|result| {
106-
result.map_err(SofosError::from).and_then(|bytes| {
107-
let text = String::from_utf8_lossy(&bytes);
108-
_parse_sse_events(&text)
109-
})
110-
})
111-
.flat_map(|result| {
112-
futures::stream::iter(match result {
113-
Ok(events) => events.into_iter().map(Ok).collect::<Vec<_>>(),
114-
Err(e) => vec![Err(e)],
115-
})
116-
});
117-
118-
Ok(Box::pin(stream))
119-
}
120-
}
133+
let mut byte_stream = response.bytes_stream();
134+
let mut buffer = String::new();
121135

122-
fn _parse_sse_events(text: &str) -> Result<Vec<_StreamEvent>> {
123-
let mut events = Vec::new();
136+
let mut message_id = String::new();
137+
let mut model_name = String::new();
138+
let mut content_blocks: Vec<ContentBlock> = Vec::new();
139+
let mut input_tokens: u32 = 0;
140+
let mut output_tokens: u32 = 0;
141+
let mut stop_reason: Option<String> = None;
124142

125-
for line in text.lines() {
126-
if let Some(json_str) = line.strip_prefix("data: ") {
127-
if json_str.trim() == "[DONE]" {
128-
break;
143+
let mut current_block_type: Option<String> = None;
144+
let mut current_text = String::new();
145+
let mut current_thinking = String::new();
146+
let mut current_signature = String::new();
147+
let mut current_tool_id = String::new();
148+
let mut current_tool_name = String::new();
149+
let mut current_tool_json = String::new();
150+
151+
while let Some(chunk_result) = byte_stream.next().await {
152+
if interrupt_flag.load(Ordering::Relaxed) {
153+
return Err(SofosError::Interrupted);
129154
}
130-
match serde_json::from_str::<_StreamEvent>(json_str) {
131-
Ok(event) => events.push(event),
132-
Err(e) => {
133-
tracing::warn!("Failed to parse SSE event: {} - {}", e, json_str);
155+
156+
let chunk = chunk_result
157+
.map_err(|e| SofosError::NetworkError(format!("Stream read error: {}", e)))?;
158+
buffer.push_str(&String::from_utf8_lossy(&chunk));
159+
160+
while let Some(pos) = buffer.find('\n') {
161+
let line = buffer[..pos].to_string();
162+
buffer = buffer[pos + 1..].to_string();
163+
164+
let line = line.trim_end();
165+
let json_str = match line.strip_prefix("data: ") {
166+
Some(s) if s == "[DONE]" => continue,
167+
Some(s) => s,
168+
None => continue,
169+
};
170+
171+
let event: serde_json::Value = match serde_json::from_str(json_str) {
172+
Ok(v) => v,
173+
Err(_) => continue,
174+
};
175+
176+
let event_type = event.get("type").and_then(|t| t.as_str()).unwrap_or("");
177+
178+
match event_type {
179+
"message_start" => {
180+
if let Some(msg) = event.get("message") {
181+
message_id = msg
182+
.get("id")
183+
.and_then(|v| v.as_str())
184+
.unwrap_or("")
185+
.to_string();
186+
model_name = msg
187+
.get("model")
188+
.and_then(|v| v.as_str())
189+
.unwrap_or("")
190+
.to_string();
191+
if let Some(u) = msg.get("usage") {
192+
input_tokens =
193+
u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0)
194+
as u32;
195+
}
196+
}
197+
}
198+
"content_block_start" => {
199+
if let Some(block) = event.get("content_block") {
200+
let btype = block.get("type").and_then(|t| t.as_str()).unwrap_or("");
201+
current_block_type = Some(btype.to_string());
202+
match btype {
203+
"text" => current_text.clear(),
204+
"thinking" => {
205+
current_thinking.clear();
206+
current_signature.clear();
207+
}
208+
"tool_use" | "server_tool_use" => {
209+
current_tool_id = block
210+
.get("id")
211+
.and_then(|v| v.as_str())
212+
.unwrap_or("")
213+
.to_string();
214+
current_tool_name = block
215+
.get("name")
216+
.and_then(|v| v.as_str())
217+
.unwrap_or("")
218+
.to_string();
219+
current_tool_json.clear();
220+
}
221+
"web_search_tool_result" => {
222+
if let Ok(result) =
223+
serde_json::from_value::<WebSearchToolResultBlock>(
224+
block.clone(),
225+
)
226+
{
227+
content_blocks.push(ContentBlock::WebSearchToolResult {
228+
tool_use_id: result.tool_use_id,
229+
content: result.content,
230+
});
231+
}
232+
current_block_type = None;
233+
}
234+
_ => {}
235+
}
236+
}
237+
}
238+
"content_block_delta" => {
239+
if let Some(delta) = event.get("delta") {
240+
let dtype = delta.get("type").and_then(|t| t.as_str()).unwrap_or("");
241+
match dtype {
242+
"text_delta" => {
243+
if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
244+
current_text.push_str(text);
245+
on_text_delta(text);
246+
}
247+
}
248+
"thinking_delta" => {
249+
if let Some(thinking) =
250+
delta.get("thinking").and_then(|v| v.as_str())
251+
{
252+
current_thinking.push_str(thinking);
253+
on_thinking_delta(thinking);
254+
}
255+
}
256+
"signature_delta" => {
257+
if let Some(sig) =
258+
delta.get("signature").and_then(|v| v.as_str())
259+
{
260+
current_signature.push_str(sig);
261+
}
262+
}
263+
"input_json_delta" => {
264+
if let Some(json_part) =
265+
delta.get("partial_json").and_then(|v| v.as_str())
266+
{
267+
current_tool_json.push_str(json_part);
268+
}
269+
}
270+
_ => {}
271+
}
272+
}
273+
}
274+
"content_block_stop" => {
275+
match current_block_type.as_deref() {
276+
Some("text") => {
277+
content_blocks.push(ContentBlock::Text {
278+
text: current_text.clone(),
279+
});
280+
}
281+
Some("thinking") => {
282+
content_blocks.push(ContentBlock::Thinking {
283+
thinking: current_thinking.clone(),
284+
signature: current_signature.clone(),
285+
});
286+
}
287+
Some("tool_use") => {
288+
let input = serde_json::from_str(&current_tool_json)
289+
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
290+
content_blocks.push(ContentBlock::ToolUse {
291+
id: current_tool_id.clone(),
292+
name: current_tool_name.clone(),
293+
input,
294+
});
295+
}
296+
Some("server_tool_use") => {
297+
let input = serde_json::from_str(&current_tool_json)
298+
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
299+
content_blocks.push(ContentBlock::ServerToolUse {
300+
id: current_tool_id.clone(),
301+
name: current_tool_name.clone(),
302+
input,
303+
});
304+
}
305+
_ => {}
306+
}
307+
current_block_type = None;
308+
}
309+
"message_delta" => {
310+
if let Some(delta) = event.get("delta") {
311+
stop_reason = delta
312+
.get("stop_reason")
313+
.and_then(|v| v.as_str())
314+
.map(String::from);
315+
}
316+
if let Some(u) = event.get("usage") {
317+
output_tokens =
318+
u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
319+
}
320+
}
321+
"error" => {
322+
let error_msg = event
323+
.get("error")
324+
.and_then(|e| e.get("message"))
325+
.and_then(|m| m.as_str())
326+
.unwrap_or("Unknown streaming error");
327+
return Err(SofosError::Api(format!("Streaming error: {}", error_msg)));
328+
}
329+
_ => {}
134330
}
135331
}
136332
}
333+
334+
Ok(CreateMessageResponse {
335+
_id: message_id,
336+
_response_type: "message".to_string(),
337+
_role: "assistant".to_string(),
338+
content: content_blocks,
339+
_model: model_name,
340+
stop_reason,
341+
usage: Usage {
342+
input_tokens,
343+
output_tokens,
344+
},
345+
})
137346
}
347+
}
138348

139-
Ok(events)
349+
#[derive(serde::Deserialize)]
350+
struct WebSearchToolResultBlock {
351+
tool_use_id: String,
352+
#[serde(default)]
353+
content: Vec<WebSearchResult>,
140354
}
141355

142356
fn sanitize_messages_for_anthropic(messages: Vec<Message>) -> Vec<Message> {

0 commit comments

Comments
 (0)