11use super :: types:: * ;
22use super :: utils:: { self , REQUEST_TIMEOUT } ;
33use crate :: error:: { Result , SofosError } ;
4- use futures:: stream:: { Stream , StreamExt } ;
4+ use futures:: stream:: StreamExt ;
55use reqwest:: header:: { HeaderMap , HeaderValue , CONTENT_TYPE } ;
6- use std:: pin:: Pin ;
6+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
7+ use std:: sync:: Arc ;
78use std:: time:: Duration ;
89
910const API_BASE : & str = "https://api.anthropic.com/v1" ;
@@ -57,13 +58,7 @@ impl AnthropicClient {
5758 }
5859 }
5960
60- pub async fn create_anthropic_message (
61- & self ,
62- request : CreateMessageRequest ,
63- ) -> Result < CreateMessageResponse > {
64- let url = format ! ( "{}/messages" , API_BASE ) ;
65- let mut request = request;
66-
61+ fn prepare_request ( mut request : CreateMessageRequest ) -> CreateMessageRequest {
6762 request. messages = sanitize_messages_for_anthropic ( request. messages ) ;
6863
6964 if let Some ( tools) = request. tools . take ( ) {
@@ -77,6 +72,16 @@ impl AnthropicClient {
7772 }
7873 }
7974
75+ request
76+ }
77+
78+ pub async fn create_anthropic_message (
79+ & self ,
80+ request : CreateMessageRequest ,
81+ ) -> Result < CreateMessageResponse > {
82+ let url = format ! ( "{}/messages" , API_BASE ) ;
83+ let request = Self :: prepare_request ( request) ;
84+
8085 let client = self . client . clone ( ) ;
8186 let response = utils:: with_retries ( "Anthropic" , || {
8287 let client = client. clone ( ) ;
@@ -91,52 +96,261 @@ impl AnthropicClient {
9196 Ok ( result)
9297 }
9398
94- pub async fn _create_message_stream (
99+ pub async fn create_message_streaming < FText , FThink > (
95100 & self ,
96- mut request : CreateMessageRequest ,
97- ) -> Result < Pin < Box < dyn Stream < Item = Result < _StreamEvent > > + Send > > > {
101+ request : CreateMessageRequest ,
102+ on_text_delta : FText ,
103+ on_thinking_delta : FThink ,
104+ interrupt_flag : Arc < AtomicBool > ,
105+ ) -> Result < CreateMessageResponse >
106+ where
107+ FText : Fn ( & str ) + Send + Sync ,
108+ FThink : Fn ( & str ) + Send + Sync ,
109+ {
110+ let mut request = Self :: prepare_request ( request) ;
98111 request. stream = Some ( true ) ;
112+
99113 let url = format ! ( "{}/messages" , API_BASE ) ;
100- let response = self . client . post ( & url) . json ( & request) . send ( ) . await ?;
114+
115+ let client = self . client . clone ( ) ;
116+ let response = utils:: with_retries ( "Anthropic" , || {
117+ let client = client. clone ( ) ;
118+ let url = url. clone ( ) ;
119+ let request = request. clone ( ) ;
120+ async move {
121+ client
122+ . post ( & url)
123+ . json ( & request)
124+ . timeout ( Duration :: from_secs ( 600 ) )
125+ . send ( )
126+ . await
127+ }
128+ } )
129+ . await ?;
130+
101131 let response = utils:: check_response_status ( response) . await ?;
102132
103- let stream = response
104- . bytes_stream ( )
105- . map ( |result| {
106- result. map_err ( SofosError :: from) . and_then ( |bytes| {
107- let text = String :: from_utf8_lossy ( & bytes) ;
108- _parse_sse_events ( & text)
109- } )
110- } )
111- . flat_map ( |result| {
112- futures:: stream:: iter ( match result {
113- Ok ( events) => events. into_iter ( ) . map ( Ok ) . collect :: < Vec < _ > > ( ) ,
114- Err ( e) => vec ! [ Err ( e) ] ,
115- } )
116- } ) ;
117-
118- Ok ( Box :: pin ( stream) )
119- }
120- }
133+ let mut byte_stream = response. bytes_stream ( ) ;
134+ let mut buffer = String :: new ( ) ;
121135
122- fn _parse_sse_events ( text : & str ) -> Result < Vec < _StreamEvent > > {
123- let mut events = Vec :: new ( ) ;
136+ let mut message_id = String :: new ( ) ;
137+ let mut model_name = String :: new ( ) ;
138+ let mut content_blocks: Vec < ContentBlock > = Vec :: new ( ) ;
139+ let mut input_tokens: u32 = 0 ;
140+ let mut output_tokens: u32 = 0 ;
141+ let mut stop_reason: Option < String > = None ;
124142
125- for line in text. lines ( ) {
126- if let Some ( json_str) = line. strip_prefix ( "data: " ) {
127- if json_str. trim ( ) == "[DONE]" {
128- break ;
143+ let mut current_block_type: Option < String > = None ;
144+ let mut current_text = String :: new ( ) ;
145+ let mut current_thinking = String :: new ( ) ;
146+ let mut current_signature = String :: new ( ) ;
147+ let mut current_tool_id = String :: new ( ) ;
148+ let mut current_tool_name = String :: new ( ) ;
149+ let mut current_tool_json = String :: new ( ) ;
150+
151+ while let Some ( chunk_result) = byte_stream. next ( ) . await {
152+ if interrupt_flag. load ( Ordering :: Relaxed ) {
153+ return Err ( SofosError :: Interrupted ) ;
129154 }
130- match serde_json:: from_str :: < _StreamEvent > ( json_str) {
131- Ok ( event) => events. push ( event) ,
132- Err ( e) => {
133- tracing:: warn!( "Failed to parse SSE event: {} - {}" , e, json_str) ;
155+
156+ let chunk = chunk_result
157+ . map_err ( |e| SofosError :: NetworkError ( format ! ( "Stream read error: {}" , e) ) ) ?;
158+ buffer. push_str ( & String :: from_utf8_lossy ( & chunk) ) ;
159+
160+ while let Some ( pos) = buffer. find ( '\n' ) {
161+ let line = buffer[ ..pos] . to_string ( ) ;
162+ buffer = buffer[ pos + 1 ..] . to_string ( ) ;
163+
164+ let line = line. trim_end ( ) ;
165+ let json_str = match line. strip_prefix ( "data: " ) {
166+ Some ( s) if s == "[DONE]" => continue ,
167+ Some ( s) => s,
168+ None => continue ,
169+ } ;
170+
171+ let event: serde_json:: Value = match serde_json:: from_str ( json_str) {
172+ Ok ( v) => v,
173+ Err ( _) => continue ,
174+ } ;
175+
176+ let event_type = event. get ( "type" ) . and_then ( |t| t. as_str ( ) ) . unwrap_or ( "" ) ;
177+
178+ match event_type {
179+ "message_start" => {
180+ if let Some ( msg) = event. get ( "message" ) {
181+ message_id = msg
182+ . get ( "id" )
183+ . and_then ( |v| v. as_str ( ) )
184+ . unwrap_or ( "" )
185+ . to_string ( ) ;
186+ model_name = msg
187+ . get ( "model" )
188+ . and_then ( |v| v. as_str ( ) )
189+ . unwrap_or ( "" )
190+ . to_string ( ) ;
191+ if let Some ( u) = msg. get ( "usage" ) {
192+ input_tokens =
193+ u. get ( "input_tokens" ) . and_then ( |v| v. as_u64 ( ) ) . unwrap_or ( 0 )
194+ as u32 ;
195+ }
196+ }
197+ }
198+ "content_block_start" => {
199+ if let Some ( block) = event. get ( "content_block" ) {
200+ let btype = block. get ( "type" ) . and_then ( |t| t. as_str ( ) ) . unwrap_or ( "" ) ;
201+ current_block_type = Some ( btype. to_string ( ) ) ;
202+ match btype {
203+ "text" => current_text. clear ( ) ,
204+ "thinking" => {
205+ current_thinking. clear ( ) ;
206+ current_signature. clear ( ) ;
207+ }
208+ "tool_use" | "server_tool_use" => {
209+ current_tool_id = block
210+ . get ( "id" )
211+ . and_then ( |v| v. as_str ( ) )
212+ . unwrap_or ( "" )
213+ . to_string ( ) ;
214+ current_tool_name = block
215+ . get ( "name" )
216+ . and_then ( |v| v. as_str ( ) )
217+ . unwrap_or ( "" )
218+ . to_string ( ) ;
219+ current_tool_json. clear ( ) ;
220+ }
221+ "web_search_tool_result" => {
222+ if let Ok ( result) =
223+ serde_json:: from_value :: < WebSearchToolResultBlock > (
224+ block. clone ( ) ,
225+ )
226+ {
227+ content_blocks. push ( ContentBlock :: WebSearchToolResult {
228+ tool_use_id : result. tool_use_id ,
229+ content : result. content ,
230+ } ) ;
231+ }
232+ current_block_type = None ;
233+ }
234+ _ => { }
235+ }
236+ }
237+ }
238+ "content_block_delta" => {
239+ if let Some ( delta) = event. get ( "delta" ) {
240+ let dtype = delta. get ( "type" ) . and_then ( |t| t. as_str ( ) ) . unwrap_or ( "" ) ;
241+ match dtype {
242+ "text_delta" => {
243+ if let Some ( text) = delta. get ( "text" ) . and_then ( |v| v. as_str ( ) ) {
244+ current_text. push_str ( text) ;
245+ on_text_delta ( text) ;
246+ }
247+ }
248+ "thinking_delta" => {
249+ if let Some ( thinking) =
250+ delta. get ( "thinking" ) . and_then ( |v| v. as_str ( ) )
251+ {
252+ current_thinking. push_str ( thinking) ;
253+ on_thinking_delta ( thinking) ;
254+ }
255+ }
256+ "signature_delta" => {
257+ if let Some ( sig) =
258+ delta. get ( "signature" ) . and_then ( |v| v. as_str ( ) )
259+ {
260+ current_signature. push_str ( sig) ;
261+ }
262+ }
263+ "input_json_delta" => {
264+ if let Some ( json_part) =
265+ delta. get ( "partial_json" ) . and_then ( |v| v. as_str ( ) )
266+ {
267+ current_tool_json. push_str ( json_part) ;
268+ }
269+ }
270+ _ => { }
271+ }
272+ }
273+ }
274+ "content_block_stop" => {
275+ match current_block_type. as_deref ( ) {
276+ Some ( "text" ) => {
277+ content_blocks. push ( ContentBlock :: Text {
278+ text : current_text. clone ( ) ,
279+ } ) ;
280+ }
281+ Some ( "thinking" ) => {
282+ content_blocks. push ( ContentBlock :: Thinking {
283+ thinking : current_thinking. clone ( ) ,
284+ signature : current_signature. clone ( ) ,
285+ } ) ;
286+ }
287+ Some ( "tool_use" ) => {
288+ let input = serde_json:: from_str ( & current_tool_json)
289+ . unwrap_or ( serde_json:: Value :: Object ( serde_json:: Map :: new ( ) ) ) ;
290+ content_blocks. push ( ContentBlock :: ToolUse {
291+ id : current_tool_id. clone ( ) ,
292+ name : current_tool_name. clone ( ) ,
293+ input,
294+ } ) ;
295+ }
296+ Some ( "server_tool_use" ) => {
297+ let input = serde_json:: from_str ( & current_tool_json)
298+ . unwrap_or ( serde_json:: Value :: Object ( serde_json:: Map :: new ( ) ) ) ;
299+ content_blocks. push ( ContentBlock :: ServerToolUse {
300+ id : current_tool_id. clone ( ) ,
301+ name : current_tool_name. clone ( ) ,
302+ input,
303+ } ) ;
304+ }
305+ _ => { }
306+ }
307+ current_block_type = None ;
308+ }
309+ "message_delta" => {
310+ if let Some ( delta) = event. get ( "delta" ) {
311+ stop_reason = delta
312+ . get ( "stop_reason" )
313+ . and_then ( |v| v. as_str ( ) )
314+ . map ( String :: from) ;
315+ }
316+ if let Some ( u) = event. get ( "usage" ) {
317+ output_tokens =
318+ u. get ( "output_tokens" ) . and_then ( |v| v. as_u64 ( ) ) . unwrap_or ( 0 ) as u32 ;
319+ }
320+ }
321+ "error" => {
322+ let error_msg = event
323+ . get ( "error" )
324+ . and_then ( |e| e. get ( "message" ) )
325+ . and_then ( |m| m. as_str ( ) )
326+ . unwrap_or ( "Unknown streaming error" ) ;
327+ return Err ( SofosError :: Api ( format ! ( "Streaming error: {}" , error_msg) ) ) ;
328+ }
329+ _ => { }
134330 }
135331 }
136332 }
333+
334+ Ok ( CreateMessageResponse {
335+ _id : message_id,
336+ _response_type : "message" . to_string ( ) ,
337+ _role : "assistant" . to_string ( ) ,
338+ content : content_blocks,
339+ _model : model_name,
340+ stop_reason,
341+ usage : Usage {
342+ input_tokens,
343+ output_tokens,
344+ } ,
345+ } )
137346 }
347+ }
138348
139- Ok ( events)
349+ #[ derive( serde:: Deserialize ) ]
350+ struct WebSearchToolResultBlock {
351+ tool_use_id : String ,
352+ #[ serde( default ) ]
353+ content : Vec < WebSearchResult > ,
140354}
141355
142356fn sanitize_messages_for_anthropic ( messages : Vec < Message > ) -> Vec < Message > {
0 commit comments