Skip to content

Commit 9704b7f

Browse files
authored
fix: fix Anthropic types + add py.typed (#1940)
* draft * fix: Anthropic - fix types + add py.typed * CI * simplify
1 parent fd2ae52 commit 9704b7f

6 files changed

Lines changed: 96 additions & 115 deletions

File tree

.github/workflows/anthropic.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ jobs:
5050
- name: Install Hatch
5151
run: pip install --upgrade hatch
5252

53-
# TODO: Once this integration is properly typed, use hatch run test:types
54-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5553
- name: Lint
5654
if: matrix.python-version == '3.9' && runner.os == 'Linux'
57-
run: hatch run fmt-check && hatch run lint:typing
55+
run: hatch run fmt-check && hatch run test:types
5856

5957
- name: Run tests
6058
run: hatch run test:cov-retry

integrations/anthropic/pyproject.toml

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,14 @@ integration = 'pytest -m "integration" {args:tests}'
6666
all = 'pytest {args:tests}'
6767
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
6868

69-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
69+
types = "mypy -p haystack_integrations.components.generators.anthropic {args}"
7070

71-
# TODO: remove lint environment once this integration is properly typed
72-
# test environment should be used instead
73-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
74-
[tool.hatch.envs.lint]
75-
installer = "uv"
76-
detached = true
77-
dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
71+
[tool.mypy]
72+
install_types = true
73+
non_interactive = true
74+
check_untyped_defs = true
75+
disallow_incomplete_defs = true
7876

79-
[tool.hatch.envs.lint.scripts]
80-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
8177

8278
[tool.black]
8379
target-version = ["py38"]
@@ -159,15 +155,6 @@ omit = ["*/tests/*", "*/__init__.py"]
159155
show_missing = true
160156
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
161157

162-
[[tool.mypy.overrides]]
163-
module = [
164-
"anthropic.*",
165-
"haystack.*",
166-
"haystack_integrations.*",
167-
"pytest.*",
168-
"numpy.*",
169-
]
170-
ignore_missing_imports = true
171158

172159
[tool.pytest.ini_options]
173160
addopts = "--strict-markers"

integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py

Lines changed: 84 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import json
2-
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
2+
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union
33

44
from haystack import component, default_from_dict, default_to_dict, logging
5-
from haystack.dataclasses import (
5+
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult
6+
from haystack.dataclasses.streaming_chunk import (
67
AsyncStreamingCallbackT,
7-
ChatMessage,
8-
ChatRole,
98
StreamingCallbackT,
109
StreamingChunk,
11-
ToolCall,
12-
ToolCallResult,
10+
SyncStreamingCallbackT,
1311
select_streaming_callback,
1412
)
1513
from haystack.tools import (
@@ -19,98 +17,102 @@
1917
deserialize_tools_or_toolset_inplace,
2018
serialize_tools_or_toolset,
2119
)
22-
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
20+
from haystack.utils.auth import Secret, deserialize_secrets_inplace
21+
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
2322

2423
from anthropic import Anthropic, AsyncAnthropic
24+
from anthropic.resources.messages.messages import Message, RawMessageStreamEvent, Stream
25+
from anthropic.types import MessageParam, TextBlockParam, ToolParam, ToolResultBlockParam, ToolUseBlockParam
2526

2627
logger = logging.getLogger(__name__)
2728

2829

2930
def _update_anthropic_message_with_tool_call_results(
30-
tool_call_results: List[ToolCallResult], anthropic_msg: Dict[str, Any]
31+
tool_call_results: List[ToolCallResult],
32+
content: List[Union[TextBlockParam, ToolUseBlockParam, ToolResultBlockParam]],
3133
) -> None:
3234
"""
33-
Update an Anthropic message with tool call results.
35+
Update an Anthropic message content list with tool call results.
3436
3537
:param tool_call_results: The list of ToolCallResults to update the message with.
36-
:param anthropic_msg: The Anthropic message to update.
38+
:param content: The Anthropic message content list to update.
3739
"""
38-
if "content" not in anthropic_msg:
39-
anthropic_msg["content"] = []
40-
4140
for tool_call_result in tool_call_results:
4241
if tool_call_result.origin.id is None:
4342
msg = "`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
4443
raise ValueError(msg)
45-
anthropic_msg["content"].append(
46-
{
47-
"type": "tool_result",
48-
"tool_use_id": tool_call_result.origin.id,
49-
"content": [{"type": "text", "text": tool_call_result.result}],
50-
"is_error": tool_call_result.error,
51-
}
44+
45+
tool_result_block = ToolResultBlockParam(
46+
type="tool_result",
47+
tool_use_id=tool_call_result.origin.id,
48+
content=[{"type": "text", "text": tool_call_result.result}],
49+
is_error=tool_call_result.error,
5250
)
51+
content.append(tool_result_block)
5352

5453

55-
def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[Dict[str, Any]]:
54+
def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[ToolUseBlockParam]:
5655
"""
5756
Convert a list of tool calls to the format expected by Anthropic Chat API.
5857
5958
:param tool_calls: The list of ToolCalls to convert.
60-
:return: A list of dictionaries in the format expected by Anthropic API.
59+
:return: A list of ToolUseBlockParam objects in the format expected by Anthropic API.
6160
"""
6261
anthropic_tool_calls = []
6362
for tc in tool_calls:
6463
if tc.id is None:
6564
msg = "`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
6665
raise ValueError(msg)
67-
anthropic_tool_calls.append(
68-
{
69-
"type": "tool_use",
70-
"id": tc.id,
71-
"name": tc.tool_name,
72-
"input": tc.arguments,
73-
}
66+
67+
tool_use_block = ToolUseBlockParam(
68+
type="tool_use",
69+
id=tc.id,
70+
name=tc.tool_name,
71+
input=tc.arguments,
7472
)
73+
anthropic_tool_calls.append(tool_use_block)
7574
return anthropic_tool_calls
7675

7776

7877
def _convert_messages_to_anthropic_format(
7978
messages: List[ChatMessage],
80-
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
79+
) -> Tuple[List[TextBlockParam], List[MessageParam]]:
8180
"""
8281
Convert a list of messages to the format expected by Anthropic Chat API.
8382
8483
:param messages: The list of ChatMessages to convert.
8584
:return: A tuple of two lists:
86-
- A list of system message dictionaries in the format expected by Anthropic API.
87-
- A list of non-system message dictionaries in the format expected by Anthropic API.
85+
- A list of system message TextBlockParam objects in the format expected by Anthropic API.
86+
- A list of non-system MessageParam objects in the format expected by Anthropic API.
8887
"""
8988

90-
anthropic_system_messages = []
91-
anthropic_non_system_messages = []
89+
anthropic_system_messages: List[TextBlockParam] = []
90+
anthropic_non_system_messages: List[MessageParam] = []
9291

9392
i = 0
9493
while i < len(messages):
9594
message = messages[i]
9695

97-
# allow passing cache_control
98-
cache_control = {"cache_control": message.meta.get("cache_control")} if "cache_control" in message.meta else {}
99-
10096
# system messages have special format requirements for Anthropic API
10197
# they can have only type and text fields, and they need to be passed separately
10298
# to the Anthropic API endpoint
103-
if message.is_from(ChatRole.SYSTEM):
104-
anthropic_system_messages.append({"type": "text", "text": message.text, **cache_control})
99+
if message.is_from(ChatRole.SYSTEM) and message.text:
100+
sys_message = TextBlockParam(type="text", text=message.text)
101+
if cache_control := message.meta.get("cache_control"):
102+
sys_message["cache_control"] = cache_control
103+
anthropic_system_messages.append(sys_message)
105104
i += 1
106105
continue
107106

108-
anthropic_msg: Dict[str, Any] = {"role": message._role.value, "content": [], **cache_control}
107+
content: List[Union[TextBlockParam, ToolUseBlockParam, ToolResultBlockParam]] = []
109108

110109
if message.texts and message.texts[0]:
111-
anthropic_msg["content"].append({"type": "text", "text": message.texts[0]})
110+
text_block = TextBlockParam(type="text", text=message.texts[0])
111+
content.append(text_block)
112+
112113
if message.tool_calls:
113-
anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format(message.tool_calls)
114+
tool_use_blocks = _convert_tool_calls_to_anthropic_format(message.tool_calls)
115+
content.extend(tool_use_blocks)
114116

115117
if message.tool_call_results:
116118
results = message.tool_call_results.copy()
@@ -119,14 +121,20 @@ def _convert_messages_to_anthropic_format(
119121
i += 1
120122
results.extend(messages[i].tool_call_results)
121123

122-
_update_anthropic_message_with_tool_call_results(results, anthropic_msg)
123-
anthropic_msg["role"] = "user"
124+
_update_anthropic_message_with_tool_call_results(results, content)
124125

125-
if not anthropic_msg["content"]:
126+
if not content:
126127
msg = "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
127128
raise ValueError(msg)
128129

129-
anthropic_non_system_messages.append(anthropic_msg)
130+
# Anthropic only supports assistant and user roles in messages. User role is also used for tool messages.
131+
# System messages are passed separately.
132+
role: Union[Literal["assistant"], Literal["user"]] = "user"
133+
if message._role == ChatRole.ASSISTANT:
134+
role = "assistant"
135+
136+
anthropic_message = MessageParam(role=role, content=content)
137+
anthropic_non_system_messages.append(anthropic_message)
130138
i += 1
131139

132140
return anthropic_system_messages, anthropic_non_system_messages
@@ -340,11 +348,14 @@ def _convert_streaming_chunks_to_chat_message(
340348
for chunk in chunks:
341349
chunk_type = chunk.meta.get("type")
342350
if chunk_type == "content_block_start":
343-
if chunk.meta.get("content_block", {}).get("type") == "tool_use":
344-
delta_block = chunk.meta.get("content_block")
351+
content_block = chunk.meta.get("content_block")
352+
if content_block is None:
353+
msg = "Invalid streaming chunk. Expected 'content_block' field."
354+
raise ValueError(msg)
355+
if content_block.get("type") == "tool_use":
345356
current_tool_call = {
346-
"id": delta_block.get("id"),
347-
"name": delta_block.get("name"),
357+
"id": content_block.get("id"),
358+
"name": content_block.get("name"),
348359
"arguments": "",
349360
}
350361
elif chunk_type == "content_block_delta":
@@ -388,21 +399,12 @@ def _convert_streaming_chunks_to_chat_message(
388399

389400
return message
390401

391-
@staticmethod
392-
def _remove_cache_control(message: Dict[str, Any]) -> Dict[str, Any]:
393-
"""
394-
Removes the cache_control key from the message.
395-
:param message: The message to remove the cache_control key from.
396-
:returns: The message with the cache_control key removed.
397-
"""
398-
return {k: v for k, v in message.items() if k != "cache_control"}
399-
400402
def _prepare_request_params(
401403
self,
402404
messages: List[ChatMessage],
403405
generation_kwargs: Optional[Dict[str, Any]] = None,
404406
tools: Optional[Union[List[Tool], Toolset]] = None,
405-
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, Any], List[Dict[str, Any]]]:
407+
) -> Tuple[List[TextBlockParam], List[MessageParam], Dict[str, Any], List[ToolParam]]:
406408
"""
407409
Prepare the parameters for the Anthropic API request.
408410
@@ -433,8 +435,8 @@ def _prepare_request_params(
433435
# prompt caching
434436
extra_headers = generation_kwargs.get("extra_headers", {})
435437
prompt_caching_on = "anthropic-beta" in extra_headers and "prompt-caching" in extra_headers["anthropic-beta"]
436-
has_cached_messages = any("cache_control" in m for m in system_messages) or any(
437-
"cache_control" in m for m in non_system_messages
438+
has_cached_messages = any(m.get("cache_control") is not None for m in system_messages) or any(
439+
m.get("cache_control") is not None for m in non_system_messages
438440
)
439441
if has_cached_messages and not prompt_caching_on:
440442
# this avoids Anthropic errors when prompt caching is not enabled
@@ -443,32 +445,28 @@ def _prepare_request_params(
443445
"Prompt caching is not enabled but you requested individual messages to be cached. "
444446
"Messages will be sent to the API without prompt caching."
445447
)
446-
system_messages = list(map(self._remove_cache_control, system_messages))
447-
non_system_messages = list(map(self._remove_cache_control, non_system_messages))
448+
for message in system_messages:
449+
if message.get("cache_control"):
450+
del message["cache_control"]
448451

449452
# tools management
450453
tools = tools or self.tools
451454
tools = list(tools) if isinstance(tools, Toolset) else tools
452455
_check_duplicate_tool_names(tools) # handles Toolset as well
453-
anthropic_tools = (
454-
[
455-
{
456-
"name": tool.name,
457-
"description": tool.description,
458-
"input_schema": tool.parameters,
459-
}
460-
for tool in tools
461-
]
462-
if tools
463-
else []
464-
)
456+
457+
anthropic_tools: List[ToolParam] = []
458+
if tools:
459+
for tool in tools:
460+
anthropic_tools.append(
461+
ToolParam(name=tool.name, description=tool.description, input_schema=tool.parameters)
462+
)
465463

466464
return system_messages, non_system_messages, generation_kwargs, anthropic_tools
467465

468466
def _process_response(
469467
self,
470-
response: Any,
471-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
468+
response: Union[Message, Stream[RawMessageStreamEvent]],
469+
streaming_callback: Optional[SyncStreamingCallbackT] = None,
472470
) -> Dict[str, List[ChatMessage]]:
473471
"""
474472
Process the response from the Anthropic API.
@@ -478,8 +476,8 @@ def _process_response(
478476
:returns: A dictionary containing the processed response as a list of ChatMessage objects.
479477
"""
480478
# workaround for https://github.com/DataDog/dd-trace-py/issues/12562
481-
stream = streaming_callback is not None
482-
if stream:
479+
# we cannot use isinstance(Stream)
480+
if not isinstance(response, Message):
483481
chunks: List[StreamingChunk] = []
484482
model: Optional[str] = None
485483
for chunk in response:
@@ -552,7 +550,7 @@ def run(
552550
streaming_callback: Optional[StreamingCallbackT] = None,
553551
generation_kwargs: Optional[Dict[str, Any]] = None,
554552
tools: Optional[Union[List[Tool], Toolset]] = None,
555-
):
553+
) -> Dict[str, List[ChatMessage]]:
556554
"""
557555
Invokes the Anthropic API with the given messages and generation kwargs.
558556
@@ -584,7 +582,8 @@ def run(
584582
**generation_kwargs,
585583
)
586584

587-
return self._process_response(response, streaming_callback)
585+
# select_streaming_callback returns a StreamingCallbackT, but we know it's SyncStreamingCallbackT
586+
return self._process_response(response=response, streaming_callback=streaming_callback) # type: ignore[arg-type]
588587

589588
@component.output_types(replies=List[ChatMessage])
590589
async def run_async(
@@ -593,7 +592,7 @@ async def run_async(
593592
streaming_callback: Optional[StreamingCallbackT] = None,
594593
generation_kwargs: Optional[Dict[str, Any]] = None,
595594
tools: Optional[Union[List[Tool], Toolset]] = None,
596-
):
595+
) -> Dict[str, List[ChatMessage]]:
597596
"""
598597
Async version of the run method. Invokes the Anthropic API with the given messages and generation kwargs.
599598
@@ -625,4 +624,5 @@ async def run_async(
625624
**generation_kwargs,
626625
)
627626

628-
return await self._process_response_async(response, streaming_callback)
627+
# select_streaming_callback returns a StreamingCallbackT, but we know it's AsyncStreamingCallbackT
628+
return await self._process_response_async(response, streaming_callback) # type: ignore[arg-type]

0 commit comments

Comments
 (0)