diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml
index 81fd65f544..067aa752b2 100644
--- a/integrations/amazon_bedrock/pyproject.toml
+++ b/integrations/amazon_bedrock/pyproject.toml
@@ -23,7 +23,9 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
-dependencies = ["haystack-ai>=2.24.1", "boto3>=1.28.57", "aioboto3>=14.0.0"]
+dependencies = ["haystack-ai>=2.24.1", "boto3>=1.42.84", "aioboto3>=14.0.0"]
+
+
[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme"
@@ -172,6 +174,12 @@ omit = ["*/tests/*", "*/__init__.py"]
show_missing = true
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
+[tool.uv]
+# aiobotocore 2.25.1 pins botocore<1.40.62 but works fine with newer botocore in practice.
+# outputConfig support in the Bedrock Converse API requires botocore>=1.42.84.
+# Override the transitive constraint so uv can resolve the dependency graph.
+override-dependencies = ["botocore>=1.42.84", "boto3>=1.42.84"]
+
[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
index 16f45a915e..2ef9c690e9 100644
--- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
+++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
@@ -1,3 +1,4 @@
+import json
from typing import Any
import aioboto3
@@ -5,7 +6,12 @@
from botocore.eventstream import EventStream
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict, logging
-from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, select_streaming_callback
+from haystack.dataclasses import (
+ ChatMessage,
+ ComponentInfo,
+ StreamingCallbackT,
+ select_streaming_callback,
+)
from haystack.tools import (
ToolsType,
_check_duplicate_tool_names,
@@ -14,7 +20,10 @@
serialize_tools_or_toolset,
)
from haystack.utils.auth import Secret
-from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
+from haystack.utils.callable_serialization import (
+ deserialize_callable,
+ serialize_callable,
+)
from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
@@ -27,6 +36,7 @@
_parse_completion_response,
_parse_streaming_response,
_parse_streaming_response_async,
+ _parse_structured_output,
_validate_and_format_cache_point,
_validate_guardrail_config,
)
@@ -164,13 +174,11 @@ def weather(city: str):
def __init__(
self,
model: str,
- aws_access_key_id: Secret | None = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
- aws_secret_access_key: Secret | None = Secret.from_env_var( # noqa: B008
- ["AWS_SECRET_ACCESS_KEY"], strict=False
- ),
- aws_session_token: Secret | None = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
- aws_region_name: Secret | None = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
- aws_profile_name: Secret | None = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
+ aws_access_key_id: Secret | None = None,
+ aws_secret_access_key: Secret | None = None,
+ aws_session_token: Secret | None = None,
+ aws_region_name: Secret | None = None,
+ aws_profile_name: Secret | None = None,
generation_kwargs: dict[str, Any] | None = None,
streaming_callback: StreamingCallbackT | None = None,
boto3_config: dict[str, Any] | None = None,
@@ -236,6 +244,22 @@ def __init__(
msg = "'model' cannot be None or empty string"
raise ValueError(msg)
self.model = model
+ if aws_access_key_id is None:
+ aws_access_key_id = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False)
+
+ if aws_secret_access_key is None:
+ aws_secret_access_key = Secret.from_env_var(
+ ["AWS_SECRET_ACCESS_KEY"], strict=False
+ )
+
+ if aws_session_token is None:
+ aws_session_token = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False)
+
+ if aws_region_name is None:
+ aws_region_name = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False)
+
+ if aws_profile_name is None:
+ aws_profile_name = Secret.from_env_var(["AWS_PROFILE"], strict=False)
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
@@ -247,18 +271,23 @@ def __init__(
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
self.tools = tools
- _validate_guardrail_config(guardrail_config=guardrail_config, streaming=streaming_callback is not None)
+ _validate_guardrail_config(
+ guardrail_config=guardrail_config, streaming=streaming_callback is not None
+ )
self.guardrail_config = guardrail_config
self.tools_cachepoint_config = (
- _validate_and_format_cache_point(tools_cachepoint_config) if tools_cachepoint_config else None
+ _validate_and_format_cache_point(tools_cachepoint_config)
+ if tools_cachepoint_config
+ else None
)
def resolve_secret(secret: Secret | None) -> str | None:
return secret.resolve_value() if secret else None
config = Config(
- user_agent_extra="x-client-framework:haystack", **(self.boto3_config if self.boto3_config else {})
+ user_agent_extra="x-client-framework:haystack",
+ **(self.boto3_config if self.boto3_config else {}),
)
try:
@@ -300,13 +329,31 @@ def _get_async_session(self) -> aioboto3.Session:
try:
self.async_session = get_aws_session(
- aws_access_key_id=self.aws_access_key_id.resolve_value() if self.aws_access_key_id else None,
+ aws_access_key_id=(
+ self.aws_access_key_id.resolve_value()
+ if self.aws_access_key_id
+ else None
+ ),
aws_secret_access_key=(
- self.aws_secret_access_key.resolve_value() if self.aws_secret_access_key else None
+ self.aws_secret_access_key.resolve_value()
+ if self.aws_secret_access_key
+ else None
+ ),
+ aws_session_token=(
+ self.aws_session_token.resolve_value()
+ if self.aws_session_token
+ else None
+ ),
+ aws_region_name=(
+ self.aws_region_name.resolve_value()
+ if self.aws_region_name
+ else None
+ ),
+ aws_profile_name=(
+ self.aws_profile_name.resolve_value()
+ if self.aws_profile_name
+ else None
),
- aws_session_token=self.aws_session_token.resolve_value() if self.aws_session_token else None,
- aws_region_name=self.aws_region_name.resolve_value() if self.aws_region_name else None,
- aws_profile_name=self.aws_profile_name.resolve_value() if self.aws_profile_name else None,
async_mode=True,
)
return self.async_session
@@ -325,7 +372,11 @@ def to_dict(self) -> dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
- callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
+ callback_name = (
+ serialize_callable(self.streaming_callback)
+ if self.streaming_callback
+ else None
+ )
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id,
@@ -360,7 +411,9 @@ def from_dict(cls, data: dict[str, Any]) -> "AmazonBedrockChatGenerator":
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
- data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
+ data["init_parameters"]["streaming_callback"] = deserialize_callable(
+ serialized_callback_handler
+ )
deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
return default_from_dict(cls, data)
@@ -385,6 +438,26 @@ def _prepare_request_params(
- `stopSequences`: List of stop sequences to stop generation.
- `temperature`: Sampling temperature.
- `topP`: Nucleus sampling parameter.
+ - `json_schema`: Request structured JSON output validated against a schema. Provide a dict with:
+ - `schema` (required): a JSON Schema dict describing the expected output structure.
+ - `name` (optional): a name for the schema, defaults to ``"response_schema"``.
+ - `description` (optional): a description of the schema.
+
+ Example::
+
+ generation_kwargs={
+ "json_schema": {
+ "name": "person",
+ "schema": {
+ "type": "object",
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
+ "required": ["name", "age"],
+ "additionalProperties": False,
+ },
+ }
+ }
+
+ When set, the parsed JSON object is stored in ``reply.meta["structured_output"]``.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.
:param requires_async: Boolean flag to indicate if an async-compatible streaming callback function is needed.
@@ -413,13 +486,36 @@ def _prepare_request_params(
flattened_tools = flatten_tools_or_toolsets(tools)
_check_duplicate_tool_names(flattened_tools)
tool_config = merged_kwargs.pop("toolConfig", None)
+ json_schema = merged_kwargs.pop("json_schema", None)
if flattened_tools:
# Format Haystack tools to Bedrock format
- tool_config = _format_tools(flattened_tools, tools_cachepoint_config=self.tools_cachepoint_config)
+ tool_config = _format_tools(
+ flattened_tools, tools_cachepoint_config=self.tools_cachepoint_config
+ )
# Any remaining kwargs go to additionalModelRequestFields
additional_fields = merged_kwargs if merged_kwargs else None
+ # Build outputConfig from json_schema for structured output support.
+ # See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html
+ output_config: dict[str, Any] | None = None
+ if json_schema is not None:
+ if "schema" not in json_schema:
+ msg = "'json_schema' must contain a 'schema' key with the JSON Schema dict."
+ raise ValueError(msg)
+ json_schema_block: dict[str, Any] = {
+ "name": json_schema.get("name", "response_schema"),
+ "schema": json.dumps(json_schema["schema"]),
+ }
+ if "description" in json_schema:
+ json_schema_block["description"] = json_schema["description"]
+ output_config = {
+ "textFormat": {
+ "type": "json_schema",
+ "structure": {"jsonSchema": json_schema_block},
+ }
+ }
+
# Format messages to Bedrock format
system_prompts, messages_list = _format_messages(messages)
@@ -434,6 +530,8 @@ def _prepare_request_params(
params["toolConfig"] = tool_config
if additional_fields:
params["additionalModelRequestFields"] = additional_fields
+ if output_config:
+ params["outputConfig"] = output_config
if self.guardrail_config:
params["guardrailConfig"] = self.guardrail_config
@@ -447,10 +545,14 @@ def _prepare_request_params(
return params, callback
- def _resolve_flattened_generation_kwargs(self, generation_kwargs: dict[str, Any]) -> dict[str, Any]:
+ def _resolve_flattened_generation_kwargs(
+ self, generation_kwargs: dict[str, Any]
+ ) -> dict[str, Any]:
generation_kwargs = generation_kwargs.copy()
- disable_parallel_tool_use = generation_kwargs.pop("disable_parallel_tool_use", None)
+ disable_parallel_tool_use = generation_kwargs.pop(
+ "disable_parallel_tool_use", None
+ )
parallel_tool_use = generation_kwargs.pop("parallel_tool_use", None)
if disable_parallel_tool_use is not None and parallel_tool_use is not None:
@@ -497,13 +599,16 @@ def run(
- `stopSequences`: List of stop sequences to stop generation.
- `temperature`: Sampling temperature.
- `topP`: Nucleus sampling parameter.
+ - `json_schema`: Request structured JSON output validated against a schema. See
+ :meth:`_prepare_request_params` for full details.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.
:returns:
A dictionary containing the model-generated replies under the `"replies"` key.
+ When ``json_schema`` is used, each reply's ``meta["structured_output"]`` contains the parsed JSON object.
:raises AmazonBedrockInferenceError:
- If the Bedrock inference API call fails.
+ If the Bedrock inference API call fails or the model returns invalid JSON for structured output.
"""
component_info = ComponentInfo.from_component(self)
@@ -536,6 +641,9 @@ def run(
msg = f"Could not perform inference for Amazon Bedrock model {self.model} due to:\n{exception}"
raise AmazonBedrockInferenceError(msg) from exception
+ if "outputConfig" in params:
+ replies = _parse_structured_output(replies)
+
return {"replies": replies}
@component.output_types(replies=list[ChatMessage])
@@ -558,13 +666,16 @@ async def run_async(
- `stopSequences`: List of stop sequences to stop generation.
- `temperature`: Sampling temperature.
- `topP`: Nucleus sampling parameter.
+ - `json_schema`: Request structured JSON output validated against a schema. See
+ :meth:`_prepare_request_params` for full details.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.
:returns:
A dictionary containing the model-generated replies under the `"replies"` key.
+ When ``json_schema`` is used, each reply's ``meta["structured_output"]`` contains the parsed JSON object.
:raises AmazonBedrockInferenceError:
- If the Bedrock inference API call fails.
+ If the Bedrock inference API call fails or the model returns invalid JSON for structured output.
"""
component_info = ComponentInfo.from_component(self)
@@ -581,7 +692,8 @@ async def run_async(
# Note: https://aioboto3.readthedocs.io/en/latest/usage.html
# we need to create a new client for each request
config = Config(
- user_agent_extra="x-client-framework:haystack", **(self.boto3_config if self.boto3_config else {})
+ user_agent_extra="x-client-framework:haystack",
+ **(self.boto3_config if self.boto3_config else {}),
)
async with session.client("bedrock-runtime", config=config) as async_client:
if callback:
@@ -605,4 +717,7 @@ async def run_async(
msg = f"Could not perform inference for Amazon Bedrock model {self.model} due to:\n{exception}"
raise AmazonBedrockInferenceError(msg) from exception
+ if "outputConfig" in params:
+ replies = _parse_structured_output(replies)
+
return {"replies": replies}
diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py
index f6a5ce2251..cd9ceb4d2b 100644
--- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py
+++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py
@@ -25,6 +25,8 @@
)
from haystack.tools import Tool
+from haystack_integrations.common.amazon_bedrock.errors import AmazonBedrockInferenceError
+
logger = logging.getLogger(__name__)
@@ -713,6 +715,28 @@ async def _parse_streaming_response_async(
return replies
+def _parse_structured_output(replies: list[ChatMessage]) -> list[ChatMessage]:
+ """
+ Parse JSON structured output from model replies and store it in message metadata.
+
+ When structured output is requested via ``json_schema`` in ``generation_kwargs``, the model
+ returns JSON text. This function parses that JSON and stores the resulting object in
+ ``reply.meta["structured_output"]`` for each reply that contains text.
+
+ :param replies: List of ChatMessage objects returned by the model.
+ :returns: The same list with ``meta["structured_output"]`` populated on text replies.
+ :raises AmazonBedrockInferenceError: If the model's response is not valid JSON.
+ """
+ for reply in replies:
+ if reply.text:
+ try:
+ reply.meta["structured_output"] = json.loads(reply.text)
+ except json.JSONDecodeError as e:
+ msg = f"Structured output was requested but the model returned invalid JSON: {e}"
+ raise AmazonBedrockInferenceError(msg) from e
+ return replies
+
+
def _validate_guardrail_config(guardrail_config: dict[str, str] | None = None, streaming: bool = False) -> None:
"""
Validate the guardrail configuration.
diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py
index 8e7ebdc503..7d15094c2d 100644
--- a/integrations/amazon_bedrock/tests/test_chat_generator.py
+++ b/integrations/amazon_bedrock/tests/test_chat_generator.py
@@ -1,3 +1,4 @@
+import json
import os
from typing import Any
@@ -6,10 +7,20 @@
from haystack.components.agents import Agent
from haystack.components.generators.utils import print_streaming_chunk
from haystack.components.tools import ToolInvoker
-from haystack.dataclasses import ChatMessage, ChatRole, FileContent, ImageContent, StreamingChunk, TextContent, ToolCall
+from haystack.dataclasses import (
+ ChatMessage,
+ ChatRole,
+ FileContent,
+ ImageContent,
+ StreamingChunk,
+ TextContent,
+ ToolCall,
+)
from haystack.tools import Tool, Toolset, create_tool_from_function
-from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
+from haystack_integrations.components.generators.amazon_bedrock import (
+ AmazonBedrockChatGenerator,
+)
CLASS_TYPE = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
MODELS_TO_TEST = [
@@ -58,9 +69,7 @@
"global.anthropic.claude-sonnet-4-6",
]
-MODELS_TO_TEST_WITH_PROMPT_CACHING = [
- "us.amazon.nova-micro-v1:0" # cheap, fast model
-]
+MODELS_TO_TEST_WITH_PROMPT_CACHING = ["us.amazon.nova-micro-v1:0"] # cheap, fast model
def hello_world():
@@ -90,7 +99,11 @@ def population(city: str):
@pytest.fixture
def tools():
- tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
+ tool_parameters = {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ }
tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
@@ -106,13 +119,21 @@ def mixed_tools():
weather_tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
- parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
+ parameters={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
function=weather,
)
population_tool = Tool(
name="population",
description="useful to determine the population of a given location",
- parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
+ parameters={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
function=population,
)
toolset = Toolset([population_tool])
@@ -171,22 +192,48 @@ def test_to_dict(self, mock_boto3_session, boto3_config):
generation_kwargs={"temperature": 0.7},
streaming_callback=print_streaming_chunk,
boto3_config=boto3_config,
- guardrail_config={"guardrailIdentifier": "test", "guardrailVersion": "test"},
+ guardrail_config={
+ "guardrailIdentifier": "test",
+ "guardrailVersion": "test",
+ },
)
expected_dict = {
"type": CLASS_TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "cohere.command-r-plus-v1:0",
"generation_kwargs": {"temperature": 0.7},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"boto3_config": boto3_config,
"tools": None,
- "guardrail_config": {"guardrailIdentifier": "test", "guardrailVersion": "test"},
+ "guardrail_config": {
+ "guardrailIdentifier": "test",
+ "guardrailVersion": "test",
+ },
"tools_cachepoint_config": None,
},
}
@@ -202,15 +249,31 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] |
{
"type": CLASS_TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
"aws_secret_access_key": {
"type": "env_var",
"env_vars": ["AWS_SECRET_ACCESS_KEY"],
"strict": False,
},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "global.anthropic.claude-sonnet-4-6",
"generation_kwargs": {"temperature": 0.7},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
@@ -250,7 +313,8 @@ def test_constructor_with_generation_kwargs(self, mock_boto3_session):
"""
generation_kwargs = {"temperature": 0.7}
layer = AmazonBedrockChatGenerator(
- model="global.anthropic.claude-sonnet-4-6", generation_kwargs=generation_kwargs
+ model="global.anthropic.claude-sonnet-4-6",
+ generation_kwargs=generation_kwargs,
)
assert layer.generation_kwargs == generation_kwargs
@@ -294,17 +358,36 @@ def test_serde_in_pipeline(self, mock_boto3_session, monkeypatch):
"generator": {
"type": CLASS_TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
"aws_secret_access_key": {
"type": "env_var",
"env_vars": ["AWS_SECRET_ACCESS_KEY"],
"strict": False,
},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "global.anthropic.claude-sonnet-4-6",
- "generation_kwargs": {"temperature": 0.7, "stopSequences": ["eviscerate"]},
+ "generation_kwargs": {
+ "temperature": 0.7,
+ "stopSequences": ["eviscerate"],
+ },
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"boto3_config": None,
"tools": [
@@ -347,13 +430,62 @@ def test_prepare_request_params_tool_config(self, top_song_tool_config, mock_bot
def test_prepare_request_params_guardrail_config(self, mock_boto3_session, set_env_variables):
generator = AmazonBedrockChatGenerator(
model="global.anthropic.claude-sonnet-4-6",
- guardrail_config={"guardrailIdentifier": "test", "guardrailVersion": "test"},
+ guardrail_config={
+ "guardrailIdentifier": "test",
+ "guardrailVersion": "test",
+ },
)
request_params, _ = generator._prepare_request_params(
messages=[ChatMessage.from_user("What's the capital of France?")],
)
assert request_params["messages"] == [{"content": [{"text": "What's the capital of France?"}], "role": "user"}]
- assert request_params["guardrailConfig"] == {"guardrailIdentifier": "test", "guardrailVersion": "test"}
+ assert request_params["guardrailConfig"] == {
+ "guardrailIdentifier": "test",
+ "guardrailVersion": "test",
+ }
+
+ def test_prepare_request_params_json_schema(self, mock_boto3_session, set_env_variables):
+
+ generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6")
+ schema = {
+ "type": "object",
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
+ "required": ["name", "age"],
+ "additionalProperties": False,
+ }
+ request_params, _ = generator._prepare_request_params(
+ messages=[ChatMessage.from_user("Hello")],
+ generation_kwargs={
+ "json_schema": {
+ "name": "person",
+ "description": "A person's name and age",
+ "schema": schema,
+ }
+ },
+ )
+ assert "outputConfig" in request_params
+ assert request_params["outputConfig"] == {
+ "textFormat": {
+ "type": "json_schema",
+ "structure": {
+ "jsonSchema": {
+ "name": "person",
+ "description": "A person's name and age",
+ "schema": json.dumps(schema),
+ }
+ },
+ }
+ }
+
+ def test_prepare_request_params_json_schema_missing_schema_key(
+ self, mock_boto3_session, set_env_variables
+ ):
+ generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6")
+ with pytest.raises(ValueError, match="'json_schema' must contain a 'schema' key"):
+ generator._prepare_request_params(
+ messages=[ChatMessage.from_user("Hello")],
+ generation_kwargs={"json_schema": {"name": "test"}},
+ )
def test_init_with_mixed_tools(self, mock_boto3_session, set_env_variables):
def tool_fn(city: str) -> str:
@@ -362,13 +494,21 @@ def tool_fn(city: str) -> str:
weather_tool = Tool(
name="weather",
description="Weather lookup",
- parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
+ parameters={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
function=tool_fn,
)
population_tool = Tool(
name="population",
description="Population lookup",
- parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
+ parameters={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
function=tool_fn,
)
toolset = Toolset([population_tool])
@@ -387,13 +527,21 @@ def tool_fn(city: str) -> str:
weather_tool = Tool(
name="weather",
description="Weather lookup",
- parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
+ parameters={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
function=tool_fn,
)
population_tool = Tool(
name="population",
description="Population lookup",
- parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
+ parameters={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
function=tool_fn,
)
toolset = Toolset([population_tool])
@@ -485,7 +633,11 @@ def tool_fn(city: str) -> str:
],
)
def test_prepare_request_params_with_flattened_generation_kwargs(
- self, mock_boto3_session, set_env_variables, generation_kwargs, additional_model_request_fields
+ self,
+ mock_boto3_session,
+ set_env_variables,
+ generation_kwargs,
+ additional_model_request_fields,
):
generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6")
request_params, _ = generator._prepare_request_params(
@@ -523,6 +675,54 @@ def test_default_inference_params(self, model_name, chat_messages):
assert "prompt_tokens" in first_reply.meta["usage"]
assert "completion_tokens" in first_reply.meta["usage"]
+ def test_run_with_structured_output(self):
+
+ client = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6")
+
+ messages = [
+ ChatMessage.from_user(
+ "Extract the person's name and age from: 'Alice is 30 years old.'"
+ )
+ ]
+ schema = {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string"},
+ "age": {"type": "integer"},
+ },
+ "required": ["name", "age"],
+ "additionalProperties": False,
+ }
+
+ response = client.run(
+ messages,
+ generation_kwargs={
+ "json_schema": {
+ "name": "person",
+ "description": "A person's name and age",
+ "schema": schema,
+ }
+ },
+ )
+
+ assert "replies" in response
+ assert len(response["replies"]) > 0
+
+ reply = response["replies"][0]
+ assert reply.text is not None
+
+ # Response text must be valid JSON
+ parsed = json.loads(reply.text)
+ assert isinstance(parsed, dict)
+
+ # Parsed object is accessible via meta
+ assert "structured_output" in reply.meta
+ assert reply.meta["structured_output"] == parsed
+
+ # Schema fields are present with correct types
+ assert isinstance(reply.meta["structured_output"].get("name"), str)
+ assert isinstance(reply.meta["structured_output"].get("age"), int)
+
@pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_IMAGE_INPUT)
def test_run_with_image_input(self, model_name, test_files_path):
client = AmazonBedrockChatGenerator(model=model_name)
@@ -548,7 +748,10 @@ def test_run_with_pdf_citations(self, model_name, test_files_path):
file_content = FileContent.from_file_path(file_path, extra={"citations": {"enabled": True}})
chat_message = ChatMessage.from_user(
- content_parts=["Is this document a paper on Large Language Models? Respond briefly", file_content]
+ content_parts=[
+ "Is this document a paper on Large Language Models? Respond briefly",
+ file_content,
+ ]
)
response = client.run([chat_message])
@@ -586,7 +789,9 @@ def retrieve_image():
]
image_retriever_tool = create_tool_from_function(
- name="retrieve_image", description="Tool to retrieve an image", function=retrieve_image
+ name="retrieve_image",
+ description="Tool to retrieve an image",
+ function=retrieve_image,
)
image_retriever_tool.outputs_to_string = {"raw_result": True}
@@ -927,7 +1132,9 @@ def test_live_run_with_tool_with_no_args_streaming(self, tool_with_no_parameters
"""
initial_messages = [ChatMessage.from_user("Print Hello World using the print hello world tool.")]
component = AmazonBedrockChatGenerator(
- model=model_name, tools=[tool_with_no_parameters], streaming_callback=print_streaming_chunk
+ model=model_name,
+ tools=[tool_with_no_parameters],
+ streaming_callback=print_streaming_chunk,
)
results = component.run(messages=initial_messages)
@@ -992,7 +1199,8 @@ def test_prompt_caching_live_run_with_user_message(self, model_name, streaming_c
system_message = ChatMessage.from_system("Always respond with: 'Life is beautiful' (and nothing else).")
user_message = ChatMessage.from_user(
- "User message that should be long enough to cache. " * 100, meta={"cachePoint": {"type": "default"}}
+ "User message that should be long enough to cache. " * 100,
+ meta={"cachePoint": {"type": "default"}},
)
messages = [system_message, user_message]
result = generator.run(messages=messages)
diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py
index a09ba17b8a..b9f1ba6cd0 100644
--- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py
+++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py
@@ -45,7 +45,11 @@ def tools():
weather_tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
- parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
+ parameters={
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
function=weather,
)
addition_tool = Tool(
@@ -71,7 +75,11 @@ def test_format_tools(self, tools):
"name": "weather",
"description": "useful to determine the weather in a given location",
"inputSchema": {
- "json": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
+ "json": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ }
},
}
},
@@ -82,7 +90,10 @@ def test_format_tools(self, tools):
"inputSchema": {
"json": {
"type": "object",
- "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}},
+ "properties": {
+ "a": {"type": "integer"},
+ "b": {"type": "integer"},
+ },
"required": ["a", "b"],
}
},
@@ -94,7 +105,9 @@ def test_format_tools(self, tools):
def test_convert_file_content_to_bedrock_format_no_mime_type(self):
file_content = FileContent(
- base64_data=base64.b64encode(b"This is a test file content."), mime_type=None, validation=False
+ base64_data=base64.b64encode(b"This is a test file content."),
+ mime_type=None,
+ validation=False,
)
with pytest.raises(ValueError, match="MIME type is required"):
_convert_file_content_to_bedrock_format(file_content)
@@ -102,7 +115,8 @@ def test_convert_file_content_to_bedrock_format_no_mime_type(self):
def test_convert_file_content_to_bedrock_format_document(self, test_files_path):
file_path = test_files_path / "sample_pdf_1.pdf"
file_content = FileContent.from_file_path(
- file_path, extra={"context": "Example context.", "citations": {"enabled": True}}
+ file_path,
+ extra={"context": "Example context.", "citations": {"enabled": True}},
)
formatted_file_content = _convert_file_content_to_bedrock_format(file_content)
assert formatted_file_content == {
@@ -149,12 +163,17 @@ def test_convert_file_content_to_bedrock_format_video(self, test_files_path):
file_content = FileContent.from_file_path(file_path)
formatted_file_content = _convert_file_content_to_bedrock_format(file_content)
assert formatted_file_content == {
- "video": {"format": "mp4", "source": {"bytes": base64.b64decode(file_content.base64_data)}}
+ "video": {
+ "format": "mp4",
+ "source": {"bytes": base64.b64decode(file_content.base64_data)},
+ }
}
def test_convert_file_content_to_bedrock_format_unsupported_mime_type(self):
file_content = FileContent(
- base64_data=base64.b64encode(b"This is a test file content."), mime_type="image/tiff", validation=False
+ base64_data=base64.b64encode(b"This is a test file content."),
+ mime_type="image/tiff",
+ validation=False,
)
with pytest.raises(ValueError, match="Unsupported file content MIME type"):
_convert_file_content_to_bedrock_format(file_content)
@@ -180,27 +199,52 @@ def test_format_messages(self):
]
assert formatted_messages == [
{"role": "user", "content": [{"text": "What's the capital of France?"}]},
- {"role": "assistant", "content": [{"text": "The capital of France is Paris."}]},
+ {
+ "role": "assistant",
+ "content": [{"text": "The capital of France is Paris."}],
+ },
{"role": "user", "content": [{"text": "What is the weather in Paris?"}]},
{
"role": "assistant",
- "content": [{"toolUse": {"toolUseId": "123", "name": "weather", "input": {"city": "Paris"}}}],
+ "content": [
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "weather",
+ "input": {"city": "Paris"},
+ }
+ }
+ ],
},
{
"role": "user",
- "content": [{"toolResult": {"toolUseId": "123", "content": [{"text": "Sunny and 25°C"}]}}],
+ "content": [
+ {
+ "toolResult": {
+ "toolUseId": "123",
+ "content": [{"text": "Sunny and 25°C"}],
+ }
+ }
+ ],
+ },
+ {
+ "role": "assistant",
+ "content": [{"text": "The weather in Paris is sunny and 25°C."}],
},
- {"role": "assistant", "content": [{"text": "The weather in Paris is sunny and 25°C."}]},
]
def test_format_messages_with_cache_point(self):
meta = {"cachePoint": {"type": "default"}}
messages = [
- ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses.", meta=meta),
+ ChatMessage.from_system(
+ "\\nYou are a helpful assistant, be super brief in your responses.",
+ meta=meta,
+ ),
ChatMessage.from_user("What is the weather in Paris?", meta=meta),
ChatMessage.from_assistant(
- tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})], meta=meta
+ tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})],
+ meta=meta,
),
ChatMessage.from_tool(
tool_result="Sunny and 25°C",
@@ -217,25 +261,42 @@ def test_format_messages_with_cache_point(self):
assert formatted_messages == [
{
"role": "user",
- "content": [{"text": "What is the weather in Paris?"}, {"cachePoint": {"type": "default"}}],
+ "content": [
+ {"text": "What is the weather in Paris?"},
+ {"cachePoint": {"type": "default"}},
+ ],
},
{
"role": "assistant",
"content": [
- {"toolUse": {"toolUseId": "123", "name": "weather", "input": {"city": "Paris"}}},
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "weather",
+ "input": {"city": "Paris"},
+ }
+ },
{"cachePoint": {"type": "default"}},
],
},
{
"role": "user",
"content": [
- {"toolResult": {"toolUseId": "123", "content": [{"text": "Sunny and 25°C"}]}},
+ {
+ "toolResult": {
+ "toolUseId": "123",
+ "content": [{"text": "Sunny and 25°C"}],
+ }
+ },
{"cachePoint": {"type": "default"}},
],
},
{
"role": "assistant",
- "content": [{"text": "The weather in Paris is sunny and 25°C."}, {"cachePoint": {"type": "default"}}],
+ "content": [
+ {"text": "The weather in Paris is sunny and 25°C."},
+ {"cachePoint": {"type": "default"}},
+ ],
},
]
@@ -247,25 +308,44 @@ def test_format_messages_tool_result_with_image(self):
messages = [
ChatMessage.from_user("Retrieve the image and describe it in max 5 words."),
ChatMessage.from_assistant(
- tool_calls=[ToolCall(id="123", tool_name="image_retriever", arguments={"query": "random query"})]
+ tool_calls=[
+ ToolCall(
+ id="123",
+ tool_name="image_retriever",
+ arguments={"query": "random query"},
+ )
+ ]
),
ChatMessage.from_tool(
tool_result=[
TextContent("Here's the retrieved image"),
ImageContent(base64_image=base64_image, mime_type="image/png"),
],
- origin=ToolCall(id="123", tool_name="image_retriever", arguments={"query": "random query"}),
+ origin=ToolCall(
+ id="123",
+ tool_name="image_retriever",
+ arguments={"query": "random query"},
+ ),
),
ChatMessage.from_assistant("Beautiful landscape with mountains"),
]
formatted_system_prompts, formatted_messages = _format_messages(messages)
assert formatted_system_prompts == []
assert formatted_messages == [
- {"role": "user", "content": [{"text": "Retrieve the image and describe it in max 5 words."}]},
+ {
+ "role": "user",
+ "content": [{"text": "Retrieve the image and describe it in max 5 words."}],
+ },
{
"role": "assistant",
"content": [
- {"toolUse": {"toolUseId": "123", "name": "image_retriever", "input": {"query": "random query"}}}
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "image_retriever",
+ "input": {"query": "random query"},
+ }
+ }
],
},
{
@@ -276,13 +356,21 @@ def test_format_messages_tool_result_with_image(self):
"toolUseId": "123",
"content": [
{"text": "Here's the retrieved image"},
- {"image": {"format": "png", "source": {"bytes": base64.b64decode(base64_image)}}},
+ {
+ "image": {
+ "format": "png",
+ "source": {"bytes": base64.b64decode(base64_image)},
+ }
+ },
],
}
}
],
},
- {"role": "assistant", "content": [{"text": "Beautiful landscape with mountains"}]},
+ {
+ "role": "assistant",
+ "content": [{"text": "Beautiful landscape with mountains"}],
+ },
]
def test_format_message_thinking(self):
@@ -334,7 +422,13 @@ def test_format_message_thinking(self):
}
},
{"text": "This is a test message with a tool call."},
- {"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}},
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "test_tool",
+ "input": {"key": "value"},
+ }
+ },
],
}
@@ -352,14 +446,23 @@ def test_format_message_thinking(self):
"content": [
{"reasoningContent": {"reasoningText": {"text": "[REDACTED]"}}},
{"text": "This is a test message with a tool call."},
- {"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}},
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "test_tool",
+ "input": {"key": "value"},
+ }
+ },
],
}
def test_format_user_message(self):
plain_user_message = ChatMessage.from_user("This is a test message.")
formatted_message = _format_user_message(plain_user_message)
- assert formatted_message == {"role": "user", "content": [{"text": "This is a test message."}]}
+ assert formatted_message == {
+ "role": "user",
+ "content": [{"text": "This is a test message."}],
+ }
base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII="
image_content = ImageContent(base64_image)
@@ -369,7 +472,12 @@ def test_format_user_message(self):
"role": "user",
"content": [
{"text": "This is a test message."},
- {"image": {"format": "png", "source": {"bytes": base64.b64decode(base64_image)}}},
+ {
+ "image": {
+ "format": "png",
+ "source": {"bytes": base64.b64decode(base64_image)},
+ }
+ },
],
}
@@ -418,10 +526,14 @@ def test_format_messages_multi_tool(self):
"this information for you.",
tool_calls=[
ToolCall(
- tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q"
+ tool_name="weather_tool",
+ arguments={"location": "Berlin"},
+ id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q",
),
ToolCall(
- tool_name="weather_tool", arguments={"location": "Paris"}, id="tooluse_Oc0n2we2RvquHwuPEflaQA"
+ tool_name="weather_tool",
+ arguments={"location": "Paris"},
+ id="tooluse_Oc0n2we2RvquHwuPEflaQA",
),
],
name=None,
@@ -429,26 +541,37 @@ def test_format_messages_multi_tool(self):
"model": "global.anthropic.claude-sonnet-4-6",
"index": 0,
"finish_reason": "tool_use",
- "usage": {"prompt_tokens": 366, "completion_tokens": 134, "total_tokens": 500},
+ "usage": {
+ "prompt_tokens": 366,
+ "completion_tokens": 134,
+ "total_tokens": 500,
+ },
},
),
ChatMessage.from_tool(
tool_result="Mostly sunny",
origin=ToolCall(
- tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q"
+ tool_name="weather_tool",
+ arguments={"location": "Berlin"},
+ id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q",
),
),
ChatMessage.from_tool(
tool_result="Mostly cloudy",
origin=ToolCall(
- tool_name="weather_tool", arguments={"location": "Paris"}, id="tooluse_Oc0n2we2RvquHwuPEflaQA"
+ tool_name="weather_tool",
+ arguments={"location": "Paris"},
+ id="tooluse_Oc0n2we2RvquHwuPEflaQA",
),
),
]
formatted_system_prompts, formatted_messages = _format_messages(messages)
assert formatted_system_prompts == []
assert formatted_messages == [
- {"role": "user", "content": [{"text": "What is the weather in Berlin and Paris?"}]},
+ {
+ "role": "user",
+ "content": [{"text": "What is the weather in Berlin and Paris?"}],
+ },
{
"role": "assistant",
"content": [
@@ -495,7 +618,12 @@ def test_format_messages_multi_tool(self):
def test_extract_replies_from_text_response(self, mock_boto3_session):
model = "global.anthropic.claude-sonnet-4-6"
text_response = {
- "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}},
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [{"text": "This is a test response"}],
+ }
+ },
"stopReason": "end_turn",
"usage": {
"inputTokens": 10,
@@ -531,7 +659,15 @@ def test_extract_replies_from_tool_response(self, mock_boto3_session):
"output": {
"message": {
"role": "assistant",
- "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}],
+ "content": [
+ {
+ "toolUse": {
+ "toolUseId": "123",
+ "name": "test_tool",
+ "input": {"key": "value"},
+ }
+ }
+ ],
}
},
"stopReason": "tool_use",
@@ -567,7 +703,13 @@ def test_extract_replies_from_text_mixed_response(self, mock_boto3_session):
"role": "assistant",
"content": [
{"text": "Let me help you with that. I'll use the search tool to find the answer."},
- {"toolUse": {"toolUseId": "456", "name": "search_tool", "input": {"query": "test"}}},
+ {
+ "toolUse": {
+ "toolUseId": "456",
+ "name": "search_tool",
+ "input": {"query": "test"},
+ }
+ },
],
}
},
@@ -650,10 +792,14 @@ def test_extract_replies_from_multi_tool_response(self, mock_boto3_session):
"information for you.",
tool_calls=[
ToolCall(
- tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q"
+ tool_name="weather_tool",
+ arguments={"location": "Berlin"},
+ id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q",
),
ToolCall(
- tool_name="weather_tool", arguments={"location": "Paris"}, id="tooluse_Oc0n2we2RvquHwuPEflaQA"
+ tool_name="weather_tool",
+ arguments={"location": "Paris"},
+ id="tooluse_Oc0n2we2RvquHwuPEflaQA",
),
],
name=None,
@@ -731,7 +877,11 @@ def test_extract_replies_from_one_tool_response_with_thinking(self, mock_boto3_s
expected_message = ChatMessage.from_assistant(
text="I'll check the current weather in Paris for you.",
tool_calls=[
- ToolCall(tool_name="weather", arguments={"city": "Paris"}, id="tooluse_iUqy8-ypSByLK5zFkka8uA")
+ ToolCall(
+ tool_name="weather",
+ arguments={"city": "Paris"},
+ id="tooluse_iUqy8-ypSByLK5zFkka8uA",
+ )
],
reasoning=ReasoningContent(
reasoning_text="The user wants to know the weather in Paris. I have a `weather` function available "
@@ -767,7 +917,12 @@ def test_extract_replies_with_guardrail(self, mock_boto3_session):
"test_guardrail_id": {
"topicPolicy": {
"topics": [
- {"name": "Investments topic", "type": "DENY", "action": "BLOCKED", "detected": True}
+ {
+ "name": "Investments topic",
+ "type": "DENY",
+ "action": "BLOCKED",
+ "detected": True,
+ }
]
},
"invocationMetrics": {
@@ -803,7 +958,10 @@ def test_extract_replies_with_guardrail(self, mock_boto3_session):
"RetryAttempts": 0,
},
"output": {
- "message": {"role": "assistant", "content": [{"text": "Sorry, the model cannot answer this question."}]}
+ "message": {
+ "role": "assistant",
+ "content": [{"text": "Sorry, the model cannot answer this question."}],
+ }
},
"stopReason": "guardrail_intervened",
"usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0},
@@ -873,7 +1031,13 @@ def test_parse_completion_response_with_citations(self, mock_boto3_session):
)
}
],
- "location": {"documentPage": {"documentIndex": 0, "start": 1, "end": 2}},
+ "location": {
+ "documentPage": {
+ "documentIndex": 0,
+ "start": 1,
+ "end": 2,
+ }
+ },
}
],
}
@@ -932,33 +1096,112 @@ def test_callback(chunk: StreamingChunk):
# Simulate a stream of events for both text and tool use
events = [
{"messageStart": {"role": "assistant"}},
- {"contentBlockDelta": {"delta": {"text": "Certainly! I can"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " help you find out"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " the weather"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " in Berlin. To"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " get this information, I'll"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " use the weather tool available"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " to me."}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " Let me fetch"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " that data for"}, "contentBlockIndex": 0}},
+ {
+ "contentBlockDelta": {
+ "delta": {"text": "Certainly! I can"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " help you find out"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " the weather"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " in Berlin. To"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " get this information, I'll"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " use the weather tool available"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " to me."},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " Let me fetch"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " that data for"},
+ "contentBlockIndex": 0,
+ }
+ },
{"contentBlockDelta": {"delta": {"text": " you."}, "contentBlockIndex": 0}},
{"contentBlockStop": {"contentBlockIndex": 0}},
{
"contentBlockStart": {
- "start": {"toolUse": {"toolUseId": "tooluse_pLGRAmK7TNKoZQ_rntVN_Q", "name": "weather_tool"}},
+ "start": {
+ "toolUse": {
+ "toolUseId": "tooluse_pLGRAmK7TNKoZQ_rntVN_Q",
+ "name": "weather_tool",
+ }
+ },
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": ""}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": '{"'}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": 'location": '}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": '"B'}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": 'erlin"}'}},
"contentBlockIndex": 1,
}
},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"'}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": 'location": '}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": '"B'}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": 'erlin"}'}}, "contentBlockIndex": 1}},
{"contentBlockStop": {"contentBlockIndex": 1}},
{"messageStop": {"stopReason": "tool_use"}},
{
"metadata": {
- "usage": {"inputTokens": 364, "outputTokens": 71, "totalTokens": 435},
+ "usage": {
+ "inputTokens": 364,
+ "outputTokens": 71,
+ "totalTokens": 435,
+ },
"metrics": {"latencyMs": 2449},
}
},
@@ -976,7 +1219,9 @@ def test_callback(chunk: StreamingChunk):
name=None,
tool_calls=[
ToolCall(
- tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_pLGRAmK7TNKoZQ_rntVN_Q"
+ tool_name="weather_tool",
+ arguments={"location": "Berlin"},
+ id="tooluse_pLGRAmK7TNKoZQ_rntVN_Q",
)
],
meta={
@@ -997,12 +1242,33 @@ def test_callback(chunk: StreamingChunk):
expected_chunks = [
StreamingChunk(content="", meta=base_meta, component_info=c_info),
- StreamingChunk(content="Certainly! I can", meta=base_meta, component_info=c_info, index=0, start=True),
- StreamingChunk(content=" help you find out", meta=base_meta, component_info=c_info, index=0),
+ StreamingChunk(
+ content="Certainly! I can",
+ meta=base_meta,
+ component_info=c_info,
+ index=0,
+ start=True,
+ ),
+ StreamingChunk(
+ content=" help you find out",
+ meta=base_meta,
+ component_info=c_info,
+ index=0,
+ ),
StreamingChunk(content=" the weather", meta=base_meta, component_info=c_info, index=0),
StreamingChunk(content=" in Berlin. To", meta=base_meta, component_info=c_info, index=0),
- StreamingChunk(content=" get this information, I'll", meta=base_meta, component_info=c_info, index=0),
- StreamingChunk(content=" use the weather tool available", meta=base_meta, component_info=c_info, index=0),
+ StreamingChunk(
+ content=" get this information, I'll",
+ meta=base_meta,
+ component_info=c_info,
+ index=0,
+ ),
+ StreamingChunk(
+ content=" use the weather tool available",
+ meta=base_meta,
+ component_info=c_info,
+ index=0,
+ ),
StreamingChunk(content=" to me.", meta=base_meta, component_info=c_info, index=0),
StreamingChunk(content=" Let me fetch", meta=base_meta, component_info=c_info, index=0),
StreamingChunk(content=" that data for", meta=base_meta, component_info=c_info, index=0),
@@ -1013,7 +1279,13 @@ def test_callback(chunk: StreamingChunk):
meta=base_meta,
component_info=c_info,
index=1,
- tool_calls=[ToolCallDelta(index=1, tool_name="weather_tool", id="tooluse_pLGRAmK7TNKoZQ_rntVN_Q")],
+ tool_calls=[
+ ToolCallDelta(
+ index=1,
+ tool_name="weather_tool",
+ id="tooluse_pLGRAmK7TNKoZQ_rntVN_Q",
+ )
+ ],
start=True,
),
StreamingChunk(
@@ -1052,7 +1324,12 @@ def test_callback(chunk: StreamingChunk):
tool_calls=[ToolCallDelta(index=1, arguments='erlin"}')],
),
StreamingChunk(content="", meta=base_meta, component_info=c_info),
- StreamingChunk(content="", meta=base_meta, component_info=c_info, finish_reason="tool_calls"),
+ StreamingChunk(
+ content="",
+ meta=base_meta,
+ component_info=c_info,
+ finish_reason="tool_calls",
+ ),
StreamingChunk(
content="",
meta={
@@ -1103,7 +1380,12 @@ def test_callback(chunk: StreamingChunk):
"contentBlockIndex": 0,
}
},
- {"contentBlockDelta": {"delta": {"reasoningContent": {"text": " access to a"}}, "contentBlockIndex": 0}},
+ {
+ "contentBlockDelta": {
+ "delta": {"reasoningContent": {"text": " access to a"}},
+ "contentBlockIndex": 0,
+ }
+ },
{
"contentBlockDelta": {
"delta": {"reasoningContent": {"text": " weather function that takes"}},
@@ -1134,26 +1416,75 @@ def test_callback(chunk: StreamingChunk):
"contentBlockIndex": 0,
}
},
- {"contentBlockDelta": {"delta": {"reasoningContent": {"text": " function call."}}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "..."}}, "contentBlockIndex": 0}},
+ {
+ "contentBlockDelta": {
+ "delta": {"reasoningContent": {"text": " function call."}},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"reasoningContent": {"signature": "..."}},
+ "contentBlockIndex": 0,
+ }
+ },
{"contentBlockStop": {"contentBlockIndex": 0}},
{
"contentBlockStart": {
- "start": {"toolUse": {"toolUseId": "tooluse_1gPhO4A1RNWgzKbt1PXWLg", "name": "weather"}},
+ "start": {
+ "toolUse": {
+ "toolUseId": "tooluse_1gPhO4A1RNWgzKbt1PXWLg",
+ "name": "weather",
+ }
+ },
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": ""}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": '{"ci'}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": "ty"}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": '": "P'}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": "aris"}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": '"}'}},
"contentBlockIndex": 1,
}
},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"ci'}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": "ty"}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": '": "P'}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": "aris"}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": '"}'}}, "contentBlockIndex": 1}},
{"contentBlockStop": {"contentBlockIndex": 1}},
{"messageStop": {"stopReason": "tool_use"}},
{
"metadata": {
- "usage": {"inputTokens": 412, "outputTokens": 104, "totalTokens": 516},
+ "usage": {
+ "inputTokens": 412,
+ "outputTokens": 104,
+ "totalTokens": 516,
+ },
"metrics": {"latencyMs": 2134},
}
},
@@ -1164,7 +1495,11 @@ def test_callback(chunk: StreamingChunk):
expected_messages = [
ChatMessage.from_assistant(
tool_calls=[
- ToolCall(tool_name="weather", arguments={"city": "Paris"}, id="tooluse_1gPhO4A1RNWgzKbt1PXWLg")
+ ToolCall(
+ tool_name="weather",
+ arguments={"city": "Paris"},
+ id="tooluse_1gPhO4A1RNWgzKbt1PXWLg",
+ )
],
reasoning=ReasoningContent(
reasoning_text="The user is asking about the weather in Paris. I have access to a weather function "
@@ -1211,41 +1546,130 @@ def test_callback(chunk: StreamingChunk):
events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockDelta": {"delta": {"text": "To"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " answer your question about the"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " weather in Berlin and Paris, I'll"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " need to use the weather_tool"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " for each city. Let"}, "contentBlockIndex": 0}},
- {"contentBlockDelta": {"delta": {"text": " me fetch that information for"}, "contentBlockIndex": 0}},
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " answer your question about the"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " weather in Berlin and Paris, I'll"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " need to use the weather_tool"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " for each city. Let"},
+ "contentBlockIndex": 0,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"text": " me fetch that information for"},
+ "contentBlockIndex": 0,
+ }
+ },
{"contentBlockDelta": {"delta": {"text": " you."}, "contentBlockIndex": 0}},
{"contentBlockStop": {"contentBlockIndex": 0}},
{
"contentBlockStart": {
- "start": {"toolUse": {"toolUseId": "tooluse_A0jTtaiQTFmqD_cIq8I1BA", "name": "weather_tool"}},
+ "start": {
+ "toolUse": {
+ "toolUseId": "tooluse_A0jTtaiQTFmqD_cIq8I1BA",
+ "name": "weather_tool",
+ }
+ },
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": ""}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": '{"location":'}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": ' "Be'}},
+ "contentBlockIndex": 1,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": 'rlin"}'}},
"contentBlockIndex": 1,
}
},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"location":'}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": ' "Be'}}, "contentBlockIndex": 1}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": 'rlin"}'}}, "contentBlockIndex": 1}},
{"contentBlockStop": {"contentBlockIndex": 1}},
{
"contentBlockStart": {
- "start": {"toolUse": {"toolUseId": "tooluse_LTc2TUMgTRiobK5Z5CCNSw", "name": "weather_tool"}},
+ "start": {
+ "toolUse": {
+ "toolUseId": "tooluse_LTc2TUMgTRiobK5Z5CCNSw",
+ "name": "weather_tool",
+ }
+ },
+ "contentBlockIndex": 2,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": ""}},
+ "contentBlockIndex": 2,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": '{"l'}},
+ "contentBlockIndex": 2,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": "ocati"}},
+ "contentBlockIndex": 2,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": 'on": "P'}},
+ "contentBlockIndex": 2,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": "ari"}},
+ "contentBlockIndex": 2,
+ }
+ },
+ {
+ "contentBlockDelta": {
+ "delta": {"toolUse": {"input": 's"}'}},
"contentBlockIndex": 2,
}
},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 2}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"l'}}, "contentBlockIndex": 2}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": "ocati"}}, "contentBlockIndex": 2}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": 'on": "P'}}, "contentBlockIndex": 2}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": "ari"}}, "contentBlockIndex": 2}},
- {"contentBlockDelta": {"delta": {"toolUse": {"input": 's"}'}}, "contentBlockIndex": 2}},
{"contentBlockStop": {"contentBlockIndex": 2}},
{"messageStop": {"stopReason": "tool_use"}},
{
"metadata": {
- "usage": {"inputTokens": 366, "outputTokens": 83, "totalTokens": 449},
+ "usage": {
+ "inputTokens": 366,
+ "outputTokens": 83,
+ "totalTokens": 449,
+ },
"metrics": {"latencyMs": 3194},
}
},
@@ -1259,10 +1683,14 @@ def test_callback(chunk: StreamingChunk):
name=None,
tool_calls=[
ToolCall(
- tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_A0jTtaiQTFmqD_cIq8I1BA"
+ tool_name="weather_tool",
+ arguments={"location": "Berlin"},
+ id="tooluse_A0jTtaiQTFmqD_cIq8I1BA",
),
ToolCall(
- tool_name="weather_tool", arguments={"location": "Paris"}, id="tooluse_LTc2TUMgTRiobK5Z5CCNSw"
+ tool_name="weather_tool",
+ arguments={"location": "Paris"},
+ id="tooluse_LTc2TUMgTRiobK5Z5CCNSw",
),
],
meta={
@@ -1297,7 +1725,12 @@ def test_parse_streaming_response_with_guardrail(self, mock_boto3_session):
"test_guardrail_id": {
"topicPolicy": {
"topics": [
- {"name": "Investments topic", "type": "DENY", "action": "BLOCKED", "detected": True}
+ {
+ "name": "Investments topic",
+ "type": "DENY",
+ "action": "BLOCKED",
+ "detected": True,
+ }
]
},
"invocationMetrics": {
@@ -1366,7 +1799,9 @@ def test_callback(chunk: StreamingChunk):
]
assert replies == expected_messages
- def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments(self):
+ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments(
+ self,
+ ):
chunks = [
StreamingChunk(
content="Certainly! I",
@@ -1504,7 +1939,12 @@ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments
},
index=1,
tool_calls=[
- ToolCallDelta(index=1, id="tooluse_QZlUqTveTwyUaCQGQbWP6g", tool_name="hello_world", arguments="")
+ ToolCallDelta(
+ index=1,
+ id="tooluse_QZlUqTveTwyUaCQGQbWP6g",
+ tool_name="hello_world",
+ arguments="",
+ )
],
),
StreamingChunk(
@@ -1536,7 +1976,11 @@ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments
meta={
"model": "global.anthropic.claude-sonnet-4-6",
"received_at": "2025-07-31T08:46:08.596700+00:00",
- "usage": {"prompt_tokens": 349, "completion_tokens": 84, "total_tokens": 433},
+ "usage": {
+ "prompt_tokens": 349,
+ "completion_tokens": 84,
+ "total_tokens": 433,
+ },
},
),
]
@@ -1560,15 +2004,27 @@ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments
assert message._meta["model"] == "global.anthropic.claude-sonnet-4-6"
assert message._meta["index"] == 0
assert message._meta["finish_reason"] == "tool_calls"
- assert message._meta["usage"] == {"completion_tokens": 84, "prompt_tokens": 349, "total_tokens": 433}
+ assert message._meta["usage"] == {
+ "completion_tokens": 84,
+ "prompt_tokens": 349,
+ "total_tokens": 433,
+ }
def test_validate_guardrail_config_with_valid_configs(self):
_validate_guardrail_config(guardrail_config=None, streaming=False)
_validate_guardrail_config(
- guardrail_config={"guardrailIdentifier": "test", "guardrailVersion": "test"}, streaming=False
+ guardrail_config={
+ "guardrailIdentifier": "test",
+ "guardrailVersion": "test",
+ },
+ streaming=False,
)
_validate_guardrail_config(
- guardrail_config={"guardrailIdentifier": "test", "guardrailVersion": "test"}, streaming=True
+ guardrail_config={
+ "guardrailIdentifier": "test",
+ "guardrailVersion": "test",
+ },
+ streaming=True,
)
_validate_guardrail_config(
guardrail_config={
@@ -1580,11 +2036,20 @@ def test_validate_guardrail_config_with_valid_configs(self):
)
def test_validate_guardrail_config_with_invalid_configs(self):
- with pytest.raises(ValueError, match="`guardrailIdentifier` and `guardrailVersion` fields are required"):
+ with pytest.raises(
+ ValueError,
+ match="`guardrailIdentifier` and `guardrailVersion` fields are required",
+ ):
_validate_guardrail_config(guardrail_config={"guardrailIdentifier": "test"}, streaming=False)
- with pytest.raises(ValueError, match="`guardrailIdentifier` and `guardrailVersion` fields are required"):
+ with pytest.raises(
+ ValueError,
+ match="`guardrailIdentifier` and `guardrailVersion` fields are required",
+ ):
_validate_guardrail_config(guardrail_config={"guardrailVersion": "test"}, streaming=False)
- with pytest.raises(ValueError, match="`streamProcessingMode` field is only supported for streaming"):
+ with pytest.raises(
+ ValueError,
+ match="`streamProcessingMode` field is only supported for streaming",
+ ):
_validate_guardrail_config(
guardrail_config={
"guardrailIdentifier": "test",
@@ -1607,10 +2072,16 @@ def test_validate_and_format_cache_point(self):
cache_point = _validate_and_format_cache_point({"type": "default", "ttl": "5m"})
assert cache_point == {"cachePoint": {"type": "default", "ttl": "5m"}}
- with pytest.raises(ValueError, match=r"Cache point must have a 'type' key with value 'default'."):
+ with pytest.raises(
+ ValueError,
+ match=r"Cache point must have a 'type' key with value 'default'.",
+ ):
_validate_and_format_cache_point({"invalid": "config"})
- with pytest.raises(ValueError, match=r"Cache point must have a 'type' key with value 'default'."):
+ with pytest.raises(
+ ValueError,
+ match=r"Cache point must have a 'type' key with value 'default'.",
+ ):
_validate_and_format_cache_point({"type": "invalid"})
with pytest.raises(ValueError, match=r"Cache point can only contain 'type' and 'ttl' keys."):
diff --git a/integrations/amazon_bedrock/tests/test_document_embedder.py b/integrations/amazon_bedrock/tests/test_document_embedder.py
index 867f726efd..f996f5b240 100644
--- a/integrations/amazon_bedrock/tests/test_document_embedder.py
+++ b/integrations/amazon_bedrock/tests/test_document_embedder.py
@@ -11,7 +11,9 @@
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
)
-from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockDocumentEmbedder
+from haystack_integrations.components.embedders.amazon_bedrock import (
+ AmazonBedrockDocumentEmbedder,
+)
TYPE = "haystack_integrations.components.embedders.amazon_bedrock.document_embedder.AmazonBedrockDocumentEmbedder"
@@ -76,11 +78,31 @@ def test_to_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] | N
expected_dict = {
"type": TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "cohere.embed-english-v3",
"input_type": "search_document",
"batch_size": 32,
@@ -98,11 +120,31 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] |
data = {
"type": TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "cohere.embed-english-v3",
"input_type": "search_document",
"batch_size": 32,
@@ -151,11 +193,17 @@ def test_run_invocation_error(self, mock_boto3_session):
def test_prepare_texts_to_embed_w_metadata(self, mock_boto3_session):
documents = [
- Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
+ Document(
+ content=f"document number {i}: content",
+ meta={"meta_field": f"meta_value {i}"},
+ )
+ for i in range(5)
]
embedder = AmazonBedrockDocumentEmbedder(
- model="cohere.embed-english-v3", meta_fields_to_embed=["meta_field"], embedding_separator=" | "
+ model="cohere.embed-english-v3",
+ meta_fields_to_embed=["meta_field"],
+ embedding_separator=" | ",
)
prepared_texts = embedder._prepare_texts_to_embed(documents)
@@ -332,11 +380,17 @@ def test_run_titan_does_not_modify_original_documents(self, mock_boto3_session):
or not os.getenv("AWS_DEFAULT_REGION"),
reason="AWS credentials are not set",
)
- @pytest.mark.parametrize("model", ["cohere.embed-v4:0", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"])
+ @pytest.mark.parametrize(
+ "model",
+ ["cohere.embed-v4:0", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"],
+ )
def test_live_run(self, model):
embedder = AmazonBedrockDocumentEmbedder(model=model)
- documents = [Document(content="this is a test document"), Document(content="I love pizza")]
+ documents = [
+ Document(content="this is a test document"),
+ Document(content="I love pizza"),
+ ]
docs_with_embeddings = embedder.run(documents=documents)["documents"]
diff --git a/integrations/amazon_bedrock/tests/test_document_image_embedder.py b/integrations/amazon_bedrock/tests/test_document_image_embedder.py
index 7a4f339ff1..8b63c86767 100644
--- a/integrations/amazon_bedrock/tests/test_document_image_embedder.py
+++ b/integrations/amazon_bedrock/tests/test_document_image_embedder.py
@@ -12,7 +12,9 @@
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
)
-from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockDocumentImageEmbedder
+from haystack_integrations.components.embedders.amazon_bedrock import (
+ AmazonBedrockDocumentImageEmbedder,
+)
TYPE = (
"haystack_integrations.components.embedders.amazon_bedrock."
@@ -22,7 +24,11 @@
@pytest.fixture
def image_paths(test_files_path):
- return [test_files_path / "apple.jpg", test_files_path / "haystack-logo.png", test_files_path / "sample_pdf_1.pdf"]
+ return [
+ test_files_path / "apple.jpg",
+ test_files_path / "haystack-logo.png",
+ test_files_path / "sample_pdf_1.pdf",
+ ]
class TestAmazonBedrockDocumentImageEmbedder:
@@ -76,11 +82,31 @@ def test_to_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] | N
expected_dict = {
"type": TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "cohere.embed-english-v3",
"file_path_meta_field": "file_path",
"embedding_types": ["float"],
@@ -98,11 +124,31 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] |
data = {
"type": TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "cohere.embed-english-v3",
"embedding_types": ["float"],
"root_path": None,
diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py
index 9452b0948f..eab81eb8dd 100644
--- a/integrations/amazon_bedrock/tests/test_generator.py
+++ b/integrations/amazon_bedrock/tests/test_generator.py
@@ -7,7 +7,9 @@
from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
)
-from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator
+from haystack_integrations.components.generators.amazon_bedrock import (
+ AmazonBedrockGenerator,
+)
from haystack_integrations.components.generators.amazon_bedrock.adapters import (
AI21LabsJurassic2Adapter,
AmazonTitanAdapter,
@@ -26,17 +28,40 @@ def test_to_dict(mock_boto3_session: Any, boto3_config: dict[str, Any] | None):
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
- model="anthropic.claude-v2", max_length=99, temperature=10, boto3_config=boto3_config
+ model="anthropic.claude-v2",
+ max_length=99,
+ temperature=10,
+ boto3_config=boto3_config,
)
expected_dict = {
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "anthropic.claude-v2",
"max_length": 99,
"temperature": 10,
@@ -58,11 +83,31 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: dict[str, Any] | None)
{
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "anthropic.claude-v2",
"max_length": 99,
"boto3_config": boto3_config,
@@ -147,7 +192,10 @@ def test_constructor_with_empty_model():
("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial
("amazon.titan-text-lite-v1", AmazonTitanAdapter),
("amazon.titan-text-express-v1", AmazonTitanAdapter),
- ("us.amazon.titan-text-express-v1", AmazonTitanAdapter), # cross-region inference
+ (
+ "us.amazon.titan-text-express-v1",
+ AmazonTitanAdapter,
+ ), # cross-region inference
("amazon.titan-text-agile-v1", AmazonTitanAdapter),
("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial
("meta.llama2-13b-chat-v1", MetaLlamaAdapter),
@@ -161,8 +209,14 @@ def test_constructor_with_empty_model():
("mistral.mistral-7b-instruct-v0:2", MistralAdapter),
("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter),
("mistral.mistral-large-2402-v1:0", MistralAdapter),
- ("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference
- ("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference
+ (
+ "eu.mistral.mixtral-8x7b-instruct-v0:1",
+ MistralAdapter,
+ ), # cross-region inference
+ (
+ "us.mistral.mistral-large-2402-v1:0",
+ MistralAdapter,
+ ), # cross-region inference
("mistral.mistral-medium-v8:0", MistralAdapter), # artificial
],
)
@@ -497,7 +551,11 @@ def test_get_stream_responses_with_thinking(self) -> None:
call(
StreamingChunk(
content="",
- meta={"type": "content_block_start", "content_block": {"type": "thinking"}, "index": 0},
+ meta={
+ "type": "content_block_start",
+ "content_block": {"type": "thinking"},
+ "index": 0,
+ },
)
),
call(StreamingChunk(content="This", meta={"delta": {"thinking": "This"}})),
@@ -508,7 +566,11 @@ def test_get_stream_responses_with_thinking(self) -> None:
call(
StreamingChunk(
content="\n\n",
- meta={"type": "content_block_start", "content_block": {"type": "text"}, "index": 1},
+ meta={
+ "type": "content_block_start",
+ "content_block": {"type": "text"},
+ "index": 1,
+ },
)
),
call(StreamingChunk(content="This", meta={"delta": {"text": "This"}})),
@@ -580,7 +642,11 @@ def test_get_stream_responses_with_thinking_custom_thinking_tag(self) -> None:
call(
StreamingChunk(
content="",
- meta={"type": "content_block_start", "content_block": {"type": "thinking"}, "index": 0},
+ meta={
+ "type": "content_block_start",
+ "content_block": {"type": "thinking"},
+ "index": 0,
+ },
)
),
call(StreamingChunk(content="This", meta={"delta": {"thinking": "This"}})),
@@ -591,7 +657,11 @@ def test_get_stream_responses_with_thinking_custom_thinking_tag(self) -> None:
call(
StreamingChunk(
content="\n\n",
- meta={"type": "content_block_start", "content_block": {"type": "text"}, "index": 1},
+ meta={
+ "type": "content_block_start",
+ "content_block": {"type": "text"},
+ "index": 1,
+ },
)
),
call(StreamingChunk(content="This", meta={"delta": {"text": "This"}})),
@@ -630,7 +700,11 @@ def test_get_stream_responses_with_thinking_no_thinking_tag(self) -> None:
call(
StreamingChunk(
content="",
- meta={"type": "content_block_start", "content_block": {"type": "thinking"}, "index": 0},
+ meta={
+ "type": "content_block_start",
+ "content_block": {"type": "thinking"},
+ "index": 0,
+ },
)
),
call(StreamingChunk(content="This", meta={"delta": {"thinking": "This"}})),
@@ -641,7 +715,11 @@ def test_get_stream_responses_with_thinking_no_thinking_tag(self) -> None:
call(
StreamingChunk(
content="\n\n",
- meta={"type": "content_block_start", "content_block": {"type": "text"}, "index": 1},
+ meta={
+ "type": "content_block_start",
+ "content_block": {"type": "text"},
+ "index": 1,
+ },
)
),
call(StreamingChunk(content="This", meta={"delta": {"text": "This"}})),
@@ -652,7 +730,9 @@ def test_get_stream_responses_with_thinking_no_thinking_tag(self) -> None:
]
)
- def test_get_stream_responses_with_thinking_redacted_thinking_is_ignored(self) -> None:
+ def test_get_stream_responses_with_thinking_redacted_thinking_is_ignored(
+ self,
+ ) -> None:
stream_mock = MagicMock()
streaming_callback_mock = MagicMock()
@@ -687,7 +767,11 @@ def test_get_stream_responses_with_thinking_redacted_thinking_is_ignored(self) -
call(
StreamingChunk(
content="",
- meta={"type": "content_block_start", "content_block": {"type": "thinking"}, "index": 1},
+ meta={
+ "type": "content_block_start",
+ "content_block": {"type": "thinking"},
+ "index": 1,
+ },
)
),
call(StreamingChunk(content="This", meta={"delta": {"thinking": "This"}})),
@@ -698,7 +782,11 @@ def test_get_stream_responses_with_thinking_redacted_thinking_is_ignored(self) -
call(
StreamingChunk(
content="\n\n",
- meta={"type": "content_block_start", "content_block": {"type": "text"}, "index": 2},
+ meta={
+ "type": "content_block_start",
+ "content_block": {"type": "text"},
+ "index": 2,
+ },
)
),
call(StreamingChunk(content="This", meta={"delta": {"text": "This"}})),
@@ -856,7 +944,11 @@ class TestMistralAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = MistralAdapter(model_kwargs={}, max_length=99)
prompt = "Hello, how are you?"
- expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_tokens": 99, "stop": []}
+ expected_body = {
+ "prompt": "[INST] Hello, how are you? [/INST]",
+ "max_tokens": 99,
+ "stop": [],
+ }
body = layer.prepare_body(prompt)
assert body == expected_body
@@ -1159,7 +1251,12 @@ def test_get_stream_responses(self) -> None:
call(StreamingChunk(content=" a", meta={"text": " a"})),
call(StreamingChunk(content=" single", meta={"text": " single"})),
call(StreamingChunk(content=" response.", meta={"text": " response."})),
- call(StreamingChunk(content="", meta={"finish_reason": "MAX_TOKENS", "is_finished": True})),
+ call(
+ StreamingChunk(
+ content="",
+ meta={"finish_reason": "MAX_TOKENS", "is_finished": True},
+ )
+ ),
]
)
@@ -1185,7 +1282,10 @@ def test_prepare_body(self) -> None:
],
"documents": [
{"title": "France", "snippet": "Paris is the capital of France."},
- {"title": "Germany", "snippet": "Berlin is the capital of Germany."},
+ {
+ "title": "Germany",
+ "snippet": "Berlin is the capital of Germany.",
+ },
],
"search_query_only": False,
"preamble": "preamble",
@@ -1213,9 +1313,15 @@ def test_prepare_body(self) -> None:
],
"tool_results": [
{
- "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}},
+ "call": {
+ "name": "query_daily_sales_report",
+ "parameters": {"day": "2023-09-29"},
+ },
"outputs": [
- {"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"}
+ {
+ "date": "2023-09-29",
+ "summary": "Total Sales Amount: 10000, Total Units Sold: 250",
+ }
],
}
],
@@ -1263,8 +1369,16 @@ def test_prepare_body(self) -> None:
],
"tool_results": [
{
- "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}},
- "outputs": [{"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"}],
+ "call": {
+ "name": "query_daily_sales_report",
+ "parameters": {"day": "2023-09-29"},
+ },
+ "outputs": [
+ {
+ "date": "2023-09-29",
+ "summary": "Total Sales Amount: 10000, Total Units Sold: 250",
+ }
+ ],
}
],
"stop_sequences": ["\n\n"],
@@ -1714,7 +1828,10 @@ def test_run_with_metadata(self, mock_boto3_session) -> None:
"ResponseMetadata": {
"RequestId": "test-request-id",
"HTTPStatusCode": 200,
- "HTTPHeaders": {"x-amzn-requestid": "test-request-id", "content-type": "application/json"},
+ "HTTPHeaders": {
+ "x-amzn-requestid": "test-request-id",
+ "content-type": "application/json",
+ },
},
}
mock_client.invoke_model.return_value = mock_response
diff --git a/integrations/amazon_bedrock/tests/test_ranker.py b/integrations/amazon_bedrock/tests/test_ranker.py
index d97aca307a..367301ec89 100644
--- a/integrations/amazon_bedrock/tests/test_ranker.py
+++ b/integrations/amazon_bedrock/tests/test_ranker.py
@@ -39,7 +39,12 @@ def test_bedrock_ranker_run(mock_aws_session):
aws_region_name=Secret.from_token("us-west-2"),
)
- mock_response = {"results": [{"index": 0, "relevanceScore": 0.9}, {"index": 1, "relevanceScore": 0.7}]}
+ mock_response = {
+ "results": [
+ {"index": 0, "relevanceScore": 0.9},
+ {"index": 1, "relevanceScore": 0.7},
+ ]
+ }
mock_aws_session.rerank.return_value = mock_response
diff --git a/integrations/amazon_bedrock/tests/test_s3_downloader.py b/integrations/amazon_bedrock/tests/test_s3_downloader.py
index 61486e29fe..ff7d5a4826 100644
--- a/integrations/amazon_bedrock/tests/test_s3_downloader.py
+++ b/integrations/amazon_bedrock/tests/test_s3_downloader.py
@@ -69,11 +69,31 @@ def test_to_dict(self, mock_boto3_session: Any, tmp_path, boto3_config: dict[str
expected = {
"type": TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"file_root_path": str(tmp_path),
"file_extensions": None,
"max_cache_size": 100,
@@ -90,11 +110,31 @@ def test_from_dict(self, mock_boto3_session: Any, tmp_path, boto3_config: dict[s
data = {
"type": TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"file_root_path": str(tmp_path),
"s3_key_generation_function": None,
"s3_bucket_name_env": "S3_DOWNLOADER_BUCKET",
@@ -116,11 +156,31 @@ def test_to_dict_with_parameters(self, tmp_path):
expected = {
"type": TYPE,
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"file_root_path": str(tmp_path),
"file_extensions": [".txt"],
"max_cache_size": 400,
@@ -167,7 +227,10 @@ def test_run_with_input_file_meta_key(self, tmp_path, mock_s3_storage, mock_boto
assert out["documents"][0].meta["custom_file_key"] == "a.txt"
def test_run_with_s3_key_generation_function(self, tmp_path, mock_s3_storage, mock_boto3_session):
- d = S3Downloader(file_root_path=str(tmp_path), s3_key_generation_function=s3_key_generation_function)
+ d = S3Downloader(
+ file_root_path=str(tmp_path),
+ s3_key_generation_function=s3_key_generation_function,
+ )
d._storage = mock_s3_storage
docs = [Document(meta={"file_id": str(uuid4()), "file_name": "a.txt"})]
diff --git a/integrations/amazon_bedrock/tests/test_text_embedder.py b/integrations/amazon_bedrock/tests/test_text_embedder.py
index e1e9fb98bd..7d5bf3bca3 100644
--- a/integrations/amazon_bedrock/tests/test_text_embedder.py
+++ b/integrations/amazon_bedrock/tests/test_text_embedder.py
@@ -9,7 +9,9 @@
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
)
-from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder
+from haystack_integrations.components.embedders.amazon_bedrock import (
+ AmazonBedrockTextEmbedder,
+)
class TestAmazonBedrockTextEmbedder:
@@ -52,11 +54,31 @@ def test_to_dict(self, mock_boto3_session):
expected_dict = {
"type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder",
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "cohere.embed-english-v3",
"input_type": "search_query",
"boto3_config": None,
@@ -69,11 +91,31 @@ def test_from_dict(self, mock_boto3_session):
data = {
"type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder",
"init_parameters": {
- "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False},
- "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False},
- "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
- "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
- "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
+ "aws_access_key_id": {
+ "type": "env_var",
+ "env_vars": ["AWS_ACCESS_KEY_ID"],
+ "strict": False,
+ },
+ "aws_secret_access_key": {
+ "type": "env_var",
+ "env_vars": ["AWS_SECRET_ACCESS_KEY"],
+ "strict": False,
+ },
+ "aws_session_token": {
+ "type": "env_var",
+ "env_vars": ["AWS_SESSION_TOKEN"],
+ "strict": False,
+ },
+ "aws_region_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_DEFAULT_REGION"],
+ "strict": False,
+ },
+ "aws_profile_name": {
+ "type": "env_var",
+ "env_vars": ["AWS_PROFILE"],
+ "strict": False,
+ },
"model": "cohere.embed-english-v3",
"input_type": "search_query",
"boto3_config": {
@@ -162,7 +204,10 @@ def test_run_invocation_error(self, mock_boto3_session):
or not os.getenv("AWS_DEFAULT_REGION"),
reason="AWS credentials are not set",
)
- @pytest.mark.parametrize("model", ["cohere.embed-v4:0", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"])
+ @pytest.mark.parametrize(
+ "model",
+ ["cohere.embed-v4:0", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"],
+ )
def test_live_run(self, model):
embedder = AmazonBedrockTextEmbedder(model=model)