11import json
2- from typing import Any , Callable , ClassVar , Dict , List , Optional , Tuple , Union
2+ from typing import Any , ClassVar , Dict , List , Literal , Optional , Tuple , Union
33
44from haystack import component , default_from_dict , default_to_dict , logging
5- from haystack .dataclasses import (
5+ from haystack .dataclasses .chat_message import ChatMessage , ChatRole , ToolCall , ToolCallResult
6+ from haystack .dataclasses .streaming_chunk import (
67 AsyncStreamingCallbackT ,
7- ChatMessage ,
8- ChatRole ,
98 StreamingCallbackT ,
109 StreamingChunk ,
11- ToolCall ,
12- ToolCallResult ,
10+ SyncStreamingCallbackT ,
1311 select_streaming_callback ,
1412)
1513from haystack .tools import (
1917 deserialize_tools_or_toolset_inplace ,
2018 serialize_tools_or_toolset ,
2119)
22- from haystack .utils import Secret , deserialize_callable , deserialize_secrets_inplace , serialize_callable
20+ from haystack .utils .auth import Secret , deserialize_secrets_inplace
21+ from haystack .utils .callable_serialization import deserialize_callable , serialize_callable
2322
2423from anthropic import Anthropic , AsyncAnthropic
24+ from anthropic .resources .messages .messages import Message , RawMessageStreamEvent , Stream
25+ from anthropic .types import MessageParam , TextBlockParam , ToolParam , ToolResultBlockParam , ToolUseBlockParam
2526
2627logger = logging .getLogger (__name__ )
2728
2829
2930def _update_anthropic_message_with_tool_call_results (
30- tool_call_results : List [ToolCallResult ], anthropic_msg : Dict [str , Any ]
31+ tool_call_results : List [ToolCallResult ],
32+ content : List [Union [TextBlockParam , ToolUseBlockParam , ToolResultBlockParam ]],
3133) -> None :
3234 """
33- Update an Anthropic message with tool call results.
35+ Update an Anthropic message content list with tool call results.
3436
3537 :param tool_call_results: The list of ToolCallResults to update the message with.
36- :param anthropic_msg : The Anthropic message to update.
38+ :param content : The Anthropic message content list to update.
3739 """
38- if "content" not in anthropic_msg :
39- anthropic_msg ["content" ] = []
40-
4140 for tool_call_result in tool_call_results :
4241 if tool_call_result .origin .id is None :
4342 msg = "`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
4443 raise ValueError (msg )
45- anthropic_msg ["content" ].append (
46- {
47- "type" : "tool_result" ,
48- "tool_use_id" : tool_call_result .origin .id ,
49- "content" : [{"type" : "text" , "text" : tool_call_result .result }],
50- "is_error" : tool_call_result .error ,
51- }
44+
45+ tool_result_block = ToolResultBlockParam (
46+ type = "tool_result" ,
47+ tool_use_id = tool_call_result .origin .id ,
48+ content = [{"type" : "text" , "text" : tool_call_result .result }],
49+ is_error = tool_call_result .error ,
5250 )
51+ content .append (tool_result_block )
5352
5453
55- def _convert_tool_calls_to_anthropic_format (tool_calls : List [ToolCall ]) -> List [Dict [ str , Any ] ]:
54+ def _convert_tool_calls_to_anthropic_format (tool_calls : List [ToolCall ]) -> List [ToolUseBlockParam ]:
5655 """
5756 Convert a list of tool calls to the format expected by Anthropic Chat API.
5857
5958 :param tool_calls: The list of ToolCalls to convert.
60- :return: A list of dictionaries in the format expected by Anthropic API.
59+ :return: A list of ToolUseBlockParam objects in the format expected by Anthropic API.
6160 """
6261 anthropic_tool_calls = []
6362 for tc in tool_calls :
6463 if tc .id is None :
6564 msg = "`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
6665 raise ValueError (msg )
67- anthropic_tool_calls .append (
68- {
69- "type" : "tool_use" ,
70- "id" : tc .id ,
71- "name" : tc .tool_name ,
72- "input" : tc .arguments ,
73- }
66+
67+ tool_use_block = ToolUseBlockParam (
68+ type = "tool_use" ,
69+ id = tc .id ,
70+ name = tc .tool_name ,
71+ input = tc .arguments ,
7472 )
73+ anthropic_tool_calls .append (tool_use_block )
7574 return anthropic_tool_calls
7675
7776
7877def _convert_messages_to_anthropic_format (
7978 messages : List [ChatMessage ],
80- ) -> Tuple [List [Dict [ str , Any ]] , List [Dict [ str , Any ] ]]:
79+ ) -> Tuple [List [TextBlockParam ] , List [MessageParam ]]:
8180 """
8281 Convert a list of messages to the format expected by Anthropic Chat API.
8382
8483 :param messages: The list of ChatMessages to convert.
8584 :return: A tuple of two lists:
86- - A list of system message dictionaries in the format expected by Anthropic API.
87- - A list of non-system message dictionaries in the format expected by Anthropic API.
85+ - A list of system message TextBlockParam objects in the format expected by Anthropic API.
86+ - A list of non-system MessageParam objects in the format expected by Anthropic API.
8887 """
8988
90- anthropic_system_messages = []
91- anthropic_non_system_messages = []
89+ anthropic_system_messages : List [ TextBlockParam ] = []
90+ anthropic_non_system_messages : List [ MessageParam ] = []
9291
9392 i = 0
9493 while i < len (messages ):
9594 message = messages [i ]
9695
97- # allow passing cache_control
98- cache_control = {"cache_control" : message .meta .get ("cache_control" )} if "cache_control" in message .meta else {}
99-
10096 # system messages have special format requirements for Anthropic API
10197 # they can have only type and text fields, and they need to be passed separately
10298 # to the Anthropic API endpoint
103- if message .is_from (ChatRole .SYSTEM ):
104- anthropic_system_messages .append ({"type" : "text" , "text" : message .text , ** cache_control })
99+ if message .is_from (ChatRole .SYSTEM ) and message .text :
100+ sys_message = TextBlockParam (type = "text" , text = message .text )
101+ if cache_control := message .meta .get ("cache_control" ):
102+ sys_message ["cache_control" ] = cache_control
103+ anthropic_system_messages .append (sys_message )
105104 i += 1
106105 continue
107106
108- anthropic_msg : Dict [ str , Any ] = { "role" : message . _role . value , "content" : [], ** cache_control }
107+ content : List [ Union [ TextBlockParam , ToolUseBlockParam , ToolResultBlockParam ]] = []
109108
110109 if message .texts and message .texts [0 ]:
111- anthropic_msg ["content" ].append ({"type" : "text" , "text" : message .texts [0 ]})
110+ text_block = TextBlockParam (type = "text" , text = message .texts [0 ])
111+ content .append (text_block )
112+
112113 if message .tool_calls :
113- anthropic_msg ["content" ] += _convert_tool_calls_to_anthropic_format (message .tool_calls )
114+ tool_use_blocks = _convert_tool_calls_to_anthropic_format (message .tool_calls )
115+ content .extend (tool_use_blocks )
114116
115117 if message .tool_call_results :
116118 results = message .tool_call_results .copy ()
@@ -119,14 +121,20 @@ def _convert_messages_to_anthropic_format(
119121 i += 1
120122 results .extend (messages [i ].tool_call_results )
121123
122- _update_anthropic_message_with_tool_call_results (results , anthropic_msg )
123- anthropic_msg ["role" ] = "user"
124+ _update_anthropic_message_with_tool_call_results (results , content )
124125
125- if not anthropic_msg [ " content" ] :
126+ if not content :
126127 msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
127128 raise ValueError (msg )
128129
129- anthropic_non_system_messages .append (anthropic_msg )
130+ # Anthropic only supports assistant and user roles in messages. User role is also used for tool messages.
131+ # System messages are passed separately.
132+ role : Union [Literal ["assistant" ], Literal ["user" ]] = "user"
133+ if message ._role == ChatRole .ASSISTANT :
134+ role = "assistant"
135+
136+ anthropic_message = MessageParam (role = role , content = content )
137+ anthropic_non_system_messages .append (anthropic_message )
130138 i += 1
131139
132140 return anthropic_system_messages , anthropic_non_system_messages
@@ -340,11 +348,14 @@ def _convert_streaming_chunks_to_chat_message(
340348 for chunk in chunks :
341349 chunk_type = chunk .meta .get ("type" )
342350 if chunk_type == "content_block_start" :
343- if chunk .meta .get ("content_block" , {}).get ("type" ) == "tool_use" :
344- delta_block = chunk .meta .get ("content_block" )
351+ content_block = chunk .meta .get ("content_block" )
352+ if content_block is None :
353+ msg = "Invalid streaming chunk. Expected 'content_block' field."
354+ raise ValueError (msg )
355+ if content_block .get ("type" ) == "tool_use" :
345356 current_tool_call = {
346- "id" : delta_block .get ("id" ),
347- "name" : delta_block .get ("name" ),
357+ "id" : content_block .get ("id" ),
358+ "name" : content_block .get ("name" ),
348359 "arguments" : "" ,
349360 }
350361 elif chunk_type == "content_block_delta" :
@@ -388,21 +399,12 @@ def _convert_streaming_chunks_to_chat_message(
388399
389400 return message
390401
391- @staticmethod
392- def _remove_cache_control (message : Dict [str , Any ]) -> Dict [str , Any ]:
393- """
394- Removes the cache_control key from the message.
395- :param message: The message to remove the cache_control key from.
396- :returns: The message with the cache_control key removed.
397- """
398- return {k : v for k , v in message .items () if k != "cache_control" }
399-
400402 def _prepare_request_params (
401403 self ,
402404 messages : List [ChatMessage ],
403405 generation_kwargs : Optional [Dict [str , Any ]] = None ,
404406 tools : Optional [Union [List [Tool ], Toolset ]] = None ,
405- ) -> Tuple [List [Dict [ str , Any ]] , List [Dict [ str , Any ]] , Dict [str , Any ], List [Dict [ str , Any ] ]]:
407+ ) -> Tuple [List [TextBlockParam ] , List [MessageParam ] , Dict [str , Any ], List [ToolParam ]]:
406408 """
407409 Prepare the parameters for the Anthropic API request.
408410
@@ -433,8 +435,8 @@ def _prepare_request_params(
433435 # prompt caching
434436 extra_headers = generation_kwargs .get ("extra_headers" , {})
435437 prompt_caching_on = "anthropic-beta" in extra_headers and "prompt-caching" in extra_headers ["anthropic-beta" ]
436- has_cached_messages = any ("cache_control" in m for m in system_messages ) or any (
437- "cache_control" in m for m in non_system_messages
438+ has_cached_messages = any (m . get ( "cache_control" ) is not None for m in system_messages ) or any (
439+ m . get ( "cache_control" ) is not None for m in non_system_messages
438440 )
439441 if has_cached_messages and not prompt_caching_on :
440442 # this avoids Anthropic errors when prompt caching is not enabled
@@ -443,32 +445,28 @@ def _prepare_request_params(
443445 "Prompt caching is not enabled but you requested individual messages to be cached. "
444446 "Messages will be sent to the API without prompt caching."
445447 )
446- system_messages = list (map (self ._remove_cache_control , system_messages ))
447- non_system_messages = list (map (self ._remove_cache_control , non_system_messages ))
448+ for message in system_messages :
449+ if message .get ("cache_control" ):
450+ del message ["cache_control" ]
448451
449452 # tools management
450453 tools = tools or self .tools
451454 tools = list (tools ) if isinstance (tools , Toolset ) else tools
452455 _check_duplicate_tool_names (tools ) # handles Toolset as well
453- anthropic_tools = (
454- [
455- {
456- "name" : tool .name ,
457- "description" : tool .description ,
458- "input_schema" : tool .parameters ,
459- }
460- for tool in tools
461- ]
462- if tools
463- else []
464- )
456+
457+ anthropic_tools : List [ToolParam ] = []
458+ if tools :
459+ for tool in tools :
460+ anthropic_tools .append (
461+ ToolParam (name = tool .name , description = tool .description , input_schema = tool .parameters )
462+ )
465463
466464 return system_messages , non_system_messages , generation_kwargs , anthropic_tools
467465
468466 def _process_response (
469467 self ,
470- response : Any ,
471- streaming_callback : Optional [Callable [[ StreamingChunk ], None ] ] = None ,
468+ response : Union [ Message , Stream [ RawMessageStreamEvent ]] ,
469+ streaming_callback : Optional [SyncStreamingCallbackT ] = None ,
472470 ) -> Dict [str , List [ChatMessage ]]:
473471 """
474472 Process the response from the Anthropic API.
@@ -478,8 +476,8 @@ def _process_response(
478476 :returns: A dictionary containing the processed response as a list of ChatMessage objects.
479477 """
480478 # workaround for https://github.com/DataDog/dd-trace-py/issues/12562
481- stream = streaming_callback is not None
482- if stream :
479+ # we cannot use isinstance(Stream)
480+ if not isinstance ( response , Message ) :
483481 chunks : List [StreamingChunk ] = []
484482 model : Optional [str ] = None
485483 for chunk in response :
@@ -552,7 +550,7 @@ def run(
552550 streaming_callback : Optional [StreamingCallbackT ] = None ,
553551 generation_kwargs : Optional [Dict [str , Any ]] = None ,
554552 tools : Optional [Union [List [Tool ], Toolset ]] = None ,
555- ):
553+ ) -> Dict [ str , List [ ChatMessage ]] :
556554 """
557555 Invokes the Anthropic API with the given messages and generation kwargs.
558556
@@ -584,7 +582,8 @@ def run(
584582 ** generation_kwargs ,
585583 )
586584
587- return self ._process_response (response , streaming_callback )
585+ # select_streaming_callback returns a StreamingCallbackT, but we know it's SyncStreamingCallbackT
586+ return self ._process_response (response = response , streaming_callback = streaming_callback ) # type: ignore[arg-type]
588587
589588 @component .output_types (replies = List [ChatMessage ])
590589 async def run_async (
@@ -593,7 +592,7 @@ async def run_async(
593592 streaming_callback : Optional [StreamingCallbackT ] = None ,
594593 generation_kwargs : Optional [Dict [str , Any ]] = None ,
595594 tools : Optional [Union [List [Tool ], Toolset ]] = None ,
596- ):
595+ ) -> Dict [ str , List [ ChatMessage ]] :
597596 """
598597 Async version of the run method. Invokes the Anthropic API with the given messages and generation kwargs.
599598
@@ -625,4 +624,5 @@ async def run_async(
625624 ** generation_kwargs ,
626625 )
627626
628- return await self ._process_response_async (response , streaming_callback )
627+ # select_streaming_callback returns a StreamingCallbackT, but we know it's AsyncStreamingCallbackT
628+ return await self ._process_response_async (response , streaming_callback ) # type: ignore[arg-type]
0 commit comments