diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 81fd65f544..067aa752b2 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -23,7 +23,9 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.24.1", "boto3>=1.28.57", "aioboto3>=14.0.0"] +dependencies = ["haystack-ai>=2.24.1", "boto3>=1.42.84", "aioboto3>=14.0.0"] + + [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme" @@ -172,6 +174,12 @@ omit = ["*/tests/*", "*/__init__.py"] show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +[tool.uv] +# aiobotocore 2.25.1 pins botocore<1.40.62 but works fine with newer botocore in practice. +# outputConfig support in the Bedrock Converse API requires botocore>=1.42.84. +# Override the transitive constraint so uv can resolve the dependency graph. +override-dependencies = ["botocore>=1.42.84", "boto3>=1.42.84"] + [tool.pytest.ini_options] addopts = "--strict-markers" markers = [ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 16f45a915e..2ef9c690e9 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -1,3 +1,4 @@ +import json from typing import Any import aioboto3 @@ -5,7 +6,12 @@ from botocore.eventstream import EventStream from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, select_streaming_callback +from haystack.dataclasses import ( + ChatMessage, + ComponentInfo, + StreamingCallbackT, + select_streaming_callback, +) from haystack.tools import ( ToolsType, _check_duplicate_tool_names, @@ -14,7 +20,10 @@ serialize_tools_or_toolset, ) from haystack.utils.auth import Secret -from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from haystack.utils.callable_serialization import ( + deserialize_callable, + serialize_callable, +) from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, @@ -27,6 +36,7 @@ _parse_completion_response, _parse_streaming_response, _parse_streaming_response_async, + _parse_structured_output, _validate_and_format_cache_point, _validate_guardrail_config, ) @@ -164,13 +174,11 @@ def weather(city: str): def __init__( self, model: str, - aws_access_key_id: Secret | None = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 - aws_secret_access_key: Secret | None = Secret.from_env_var( # noqa: B008 - ["AWS_SECRET_ACCESS_KEY"], strict=False - ), - aws_session_token: Secret | None = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 - aws_region_name: Secret | None = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 - aws_profile_name: Secret | None = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 + aws_access_key_id: Secret | None = None, + aws_secret_access_key: Secret | None = None, + aws_session_token: Secret | None = None, + aws_region_name: Secret | None = None, + aws_profile_name: Secret | None = None, generation_kwargs: dict[str, Any] | None = None, streaming_callback: StreamingCallbackT | None = None, boto3_config: dict[str, Any] | None = None, @@ -236,6 +244,22 @@ def __init__( msg = "'model' cannot be None or empty string" raise ValueError(msg) self.model = model + if aws_access_key_id is None: + aws_access_key_id = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False) + + if aws_secret_access_key is None: + aws_secret_access_key = Secret.from_env_var( + ["AWS_SECRET_ACCESS_KEY"], strict=False + ) + + if aws_session_token is None: + aws_session_token = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False) + + if aws_region_name is None: + aws_region_name = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False) + + if aws_profile_name is None: + aws_profile_name = Secret.from_env_var(["AWS_PROFILE"], strict=False) self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key self.aws_session_token = aws_session_token @@ -247,18 +271,23 @@ def __init__( _check_duplicate_tool_names(flatten_tools_or_toolsets(tools)) self.tools = tools - _validate_guardrail_config(guardrail_config=guardrail_config, streaming=streaming_callback is not None) + _validate_guardrail_config( + guardrail_config=guardrail_config, streaming=streaming_callback is not None + ) self.guardrail_config = guardrail_config self.tools_cachepoint_config = ( - _validate_and_format_cache_point(tools_cachepoint_config) if tools_cachepoint_config else None + _validate_and_format_cache_point(tools_cachepoint_config) + if tools_cachepoint_config + else None ) def resolve_secret(secret: Secret | None) -> str | None: return secret.resolve_value() if secret else None config = Config( - user_agent_extra="x-client-framework:haystack", **(self.boto3_config if self.boto3_config else {}) + user_agent_extra="x-client-framework:haystack", + **(self.boto3_config if self.boto3_config else {}), ) try: @@ -300,13 +329,31 @@ def _get_async_session(self) -> aioboto3.Session: try: self.async_session = get_aws_session( - aws_access_key_id=self.aws_access_key_id.resolve_value() if self.aws_access_key_id else None, + aws_access_key_id=( + self.aws_access_key_id.resolve_value() + if self.aws_access_key_id + else None + ), aws_secret_access_key=( - self.aws_secret_access_key.resolve_value() if self.aws_secret_access_key else None + self.aws_secret_access_key.resolve_value() + if self.aws_secret_access_key + else None + ), + aws_session_token=( + self.aws_session_token.resolve_value() + if self.aws_session_token + else None + ), + aws_region_name=( + self.aws_region_name.resolve_value() + if self.aws_region_name + else None + ), + aws_profile_name=( + self.aws_profile_name.resolve_value() + if self.aws_profile_name + else None ), - aws_session_token=self.aws_session_token.resolve_value() if self.aws_session_token else None, - aws_region_name=self.aws_region_name.resolve_value() if self.aws_region_name else None, - aws_profile_name=self.aws_profile_name.resolve_value() if self.aws_profile_name else None, async_mode=True, ) return self.async_session @@ -325,7 +372,11 @@ def to_dict(self) -> dict[str, Any]: :returns: Dictionary with serialized data. """ - callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + callback_name = ( + serialize_callable(self.streaming_callback) + if self.streaming_callback + else None + ) return default_to_dict( self, aws_access_key_id=self.aws_access_key_id, @@ -360,7 +411,9 @@ def from_dict(cls, data: dict[str, Any]) -> "AmazonBedrockChatGenerator": serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + data["init_parameters"]["streaming_callback"] = deserialize_callable( + serialized_callback_handler + ) deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools") return default_from_dict(cls, data) @@ -385,6 +438,26 @@ def _prepare_request_params( - `stopSequences`: List of stop sequences to stop generation. - `temperature`: Sampling temperature. - `topP`: Nucleus sampling parameter. + - `json_schema`: Request structured JSON output validated against a schema. Provide a dict with: + - `schema` (required): a JSON Schema dict describing the expected output structure. + - `name` (optional): a name for the schema, defaults to ``"response_schema"``. + - `description` (optional): a description of the schema. + + Example:: + + generation_kwargs={ + "json_schema": { + "name": "person", + "schema": { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + "additionalProperties": False, + }, + } + } + + When set, the parsed JSON object is stored in ``reply.meta["structured_output"]``. :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. Each tool should have a unique name. :param requires_async: Boolean flag to indicate if an async-compatible streaming callback function is needed. @@ -413,13 +486,36 @@ def _prepare_request_params( flattened_tools = flatten_tools_or_toolsets(tools) _check_duplicate_tool_names(flattened_tools) tool_config = merged_kwargs.pop("toolConfig", None) + json_schema = merged_kwargs.pop("json_schema", None) if flattened_tools: # Format Haystack tools to Bedrock format - tool_config = _format_tools(flattened_tools, tools_cachepoint_config=self.tools_cachepoint_config) + tool_config = _format_tools( + flattened_tools, tools_cachepoint_config=self.tools_cachepoint_config + ) # Any remaining kwargs go to additionalModelRequestFields additional_fields = merged_kwargs if merged_kwargs else None + # Build outputConfig from json_schema for structured output support. + # See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html + output_config: dict[str, Any] | None = None + if json_schema is not None: + if "schema" not in json_schema: + msg = "'json_schema' must contain a 'schema' key with the JSON Schema dict." + raise ValueError(msg) + json_schema_block: dict[str, Any] = { + "name": json_schema.get("name", "response_schema"), + "schema": json.dumps(json_schema["schema"]), + } + if "description" in json_schema: + json_schema_block["description"] = json_schema["description"] + output_config = { + "textFormat": { + "type": "json_schema", + "structure": {"jsonSchema": json_schema_block}, + } + } + # Format messages to Bedrock format system_prompts, messages_list = _format_messages(messages) @@ -434,6 +530,8 @@ def _prepare_request_params( params["toolConfig"] = tool_config if additional_fields: params["additionalModelRequestFields"] = additional_fields + if output_config: + params["outputConfig"] = output_config if self.guardrail_config: params["guardrailConfig"] = self.guardrail_config @@ -447,10 +545,14 @@ def _prepare_request_params( return params, callback - def _resolve_flattened_generation_kwargs(self, generation_kwargs: dict[str, Any]) -> dict[str, Any]: + def _resolve_flattened_generation_kwargs( + self, generation_kwargs: dict[str, Any] + ) -> dict[str, Any]: generation_kwargs = generation_kwargs.copy() - disable_parallel_tool_use = generation_kwargs.pop("disable_parallel_tool_use", None) + disable_parallel_tool_use = generation_kwargs.pop( + "disable_parallel_tool_use", None + ) parallel_tool_use = generation_kwargs.pop("parallel_tool_use", None) if disable_parallel_tool_use is not None and parallel_tool_use is not None: @@ -497,13 +599,16 @@ def run( - `stopSequences`: List of stop sequences to stop generation. - `temperature`: Sampling temperature. - `topP`: Nucleus sampling parameter. + - `json_schema`: Request structured JSON output validated against a schema. See + :meth:`_prepare_request_params` for full details. :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. Each tool should have a unique name. :returns: A dictionary containing the model-generated replies under the `"replies"` key. + When ``json_schema`` is used, each reply's ``meta["structured_output"]`` contains the parsed JSON object. :raises AmazonBedrockInferenceError: - If the Bedrock inference API call fails. + If the Bedrock inference API call fails or the model returns invalid JSON for structured output. """ component_info = ComponentInfo.from_component(self) @@ -536,6 +641,9 @@ def run( msg = f"Could not perform inference for Amazon Bedrock model {self.model} due to:\n{exception}" raise AmazonBedrockInferenceError(msg) from exception + if "outputConfig" in params: + replies = _parse_structured_output(replies) + return {"replies": replies} @component.output_types(replies=list[ChatMessage]) @@ -558,13 +666,16 @@ async def run_async( - `stopSequences`: List of stop sequences to stop generation. - `temperature`: Sampling temperature. - `topP`: Nucleus sampling parameter. + - `json_schema`: Request structured JSON output validated against a schema. See + :meth:`_prepare_request_params` for full details. :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls. Each tool should have a unique name. :returns: A dictionary containing the model-generated replies under the `"replies"` key. + When ``json_schema`` is used, each reply's ``meta["structured_output"]`` contains the parsed JSON object. :raises AmazonBedrockInferenceError: - If the Bedrock inference API call fails. + If the Bedrock inference API call fails or the model returns invalid JSON for structured output. """ component_info = ComponentInfo.from_component(self) @@ -581,7 +692,8 @@ async def run_async( # Note: https://aioboto3.readthedocs.io/en/latest/usage.html # we need to create a new client for each request config = Config( - user_agent_extra="x-client-framework:haystack", **(self.boto3_config if self.boto3_config else {}) + user_agent_extra="x-client-framework:haystack", + **(self.boto3_config if self.boto3_config else {}), ) async with session.client("bedrock-runtime", config=config) as async_client: if callback: @@ -605,4 +717,7 @@ async def run_async( msg = f"Could not perform inference for Amazon Bedrock model {self.model} due to:\n{exception}" raise AmazonBedrockInferenceError(msg) from exception + if "outputConfig" in params: + replies = _parse_structured_output(replies) + return {"replies": replies} diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py index f6a5ce2251..cd9ceb4d2b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/utils.py @@ -25,6 +25,8 @@ ) from haystack.tools import Tool +from haystack_integrations.common.amazon_bedrock.errors import AmazonBedrockInferenceError + logger = logging.getLogger(__name__) @@ -713,6 +715,28 @@ async def _parse_streaming_response_async( return replies +def _parse_structured_output(replies: list[ChatMessage]) -> list[ChatMessage]: + """ + Parse JSON structured output from model replies and store it in message metadata. + + When structured output is requested via ``json_schema`` in ``generation_kwargs``, the model + returns JSON text. This function parses that JSON and stores the resulting object in + ``reply.meta["structured_output"]`` for each reply that contains text. + + :param replies: List of ChatMessage objects returned by the model. + :returns: The same list with ``meta["structured_output"]`` populated on text replies. + :raises AmazonBedrockInferenceError: If the model's response is not valid JSON. + """ + for reply in replies: + if reply.text: + try: + reply.meta["structured_output"] = json.loads(reply.text) + except json.JSONDecodeError as e: + msg = f"Structured output was requested but the model returned invalid JSON: {e}" + raise AmazonBedrockInferenceError(msg) from e + return replies + + def _validate_guardrail_config(guardrail_config: dict[str, str] | None = None, streaming: bool = False) -> None: """ Validate the guardrail configuration. diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8e7ebdc503..7d15094c2d 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,3 +1,4 @@ +import json import os from typing import Any @@ -6,10 +7,20 @@ from haystack.components.agents import Agent from haystack.components.generators.utils import print_streaming_chunk from haystack.components.tools import ToolInvoker -from haystack.dataclasses import ChatMessage, ChatRole, FileContent, ImageContent, StreamingChunk, TextContent, ToolCall +from haystack.dataclasses import ( + ChatMessage, + ChatRole, + FileContent, + ImageContent, + StreamingChunk, + TextContent, + ToolCall, +) from haystack.tools import Tool, Toolset, create_tool_from_function -from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator +from haystack_integrations.components.generators.amazon_bedrock import ( + AmazonBedrockChatGenerator, +) CLASS_TYPE = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" MODELS_TO_TEST = [ @@ -58,9 +69,7 @@ "global.anthropic.claude-sonnet-4-6", ] -MODELS_TO_TEST_WITH_PROMPT_CACHING = [ - "us.amazon.nova-micro-v1:0" # cheap, fast model -] +MODELS_TO_TEST_WITH_PROMPT_CACHING = ["us.amazon.nova-micro-v1:0"] # cheap, fast model def hello_world(): @@ -90,7 +99,11 @@ def population(city: str): @pytest.fixture def tools(): - tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool_parameters = { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } tool = Tool( name="weather", description="useful to determine the weather in a given location", @@ -106,13 +119,21 @@ def mixed_tools(): weather_tool = Tool( name="weather", description="useful to determine the weather in a given location", - parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, function=weather, ) population_tool = Tool( name="population", description="useful to determine the population of a given location", - parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, function=population, ) toolset = Toolset([population_tool]) @@ -171,22 +192,48 @@ def test_to_dict(self, mock_boto3_session, boto3_config): generation_kwargs={"temperature": 0.7}, streaming_callback=print_streaming_chunk, boto3_config=boto3_config, - guardrail_config={"guardrailIdentifier": "test", "guardrailVersion": "test"}, + guardrail_config={ + "guardrailIdentifier": "test", + "guardrailVersion": "test", + }, ) expected_dict = { "type": CLASS_TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "cohere.command-r-plus-v1:0", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "boto3_config": boto3_config, "tools": None, - "guardrail_config": {"guardrailIdentifier": "test", "guardrailVersion": "test"}, + "guardrail_config": { + "guardrailIdentifier": "test", + "guardrailVersion": "test", + }, "tools_cachepoint_config": None, }, } @@ -202,15 +249,31 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] | { "type": CLASS_TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, "aws_secret_access_key": { "type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False, }, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "global.anthropic.claude-sonnet-4-6", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -250,7 +313,8 @@ def test_constructor_with_generation_kwargs(self, mock_boto3_session): """ generation_kwargs = {"temperature": 0.7} layer = AmazonBedrockChatGenerator( - model="global.anthropic.claude-sonnet-4-6", generation_kwargs=generation_kwargs + model="global.anthropic.claude-sonnet-4-6", + generation_kwargs=generation_kwargs, ) assert layer.generation_kwargs == generation_kwargs @@ -294,17 +358,36 @@ def test_serde_in_pipeline(self, mock_boto3_session, monkeypatch): "generator": { "type": CLASS_TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, "aws_secret_access_key": { "type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False, }, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "global.anthropic.claude-sonnet-4-6", - "generation_kwargs": {"temperature": 0.7, "stopSequences": ["eviscerate"]}, + "generation_kwargs": { + "temperature": 0.7, + "stopSequences": ["eviscerate"], + }, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "boto3_config": None, "tools": [ @@ -347,13 +430,62 @@ def test_prepare_request_params_tool_config(self, top_song_tool_config, mock_bot def test_prepare_request_params_guardrail_config(self, mock_boto3_session, set_env_variables): generator = AmazonBedrockChatGenerator( model="global.anthropic.claude-sonnet-4-6", - guardrail_config={"guardrailIdentifier": "test", "guardrailVersion": "test"}, + guardrail_config={ + "guardrailIdentifier": "test", + "guardrailVersion": "test", + }, ) request_params, _ = generator._prepare_request_params( messages=[ChatMessage.from_user("What's the capital of France?")], ) assert request_params["messages"] == [{"content": [{"text": "What's the capital of France?"}], "role": "user"}] - assert request_params["guardrailConfig"] == {"guardrailIdentifier": "test", "guardrailVersion": "test"} + assert request_params["guardrailConfig"] == { + "guardrailIdentifier": "test", + "guardrailVersion": "test", + } + + def test_prepare_request_params_json_schema(self, mock_boto3_session, set_env_variables): + + generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6") + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + "additionalProperties": False, + } + request_params, _ = generator._prepare_request_params( + messages=[ChatMessage.from_user("Hello")], + generation_kwargs={ + "json_schema": { + "name": "person", + "description": "A person's name and age", + "schema": schema, + } + }, + ) + assert "outputConfig" in request_params + assert request_params["outputConfig"] == { + "textFormat": { + "type": "json_schema", + "structure": { + "jsonSchema": { + "name": "person", + "description": "A person's name and age", + "schema": json.dumps(schema), + } + }, + } + } + + def test_prepare_request_params_json_schema_missing_schema_key( + self, mock_boto3_session, set_env_variables + ): + generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6") + with pytest.raises(ValueError, match="'json_schema' must contain a 'schema' key"): + generator._prepare_request_params( + messages=[ChatMessage.from_user("Hello")], + generation_kwargs={"json_schema": {"name": "test"}}, + ) def test_init_with_mixed_tools(self, mock_boto3_session, set_env_variables): def tool_fn(city: str) -> str: @@ -362,13 +494,21 @@ def tool_fn(city: str) -> str: weather_tool = Tool( name="weather", description="Weather lookup", - parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, function=tool_fn, ) population_tool = Tool( name="population", description="Population lookup", - parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, function=tool_fn, ) toolset = Toolset([population_tool]) @@ -387,13 +527,21 @@ def tool_fn(city: str) -> str: weather_tool = Tool( name="weather", description="Weather lookup", - parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, function=tool_fn, ) population_tool = Tool( name="population", description="Population lookup", - parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, function=tool_fn, ) toolset = Toolset([population_tool]) @@ -485,7 +633,11 @@ def tool_fn(city: str) -> str: ], ) def test_prepare_request_params_with_flattened_generation_kwargs( - self, mock_boto3_session, set_env_variables, generation_kwargs, additional_model_request_fields + self, + mock_boto3_session, + set_env_variables, + generation_kwargs, + additional_model_request_fields, ): generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6") request_params, _ = generator._prepare_request_params( @@ -523,6 +675,54 @@ def test_default_inference_params(self, model_name, chat_messages): assert "prompt_tokens" in first_reply.meta["usage"] assert "completion_tokens" in first_reply.meta["usage"] + def test_run_with_structured_output(self): + + client = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6") + + messages = [ + ChatMessage.from_user( + "Extract the person's name and age from: 'Alice is 30 years old.'" + ) + ] + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + "additionalProperties": False, + } + + response = client.run( + messages, + generation_kwargs={ + "json_schema": { + "name": "person", + "description": "A person's name and age", + "schema": schema, + } + }, + ) + + assert "replies" in response + assert len(response["replies"]) > 0 + + reply = response["replies"][0] + assert reply.text is not None + + # Response text must be valid JSON + parsed = json.loads(reply.text) + assert isinstance(parsed, dict) + + # Parsed object is accessible via meta + assert "structured_output" in reply.meta + assert reply.meta["structured_output"] == parsed + + # Schema fields are present with correct types + assert isinstance(reply.meta["structured_output"].get("name"), str) + assert isinstance(reply.meta["structured_output"].get("age"), int) + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_IMAGE_INPUT) def test_run_with_image_input(self, model_name, test_files_path): client = AmazonBedrockChatGenerator(model=model_name) @@ -548,7 +748,10 @@ def test_run_with_pdf_citations(self, model_name, test_files_path): file_content = FileContent.from_file_path(file_path, extra={"citations": {"enabled": True}}) chat_message = ChatMessage.from_user( - content_parts=["Is this document a paper on Large Language Models? Respond briefly", file_content] + content_parts=[ + "Is this document a paper on Large Language Models? Respond briefly", + file_content, + ] ) response = client.run([chat_message]) @@ -586,7 +789,9 @@ def retrieve_image(): ] image_retriever_tool = create_tool_from_function( - name="retrieve_image", description="Tool to retrieve an image", function=retrieve_image + name="retrieve_image", + description="Tool to retrieve an image", + function=retrieve_image, ) image_retriever_tool.outputs_to_string = {"raw_result": True} @@ -927,7 +1132,9 @@ def test_live_run_with_tool_with_no_args_streaming(self, tool_with_no_parameters """ initial_messages = [ChatMessage.from_user("Print Hello World using the print hello world tool.")] component = AmazonBedrockChatGenerator( - model=model_name, tools=[tool_with_no_parameters], streaming_callback=print_streaming_chunk + model=model_name, + tools=[tool_with_no_parameters], + streaming_callback=print_streaming_chunk, ) results = component.run(messages=initial_messages) @@ -992,7 +1199,8 @@ def test_prompt_caching_live_run_with_user_message(self, model_name, streaming_c system_message = ChatMessage.from_system("Always respond with: 'Life is beautiful' (and nothing else).") user_message = ChatMessage.from_user( - "User message that should be long enough to cache. " * 100, meta={"cachePoint": {"type": "default"}} + "User message that should be long enough to cache. " * 100, + meta={"cachePoint": {"type": "default"}}, ) messages = [system_message, user_message] result = generator.run(messages=messages) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py index a09ba17b8a..b9f1ba6cd0 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator_utils.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator_utils.py @@ -45,7 +45,11 @@ def tools(): weather_tool = Tool( name="weather", description="useful to determine the weather in a given location", - parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, function=weather, ) addition_tool = Tool( @@ -71,7 +75,11 @@ def test_format_tools(self, tools): "name": "weather", "description": "useful to determine the weather in a given location", "inputSchema": { - "json": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + "json": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } }, } }, @@ -82,7 +90,10 @@ def test_format_tools(self, tools): "inputSchema": { "json": { "type": "object", - "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, "required": ["a", "b"], } }, @@ -94,7 +105,9 @@ def test_format_tools(self, tools): def test_convert_file_content_to_bedrock_format_no_mime_type(self): file_content = FileContent( - base64_data=base64.b64encode(b"This is a test file content."), mime_type=None, validation=False + base64_data=base64.b64encode(b"This is a test file content."), + mime_type=None, + validation=False, ) with pytest.raises(ValueError, match="MIME type is required"): _convert_file_content_to_bedrock_format(file_content) @@ -102,7 +115,8 @@ def test_convert_file_content_to_bedrock_format_no_mime_type(self): def test_convert_file_content_to_bedrock_format_document(self, test_files_path): file_path = test_files_path / "sample_pdf_1.pdf" file_content = FileContent.from_file_path( - file_path, extra={"context": "Example context.", "citations": {"enabled": True}} + file_path, + extra={"context": "Example context.", "citations": {"enabled": True}}, ) formatted_file_content = _convert_file_content_to_bedrock_format(file_content) assert formatted_file_content == { @@ -149,12 +163,17 @@ def test_convert_file_content_to_bedrock_format_video(self, test_files_path): file_content = FileContent.from_file_path(file_path) formatted_file_content = _convert_file_content_to_bedrock_format(file_content) assert formatted_file_content == { - "video": {"format": "mp4", "source": {"bytes": base64.b64decode(file_content.base64_data)}} + "video": { + "format": "mp4", + "source": {"bytes": base64.b64decode(file_content.base64_data)}, + } } def test_convert_file_content_to_bedrock_format_unsupported_mime_type(self): file_content = FileContent( - base64_data=base64.b64encode(b"This is a test file content."), mime_type="image/tiff", validation=False + base64_data=base64.b64encode(b"This is a test file content."), + mime_type="image/tiff", + validation=False, ) with pytest.raises(ValueError, match="Unsupported file content MIME type"): _convert_file_content_to_bedrock_format(file_content) @@ -180,27 +199,52 @@ def test_format_messages(self): ] assert formatted_messages == [ {"role": "user", "content": [{"text": "What's the capital of France?"}]}, - {"role": "assistant", "content": [{"text": "The capital of France is Paris."}]}, + { + "role": "assistant", + "content": [{"text": "The capital of France is Paris."}], + }, {"role": "user", "content": [{"text": "What is the weather in Paris?"}]}, { "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "weather", "input": {"city": "Paris"}}}], + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "weather", + "input": {"city": "Paris"}, + } + } + ], }, { "role": "user", - "content": [{"toolResult": {"toolUseId": "123", "content": [{"text": "Sunny and 25°C"}]}}], + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "Sunny and 25°C"}], + } + } + ], + }, + { + "role": "assistant", + "content": [{"text": "The weather in Paris is sunny and 25°C."}], }, - {"role": "assistant", "content": [{"text": "The weather in Paris is sunny and 25°C."}]}, ] def test_format_messages_with_cache_point(self): meta = {"cachePoint": {"type": "default"}} messages = [ - ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses.", meta=meta), + ChatMessage.from_system( + "\\nYou are a helpful assistant, be super brief in your responses.", + meta=meta, + ), ChatMessage.from_user("What is the weather in Paris?", meta=meta), ChatMessage.from_assistant( - tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})], meta=meta + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})], + meta=meta, ), ChatMessage.from_tool( tool_result="Sunny and 25°C", @@ -217,25 +261,42 @@ def test_format_messages_with_cache_point(self): assert formatted_messages == [ { "role": "user", - "content": [{"text": "What is the weather in Paris?"}, {"cachePoint": {"type": "default"}}], + "content": [ + {"text": "What is the weather in Paris?"}, + {"cachePoint": {"type": "default"}}, + ], }, { "role": "assistant", "content": [ - {"toolUse": {"toolUseId": "123", "name": "weather", "input": {"city": "Paris"}}}, + { + "toolUse": { + "toolUseId": "123", + "name": "weather", + "input": {"city": "Paris"}, + } + }, {"cachePoint": {"type": "default"}}, ], }, { "role": "user", "content": [ - {"toolResult": {"toolUseId": "123", "content": [{"text": "Sunny and 25°C"}]}}, + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "Sunny and 25°C"}], + } + }, {"cachePoint": {"type": "default"}}, ], }, { "role": "assistant", - "content": [{"text": "The weather in Paris is sunny and 25°C."}, {"cachePoint": {"type": "default"}}], + "content": [ + {"text": "The weather in Paris is sunny and 25°C."}, + {"cachePoint": {"type": "default"}}, + ], }, ] @@ -247,25 +308,44 @@ def test_format_messages_tool_result_with_image(self): messages = [ ChatMessage.from_user("Retrieve the image and describe it in max 5 words."), ChatMessage.from_assistant( - tool_calls=[ToolCall(id="123", tool_name="image_retriever", arguments={"query": "random query"})] + tool_calls=[ + ToolCall( + id="123", + tool_name="image_retriever", + arguments={"query": "random query"}, + ) + ] ), ChatMessage.from_tool( tool_result=[ TextContent("Here's the retrieved image"), ImageContent(base64_image=base64_image, mime_type="image/png"), ], - origin=ToolCall(id="123", tool_name="image_retriever", arguments={"query": "random query"}), + origin=ToolCall( + id="123", + tool_name="image_retriever", + arguments={"query": "random query"}, + ), ), ChatMessage.from_assistant("Beautiful landscape with mountains"), ] formatted_system_prompts, formatted_messages = _format_messages(messages) assert formatted_system_prompts == [] assert formatted_messages == [ - {"role": "user", "content": [{"text": "Retrieve the image and describe it in max 5 words."}]}, + { + "role": "user", + "content": [{"text": "Retrieve the image and describe it in max 5 words."}], + }, { "role": "assistant", "content": [ - {"toolUse": {"toolUseId": "123", "name": "image_retriever", "input": {"query": "random query"}}} + { + "toolUse": { + "toolUseId": "123", + "name": "image_retriever", + "input": {"query": "random query"}, + } + } ], }, { @@ -276,13 +356,21 @@ def test_format_messages_tool_result_with_image(self): "toolUseId": "123", "content": [ {"text": "Here's the retrieved image"}, - {"image": {"format": "png", "source": {"bytes": base64.b64decode(base64_image)}}}, + { + "image": { + "format": "png", + "source": {"bytes": base64.b64decode(base64_image)}, + } + }, ], } } ], }, - {"role": "assistant", "content": [{"text": "Beautiful landscape with mountains"}]}, + { + "role": "assistant", + "content": [{"text": "Beautiful landscape with mountains"}], + }, ] def test_format_message_thinking(self): @@ -334,7 +422,13 @@ def test_format_message_thinking(self): } }, {"text": "This is a test message with a tool call."}, - {"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}, + { + "toolUse": { + "toolUseId": "123", + "name": "test_tool", + "input": {"key": "value"}, + } + }, ], } @@ -352,14 +446,23 @@ def test_format_message_thinking(self): "content": [ {"reasoningContent": {"reasoningText": {"text": "[REDACTED]"}}}, {"text": "This is a test message with a tool call."}, - {"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}, + { + "toolUse": { + "toolUseId": "123", + "name": "test_tool", + "input": {"key": "value"}, + } + }, ], } def test_format_user_message(self): plain_user_message = ChatMessage.from_user("This is a test message.") formatted_message = _format_user_message(plain_user_message) - assert formatted_message == {"role": "user", "content": [{"text": "This is a test message."}]} + assert formatted_message == { + "role": "user", + "content": [{"text": "This is a test message."}], + } base64_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=" image_content = ImageContent(base64_image) @@ -369,7 +472,12 @@ def test_format_user_message(self): "role": "user", "content": [ {"text": "This is a test message."}, - {"image": {"format": "png", "source": {"bytes": base64.b64decode(base64_image)}}}, + { + "image": { + "format": "png", + "source": {"bytes": base64.b64decode(base64_image)}, + } + }, ], } @@ -418,10 +526,14 @@ def test_format_messages_multi_tool(self): "this information for you.", tool_calls=[ ToolCall( - tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q" + tool_name="weather_tool", + arguments={"location": "Berlin"}, + id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q", ), ToolCall( - tool_name="weather_tool", arguments={"location": "Paris"}, id="tooluse_Oc0n2we2RvquHwuPEflaQA" + tool_name="weather_tool", + arguments={"location": "Paris"}, + id="tooluse_Oc0n2we2RvquHwuPEflaQA", ), ], name=None, @@ -429,26 +541,37 @@ def test_format_messages_multi_tool(self): "model": "global.anthropic.claude-sonnet-4-6", "index": 0, "finish_reason": "tool_use", - "usage": {"prompt_tokens": 366, "completion_tokens": 134, "total_tokens": 500}, + "usage": { + "prompt_tokens": 366, + "completion_tokens": 134, + "total_tokens": 500, + }, }, ), ChatMessage.from_tool( tool_result="Mostly sunny", origin=ToolCall( - tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q" + tool_name="weather_tool", + arguments={"location": "Berlin"}, + id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q", ), ), ChatMessage.from_tool( tool_result="Mostly cloudy", origin=ToolCall( - tool_name="weather_tool", arguments={"location": "Paris"}, id="tooluse_Oc0n2we2RvquHwuPEflaQA" + tool_name="weather_tool", + arguments={"location": "Paris"}, + id="tooluse_Oc0n2we2RvquHwuPEflaQA", ), ), ] formatted_system_prompts, formatted_messages = _format_messages(messages) assert formatted_system_prompts == [] assert formatted_messages == [ - {"role": "user", "content": [{"text": "What is the weather in Berlin and Paris?"}]}, + { + "role": "user", + "content": [{"text": "What is the weather in Berlin and Paris?"}], + }, { "role": "assistant", "content": [ @@ -495,7 +618,12 @@ def test_format_messages_multi_tool(self): def test_extract_replies_from_text_response(self, mock_boto3_session): model = "global.anthropic.claude-sonnet-4-6" text_response = { - "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}}, + "output": { + "message": { + "role": "assistant", + "content": [{"text": "This is a test response"}], + } + }, "stopReason": "end_turn", "usage": { "inputTokens": 10, @@ -531,7 +659,15 @@ def test_extract_replies_from_tool_response(self, mock_boto3_session): "output": { "message": { "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}], + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "test_tool", + "input": {"key": "value"}, + } + } + ], } }, "stopReason": "tool_use", @@ -567,7 +703,13 @@ def test_extract_replies_from_text_mixed_response(self, mock_boto3_session): "role": "assistant", "content": [ {"text": "Let me help you with that. I'll use the search tool to find the answer."}, - {"toolUse": {"toolUseId": "456", "name": "search_tool", "input": {"query": "test"}}}, + { + "toolUse": { + "toolUseId": "456", + "name": "search_tool", + "input": {"query": "test"}, + } + }, ], } }, @@ -650,10 +792,14 @@ def test_extract_replies_from_multi_tool_response(self, mock_boto3_session): "information for you.", tool_calls=[ ToolCall( - tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q" + tool_name="weather_tool", + arguments={"location": "Berlin"}, + id="tooluse_evFtOFYeSiG_TQ0cAAgy4Q", ), ToolCall( - tool_name="weather_tool", arguments={"location": "Paris"}, id="tooluse_Oc0n2we2RvquHwuPEflaQA" + tool_name="weather_tool", + arguments={"location": "Paris"}, + id="tooluse_Oc0n2we2RvquHwuPEflaQA", ), ], name=None, @@ -731,7 +877,11 @@ def test_extract_replies_from_one_tool_response_with_thinking(self, mock_boto3_s expected_message = ChatMessage.from_assistant( text="I'll check the current weather in Paris for you.", tool_calls=[ - ToolCall(tool_name="weather", arguments={"city": "Paris"}, id="tooluse_iUqy8-ypSByLK5zFkka8uA") + ToolCall( + tool_name="weather", + arguments={"city": "Paris"}, + id="tooluse_iUqy8-ypSByLK5zFkka8uA", + ) ], reasoning=ReasoningContent( reasoning_text="The user wants to know the weather in Paris. I have a `weather` function available " @@ -767,7 +917,12 @@ def test_extract_replies_with_guardrail(self, mock_boto3_session): "test_guardrail_id": { "topicPolicy": { "topics": [ - {"name": "Investments topic", "type": "DENY", "action": "BLOCKED", "detected": True} + { + "name": "Investments topic", + "type": "DENY", + "action": "BLOCKED", + "detected": True, + } ] }, "invocationMetrics": { @@ -803,7 +958,10 @@ def test_extract_replies_with_guardrail(self, mock_boto3_session): "RetryAttempts": 0, }, "output": { - "message": {"role": "assistant", "content": [{"text": "Sorry, the model cannot answer this question."}]} + "message": { + "role": "assistant", + "content": [{"text": "Sorry, the model cannot answer this question."}], + } }, "stopReason": "guardrail_intervened", "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, @@ -873,7 +1031,13 @@ def test_parse_completion_response_with_citations(self, mock_boto3_session): ) } ], - "location": {"documentPage": {"documentIndex": 0, "start": 1, "end": 2}}, + "location": { + "documentPage": { + "documentIndex": 0, + "start": 1, + "end": 2, + } + }, } ], } @@ -932,33 +1096,112 @@ def test_callback(chunk: StreamingChunk): # Simulate a stream of events for both text and tool use events = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"text": "Certainly! I can"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " help you find out"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " the weather"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " in Berlin. To"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " get this information, I'll"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " use the weather tool available"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " to me."}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " Let me fetch"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " that data for"}, "contentBlockIndex": 0}}, + { + "contentBlockDelta": { + "delta": {"text": "Certainly! I can"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " help you find out"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " the weather"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " in Berlin. To"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " get this information, I'll"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " use the weather tool available"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " to me."}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " Let me fetch"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " that data for"}, + "contentBlockIndex": 0, + } + }, {"contentBlockDelta": {"delta": {"text": " you."}, "contentBlockIndex": 0}}, {"contentBlockStop": {"contentBlockIndex": 0}}, { "contentBlockStart": { - "start": {"toolUse": {"toolUseId": "tooluse_pLGRAmK7TNKoZQ_rntVN_Q", "name": "weather_tool"}}, + "start": { + "toolUse": { + "toolUseId": "tooluse_pLGRAmK7TNKoZQ_rntVN_Q", + "name": "weather_tool", + } + }, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": ""}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"'}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": 'location": '}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '"B'}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": 'erlin"}'}}, "contentBlockIndex": 1, } }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"'}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": 'location": '}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '"B'}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": 'erlin"}'}}, "contentBlockIndex": 1}}, {"contentBlockStop": {"contentBlockIndex": 1}}, {"messageStop": {"stopReason": "tool_use"}}, { "metadata": { - "usage": {"inputTokens": 364, "outputTokens": 71, "totalTokens": 435}, + "usage": { + "inputTokens": 364, + "outputTokens": 71, + "totalTokens": 435, + }, "metrics": {"latencyMs": 2449}, } }, @@ -976,7 +1219,9 @@ def test_callback(chunk: StreamingChunk): name=None, tool_calls=[ ToolCall( - tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_pLGRAmK7TNKoZQ_rntVN_Q" + tool_name="weather_tool", + arguments={"location": "Berlin"}, + id="tooluse_pLGRAmK7TNKoZQ_rntVN_Q", ) ], meta={ @@ -997,12 +1242,33 @@ def test_callback(chunk: StreamingChunk): expected_chunks = [ StreamingChunk(content="", meta=base_meta, component_info=c_info), - StreamingChunk(content="Certainly! I can", meta=base_meta, component_info=c_info, index=0, start=True), - StreamingChunk(content=" help you find out", meta=base_meta, component_info=c_info, index=0), + StreamingChunk( + content="Certainly! I can", + meta=base_meta, + component_info=c_info, + index=0, + start=True, + ), + StreamingChunk( + content=" help you find out", + meta=base_meta, + component_info=c_info, + index=0, + ), StreamingChunk(content=" the weather", meta=base_meta, component_info=c_info, index=0), StreamingChunk(content=" in Berlin. To", meta=base_meta, component_info=c_info, index=0), - StreamingChunk(content=" get this information, I'll", meta=base_meta, component_info=c_info, index=0), - StreamingChunk(content=" use the weather tool available", meta=base_meta, component_info=c_info, index=0), + StreamingChunk( + content=" get this information, I'll", + meta=base_meta, + component_info=c_info, + index=0, + ), + StreamingChunk( + content=" use the weather tool available", + meta=base_meta, + component_info=c_info, + index=0, + ), StreamingChunk(content=" to me.", meta=base_meta, component_info=c_info, index=0), StreamingChunk(content=" Let me fetch", meta=base_meta, component_info=c_info, index=0), StreamingChunk(content=" that data for", meta=base_meta, component_info=c_info, index=0), @@ -1013,7 +1279,13 @@ def test_callback(chunk: StreamingChunk): meta=base_meta, component_info=c_info, index=1, - tool_calls=[ToolCallDelta(index=1, tool_name="weather_tool", id="tooluse_pLGRAmK7TNKoZQ_rntVN_Q")], + tool_calls=[ + ToolCallDelta( + index=1, + tool_name="weather_tool", + id="tooluse_pLGRAmK7TNKoZQ_rntVN_Q", + ) + ], start=True, ), StreamingChunk( @@ -1052,7 +1324,12 @@ def test_callback(chunk: StreamingChunk): tool_calls=[ToolCallDelta(index=1, arguments='erlin"}')], ), StreamingChunk(content="", meta=base_meta, component_info=c_info), - StreamingChunk(content="", meta=base_meta, component_info=c_info, finish_reason="tool_calls"), + StreamingChunk( + content="", + meta=base_meta, + component_info=c_info, + finish_reason="tool_calls", + ), StreamingChunk( content="", meta={ @@ -1103,7 +1380,12 @@ def test_callback(chunk: StreamingChunk): "contentBlockIndex": 0, } }, - {"contentBlockDelta": {"delta": {"reasoningContent": {"text": " access to a"}}, "contentBlockIndex": 0}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " access to a"}}, + "contentBlockIndex": 0, + } + }, { "contentBlockDelta": { "delta": {"reasoningContent": {"text": " weather function that takes"}}, @@ -1134,26 +1416,75 @@ def test_callback(chunk: StreamingChunk): "contentBlockIndex": 0, } }, - {"contentBlockDelta": {"delta": {"reasoningContent": {"text": " function call."}}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "..."}}, "contentBlockIndex": 0}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": " function call."}}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"signature": "..."}}, + "contentBlockIndex": 0, + } + }, {"contentBlockStop": {"contentBlockIndex": 0}}, { "contentBlockStart": { - "start": {"toolUse": {"toolUseId": "tooluse_1gPhO4A1RNWgzKbt1PXWLg", "name": "weather"}}, + "start": { + "toolUse": { + "toolUseId": "tooluse_1gPhO4A1RNWgzKbt1PXWLg", + "name": "weather", + } + }, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": ""}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"ci'}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": "ty"}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '": "P'}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": "aris"}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '"}'}}, "contentBlockIndex": 1, } }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"ci'}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": "ty"}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '": "P'}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": "aris"}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '"}'}}, "contentBlockIndex": 1}}, {"contentBlockStop": {"contentBlockIndex": 1}}, {"messageStop": {"stopReason": "tool_use"}}, { "metadata": { - "usage": {"inputTokens": 412, "outputTokens": 104, "totalTokens": 516}, + "usage": { + "inputTokens": 412, + "outputTokens": 104, + "totalTokens": 516, + }, "metrics": {"latencyMs": 2134}, } }, @@ -1164,7 +1495,11 @@ def test_callback(chunk: StreamingChunk): expected_messages = [ ChatMessage.from_assistant( tool_calls=[ - ToolCall(tool_name="weather", arguments={"city": "Paris"}, id="tooluse_1gPhO4A1RNWgzKbt1PXWLg") + ToolCall( + tool_name="weather", + arguments={"city": "Paris"}, + id="tooluse_1gPhO4A1RNWgzKbt1PXWLg", + ) ], reasoning=ReasoningContent( reasoning_text="The user is asking about the weather in Paris. I have access to a weather function " @@ -1211,41 +1546,130 @@ def test_callback(chunk: StreamingChunk): events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockDelta": {"delta": {"text": "To"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " answer your question about the"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " weather in Berlin and Paris, I'll"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " need to use the weather_tool"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " for each city. Let"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": " me fetch that information for"}, "contentBlockIndex": 0}}, + { + "contentBlockDelta": { + "delta": {"text": " answer your question about the"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " weather in Berlin and Paris, I'll"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " need to use the weather_tool"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " for each city. Let"}, + "contentBlockIndex": 0, + } + }, + { + "contentBlockDelta": { + "delta": {"text": " me fetch that information for"}, + "contentBlockIndex": 0, + } + }, {"contentBlockDelta": {"delta": {"text": " you."}, "contentBlockIndex": 0}}, {"contentBlockStop": {"contentBlockIndex": 0}}, { "contentBlockStart": { - "start": {"toolUse": {"toolUseId": "tooluse_A0jTtaiQTFmqD_cIq8I1BA", "name": "weather_tool"}}, + "start": { + "toolUse": { + "toolUseId": "tooluse_A0jTtaiQTFmqD_cIq8I1BA", + "name": "weather_tool", + } + }, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": ""}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"location":'}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": ' "Be'}}, + "contentBlockIndex": 1, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": 'rlin"}'}}, "contentBlockIndex": 1, } }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"location":'}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": ' "Be'}}, "contentBlockIndex": 1}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": 'rlin"}'}}, "contentBlockIndex": 1}}, {"contentBlockStop": {"contentBlockIndex": 1}}, { "contentBlockStart": { - "start": {"toolUse": {"toolUseId": "tooluse_LTc2TUMgTRiobK5Z5CCNSw", "name": "weather_tool"}}, + "start": { + "toolUse": { + "toolUseId": "tooluse_LTc2TUMgTRiobK5Z5CCNSw", + "name": "weather_tool", + } + }, + "contentBlockIndex": 2, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": ""}}, + "contentBlockIndex": 2, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": '{"l'}}, + "contentBlockIndex": 2, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": "ocati"}}, + "contentBlockIndex": 2, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": 'on": "P'}}, + "contentBlockIndex": 2, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": "ari"}}, + "contentBlockIndex": 2, + } + }, + { + "contentBlockDelta": { + "delta": {"toolUse": {"input": 's"}'}}, "contentBlockIndex": 2, } }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 2}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"l'}}, "contentBlockIndex": 2}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": "ocati"}}, "contentBlockIndex": 2}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": 'on": "P'}}, "contentBlockIndex": 2}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": "ari"}}, "contentBlockIndex": 2}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": 's"}'}}, "contentBlockIndex": 2}}, {"contentBlockStop": {"contentBlockIndex": 2}}, {"messageStop": {"stopReason": "tool_use"}}, { "metadata": { - "usage": {"inputTokens": 366, "outputTokens": 83, "totalTokens": 449}, + "usage": { + "inputTokens": 366, + "outputTokens": 83, + "totalTokens": 449, + }, "metrics": {"latencyMs": 3194}, } }, @@ -1259,10 +1683,14 @@ def test_callback(chunk: StreamingChunk): name=None, tool_calls=[ ToolCall( - tool_name="weather_tool", arguments={"location": "Berlin"}, id="tooluse_A0jTtaiQTFmqD_cIq8I1BA" + tool_name="weather_tool", + arguments={"location": "Berlin"}, + id="tooluse_A0jTtaiQTFmqD_cIq8I1BA", ), ToolCall( - tool_name="weather_tool", arguments={"location": "Paris"}, id="tooluse_LTc2TUMgTRiobK5Z5CCNSw" + tool_name="weather_tool", + arguments={"location": "Paris"}, + id="tooluse_LTc2TUMgTRiobK5Z5CCNSw", ), ], meta={ @@ -1297,7 +1725,12 @@ def test_parse_streaming_response_with_guardrail(self, mock_boto3_session): "test_guardrail_id": { "topicPolicy": { "topics": [ - {"name": "Investments topic", "type": "DENY", "action": "BLOCKED", "detected": True} + { + "name": "Investments topic", + "type": "DENY", + "action": "BLOCKED", + "detected": True, + } ] }, "invocationMetrics": { @@ -1366,7 +1799,9 @@ def test_callback(chunk: StreamingChunk): ] assert replies == expected_messages - def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments(self): + def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments( + self, + ): chunks = [ StreamingChunk( content="Certainly! I", @@ -1504,7 +1939,12 @@ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments }, index=1, tool_calls=[ - ToolCallDelta(index=1, id="tooluse_QZlUqTveTwyUaCQGQbWP6g", tool_name="hello_world", arguments="") + ToolCallDelta( + index=1, + id="tooluse_QZlUqTveTwyUaCQGQbWP6g", + tool_name="hello_world", + arguments="", + ) ], ), StreamingChunk( @@ -1536,7 +1976,11 @@ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments meta={ "model": "global.anthropic.claude-sonnet-4-6", "received_at": "2025-07-31T08:46:08.596700+00:00", - "usage": {"prompt_tokens": 349, "completion_tokens": 84, "total_tokens": 433}, + "usage": { + "prompt_tokens": 349, + "completion_tokens": 84, + "total_tokens": 433, + }, }, ), ] @@ -1560,15 +2004,27 @@ def test_convert_streaming_chunks_to_chat_message_tool_call_with_empty_arguments assert message._meta["model"] == "global.anthropic.claude-sonnet-4-6" assert message._meta["index"] == 0 assert message._meta["finish_reason"] == "tool_calls" - assert message._meta["usage"] == {"completion_tokens": 84, "prompt_tokens": 349, "total_tokens": 433} + assert message._meta["usage"] == { + "completion_tokens": 84, + "prompt_tokens": 349, + "total_tokens": 433, + } def test_validate_guardrail_config_with_valid_configs(self): _validate_guardrail_config(guardrail_config=None, streaming=False) _validate_guardrail_config( - guardrail_config={"guardrailIdentifier": "test", "guardrailVersion": "test"}, streaming=False + guardrail_config={ + "guardrailIdentifier": "test", + "guardrailVersion": "test", + }, + streaming=False, ) _validate_guardrail_config( - guardrail_config={"guardrailIdentifier": "test", "guardrailVersion": "test"}, streaming=True + guardrail_config={ + "guardrailIdentifier": "test", + "guardrailVersion": "test", + }, + streaming=True, ) _validate_guardrail_config( guardrail_config={ @@ -1580,11 +2036,20 @@ def test_validate_guardrail_config_with_valid_configs(self): ) def test_validate_guardrail_config_with_invalid_configs(self): - with pytest.raises(ValueError, match="`guardrailIdentifier` and `guardrailVersion` fields are required"): + with pytest.raises( + ValueError, + match="`guardrailIdentifier` and `guardrailVersion` fields are required", + ): _validate_guardrail_config(guardrail_config={"guardrailIdentifier": "test"}, streaming=False) - with pytest.raises(ValueError, match="`guardrailIdentifier` and `guardrailVersion` fields are required"): + with pytest.raises( + ValueError, + match="`guardrailIdentifier` and `guardrailVersion` fields are required", + ): _validate_guardrail_config(guardrail_config={"guardrailVersion": "test"}, streaming=False) - with pytest.raises(ValueError, match="`streamProcessingMode` field is only supported for streaming"): + with pytest.raises( + ValueError, + match="`streamProcessingMode` field is only supported for streaming", + ): _validate_guardrail_config( guardrail_config={ "guardrailIdentifier": "test", @@ -1607,10 +2072,16 @@ def test_validate_and_format_cache_point(self): cache_point = _validate_and_format_cache_point({"type": "default", "ttl": "5m"}) assert cache_point == {"cachePoint": {"type": "default", "ttl": "5m"}} - with pytest.raises(ValueError, match=r"Cache point must have a 'type' key with value 'default'."): + with pytest.raises( + ValueError, + match=r"Cache point must have a 'type' key with value 'default'.", + ): _validate_and_format_cache_point({"invalid": "config"}) - with pytest.raises(ValueError, match=r"Cache point must have a 'type' key with value 'default'."): + with pytest.raises( + ValueError, + match=r"Cache point must have a 'type' key with value 'default'.", + ): _validate_and_format_cache_point({"type": "invalid"}) with pytest.raises(ValueError, match=r"Cache point can only contain 'type' and 'ttl' keys."): diff --git a/integrations/amazon_bedrock/tests/test_document_embedder.py b/integrations/amazon_bedrock/tests/test_document_embedder.py index 867f726efd..f996f5b240 100644 --- a/integrations/amazon_bedrock/tests/test_document_embedder.py +++ b/integrations/amazon_bedrock/tests/test_document_embedder.py @@ -11,7 +11,9 @@ AmazonBedrockConfigurationError, AmazonBedrockInferenceError, ) -from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockDocumentEmbedder +from haystack_integrations.components.embedders.amazon_bedrock import ( + AmazonBedrockDocumentEmbedder, +) TYPE = "haystack_integrations.components.embedders.amazon_bedrock.document_embedder.AmazonBedrockDocumentEmbedder" @@ -76,11 +78,31 @@ def test_to_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] | N expected_dict = { "type": TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "cohere.embed-english-v3", "input_type": "search_document", "batch_size": 32, @@ -98,11 +120,31 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] | data = { "type": TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "cohere.embed-english-v3", "input_type": "search_document", "batch_size": 32, @@ -151,11 +193,17 @@ def test_run_invocation_error(self, mock_boto3_session): def test_prepare_texts_to_embed_w_metadata(self, mock_boto3_session): documents = [ - Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) + Document( + content=f"document number {i}: content", + meta={"meta_field": f"meta_value {i}"}, + ) + for i in range(5) ] embedder = AmazonBedrockDocumentEmbedder( - model="cohere.embed-english-v3", meta_fields_to_embed=["meta_field"], embedding_separator=" | " + model="cohere.embed-english-v3", + meta_fields_to_embed=["meta_field"], + embedding_separator=" | ", ) prepared_texts = embedder._prepare_texts_to_embed(documents) @@ -332,11 +380,17 @@ def test_run_titan_does_not_modify_original_documents(self, mock_boto3_session): or not os.getenv("AWS_DEFAULT_REGION"), reason="AWS credentials are not set", ) - @pytest.mark.parametrize("model", ["cohere.embed-v4:0", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"]) + @pytest.mark.parametrize( + "model", + ["cohere.embed-v4:0", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"], + ) def test_live_run(self, model): embedder = AmazonBedrockDocumentEmbedder(model=model) - documents = [Document(content="this is a test document"), Document(content="I love pizza")] + documents = [ + Document(content="this is a test document"), + Document(content="I love pizza"), + ] docs_with_embeddings = embedder.run(documents=documents)["documents"] diff --git a/integrations/amazon_bedrock/tests/test_document_image_embedder.py b/integrations/amazon_bedrock/tests/test_document_image_embedder.py index 7a4f339ff1..8b63c86767 100644 --- a/integrations/amazon_bedrock/tests/test_document_image_embedder.py +++ b/integrations/amazon_bedrock/tests/test_document_image_embedder.py @@ -12,7 +12,9 @@ AmazonBedrockConfigurationError, AmazonBedrockInferenceError, ) -from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockDocumentImageEmbedder +from haystack_integrations.components.embedders.amazon_bedrock import ( + AmazonBedrockDocumentImageEmbedder, +) TYPE = ( "haystack_integrations.components.embedders.amazon_bedrock." @@ -22,7 +24,11 @@ @pytest.fixture def image_paths(test_files_path): - return [test_files_path / "apple.jpg", test_files_path / "haystack-logo.png", test_files_path / "sample_pdf_1.pdf"] + return [ + test_files_path / "apple.jpg", + test_files_path / "haystack-logo.png", + test_files_path / "sample_pdf_1.pdf", + ] class TestAmazonBedrockDocumentImageEmbedder: @@ -76,11 +82,31 @@ def test_to_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] | N expected_dict = { "type": TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "cohere.embed-english-v3", "file_path_meta_field": "file_path", "embedding_types": ["float"], @@ -98,11 +124,31 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] | data = { "type": TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "cohere.embed-english-v3", "embedding_types": ["float"], "root_path": None, diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index 9452b0948f..eab81eb8dd 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -7,7 +7,9 @@ from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, ) -from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator +from haystack_integrations.components.generators.amazon_bedrock import ( + AmazonBedrockGenerator, +) from haystack_integrations.components.generators.amazon_bedrock.adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, @@ -26,17 +28,40 @@ def test_to_dict(mock_boto3_session: Any, boto3_config: dict[str, Any] | None): Test that the to_dict method returns the correct dictionary without aws credentials """ generator = AmazonBedrockGenerator( - model="anthropic.claude-v2", max_length=99, temperature=10, boto3_config=boto3_config + model="anthropic.claude-v2", + max_length=99, + temperature=10, + boto3_config=boto3_config, ) expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "anthropic.claude-v2", "max_length": 99, "temperature": 10, @@ -58,11 +83,31 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: dict[str, Any] | None) { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "anthropic.claude-v2", "max_length": 99, "boto3_config": boto3_config, @@ -147,7 +192,10 @@ def test_constructor_with_empty_model(): ("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial ("amazon.titan-text-lite-v1", AmazonTitanAdapter), ("amazon.titan-text-express-v1", AmazonTitanAdapter), - ("us.amazon.titan-text-express-v1", AmazonTitanAdapter), # cross-region inference + ( + "us.amazon.titan-text-express-v1", + AmazonTitanAdapter, + ), # cross-region inference ("amazon.titan-text-agile-v1", AmazonTitanAdapter), ("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial ("meta.llama2-13b-chat-v1", MetaLlamaAdapter), @@ -161,8 +209,14 @@ def test_constructor_with_empty_model(): ("mistral.mistral-7b-instruct-v0:2", MistralAdapter), ("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), ("mistral.mistral-large-2402-v1:0", MistralAdapter), - ("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference - ("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference + ( + "eu.mistral.mixtral-8x7b-instruct-v0:1", + MistralAdapter, + ), # cross-region inference + ( + "us.mistral.mistral-large-2402-v1:0", + MistralAdapter, + ), # cross-region inference ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial ], ) @@ -497,7 +551,11 @@ def test_get_stream_responses_with_thinking(self) -> None: call( StreamingChunk( content="", - meta={"type": "content_block_start", "content_block": {"type": "thinking"}, "index": 0}, + meta={ + "type": "content_block_start", + "content_block": {"type": "thinking"}, + "index": 0, + }, ) ), call(StreamingChunk(content="This", meta={"delta": {"thinking": "This"}})), @@ -508,7 +566,11 @@ def test_get_stream_responses_with_thinking(self) -> None: call( StreamingChunk( content="\n\n", - meta={"type": "content_block_start", "content_block": {"type": "text"}, "index": 1}, + meta={ + "type": "content_block_start", + "content_block": {"type": "text"}, + "index": 1, + }, ) ), call(StreamingChunk(content="This", meta={"delta": {"text": "This"}})), @@ -580,7 +642,11 @@ def test_get_stream_responses_with_thinking_custom_thinking_tag(self) -> None: call( StreamingChunk( content="", - meta={"type": "content_block_start", "content_block": {"type": "thinking"}, "index": 0}, + meta={ + "type": "content_block_start", + "content_block": {"type": "thinking"}, + "index": 0, + }, ) ), call(StreamingChunk(content="This", meta={"delta": {"thinking": "This"}})), @@ -591,7 +657,11 @@ def test_get_stream_responses_with_thinking_custom_thinking_tag(self) -> None: call( StreamingChunk( content="\n\n", - meta={"type": "content_block_start", "content_block": {"type": "text"}, "index": 1}, + meta={ + "type": "content_block_start", + "content_block": {"type": "text"}, + "index": 1, + }, ) ), call(StreamingChunk(content="This", meta={"delta": {"text": "This"}})), @@ -630,7 +700,11 @@ def test_get_stream_responses_with_thinking_no_thinking_tag(self) -> None: call( StreamingChunk( content="", - meta={"type": "content_block_start", "content_block": {"type": "thinking"}, "index": 0}, + meta={ + "type": "content_block_start", + "content_block": {"type": "thinking"}, + "index": 0, + }, ) ), call(StreamingChunk(content="This", meta={"delta": {"thinking": "This"}})), @@ -641,7 +715,11 @@ def test_get_stream_responses_with_thinking_no_thinking_tag(self) -> None: call( StreamingChunk( content="\n\n", - meta={"type": "content_block_start", "content_block": {"type": "text"}, "index": 1}, + meta={ + "type": "content_block_start", + "content_block": {"type": "text"}, + "index": 1, + }, ) ), call(StreamingChunk(content="This", meta={"delta": {"text": "This"}})), @@ -652,7 +730,9 @@ def test_get_stream_responses_with_thinking_no_thinking_tag(self) -> None: ] ) - def test_get_stream_responses_with_thinking_redacted_thinking_is_ignored(self) -> None: + def test_get_stream_responses_with_thinking_redacted_thinking_is_ignored( + self, + ) -> None: stream_mock = MagicMock() streaming_callback_mock = MagicMock() @@ -687,7 +767,11 @@ def test_get_stream_responses_with_thinking_redacted_thinking_is_ignored(self) - call( StreamingChunk( content="", - meta={"type": "content_block_start", "content_block": {"type": "thinking"}, "index": 1}, + meta={ + "type": "content_block_start", + "content_block": {"type": "thinking"}, + "index": 1, + }, ) ), call(StreamingChunk(content="This", meta={"delta": {"thinking": "This"}})), @@ -698,7 +782,11 @@ def test_get_stream_responses_with_thinking_redacted_thinking_is_ignored(self) - call( StreamingChunk( content="\n\n", - meta={"type": "content_block_start", "content_block": {"type": "text"}, "index": 2}, + meta={ + "type": "content_block_start", + "content_block": {"type": "text"}, + "index": 2, + }, ) ), call(StreamingChunk(content="This", meta={"delta": {"text": "This"}})), @@ -856,7 +944,11 @@ class TestMistralAdapter: def test_prepare_body_with_default_params(self) -> None: layer = MistralAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" - expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_tokens": 99, "stop": []} + expected_body = { + "prompt": "[INST] Hello, how are you? [/INST]", + "max_tokens": 99, + "stop": [], + } body = layer.prepare_body(prompt) assert body == expected_body @@ -1159,7 +1251,12 @@ def test_get_stream_responses(self) -> None: call(StreamingChunk(content=" a", meta={"text": " a"})), call(StreamingChunk(content=" single", meta={"text": " single"})), call(StreamingChunk(content=" response.", meta={"text": " response."})), - call(StreamingChunk(content="", meta={"finish_reason": "MAX_TOKENS", "is_finished": True})), + call( + StreamingChunk( + content="", + meta={"finish_reason": "MAX_TOKENS", "is_finished": True}, + ) + ), ] ) @@ -1185,7 +1282,10 @@ def test_prepare_body(self) -> None: ], "documents": [ {"title": "France", "snippet": "Paris is the capital of France."}, - {"title": "Germany", "snippet": "Berlin is the capital of Germany."}, + { + "title": "Germany", + "snippet": "Berlin is the capital of Germany.", + }, ], "search_query_only": False, "preamble": "preamble", @@ -1213,9 +1313,15 @@ def test_prepare_body(self) -> None: ], "tool_results": [ { - "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}}, + "call": { + "name": "query_daily_sales_report", + "parameters": {"day": "2023-09-29"}, + }, "outputs": [ - {"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"} + { + "date": "2023-09-29", + "summary": "Total Sales Amount: 10000, Total Units Sold: 250", + } ], } ], @@ -1263,8 +1369,16 @@ def test_prepare_body(self) -> None: ], "tool_results": [ { - "call": {"name": "query_daily_sales_report", "parameters": {"day": "2023-09-29"}}, - "outputs": [{"date": "2023-09-29", "summary": "Total Sales Amount: 10000, Total Units Sold: 250"}], + "call": { + "name": "query_daily_sales_report", + "parameters": {"day": "2023-09-29"}, + }, + "outputs": [ + { + "date": "2023-09-29", + "summary": "Total Sales Amount: 10000, Total Units Sold: 250", + } + ], } ], "stop_sequences": ["\n\n"], @@ -1714,7 +1828,10 @@ def test_run_with_metadata(self, mock_boto3_session) -> None: "ResponseMetadata": { "RequestId": "test-request-id", "HTTPStatusCode": 200, - "HTTPHeaders": {"x-amzn-requestid": "test-request-id", "content-type": "application/json"}, + "HTTPHeaders": { + "x-amzn-requestid": "test-request-id", + "content-type": "application/json", + }, }, } mock_client.invoke_model.return_value = mock_response diff --git a/integrations/amazon_bedrock/tests/test_ranker.py b/integrations/amazon_bedrock/tests/test_ranker.py index d97aca307a..367301ec89 100644 --- a/integrations/amazon_bedrock/tests/test_ranker.py +++ b/integrations/amazon_bedrock/tests/test_ranker.py @@ -39,7 +39,12 @@ def test_bedrock_ranker_run(mock_aws_session): aws_region_name=Secret.from_token("us-west-2"), ) - mock_response = {"results": [{"index": 0, "relevanceScore": 0.9}, {"index": 1, "relevanceScore": 0.7}]} + mock_response = { + "results": [ + {"index": 0, "relevanceScore": 0.9}, + {"index": 1, "relevanceScore": 0.7}, + ] + } mock_aws_session.rerank.return_value = mock_response diff --git a/integrations/amazon_bedrock/tests/test_s3_downloader.py b/integrations/amazon_bedrock/tests/test_s3_downloader.py index 61486e29fe..ff7d5a4826 100644 --- a/integrations/amazon_bedrock/tests/test_s3_downloader.py +++ b/integrations/amazon_bedrock/tests/test_s3_downloader.py @@ -69,11 +69,31 @@ def test_to_dict(self, mock_boto3_session: Any, tmp_path, boto3_config: dict[str expected = { "type": TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "file_root_path": str(tmp_path), "file_extensions": None, "max_cache_size": 100, @@ -90,11 +110,31 @@ def test_from_dict(self, mock_boto3_session: Any, tmp_path, boto3_config: dict[s data = { "type": TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "file_root_path": str(tmp_path), "s3_key_generation_function": None, "s3_bucket_name_env": "S3_DOWNLOADER_BUCKET", @@ -116,11 +156,31 @@ def test_to_dict_with_parameters(self, tmp_path): expected = { "type": TYPE, "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "file_root_path": str(tmp_path), "file_extensions": [".txt"], "max_cache_size": 400, @@ -167,7 +227,10 @@ def test_run_with_input_file_meta_key(self, tmp_path, mock_s3_storage, mock_boto assert out["documents"][0].meta["custom_file_key"] == "a.txt" def test_run_with_s3_key_generation_function(self, tmp_path, mock_s3_storage, mock_boto3_session): - d = S3Downloader(file_root_path=str(tmp_path), s3_key_generation_function=s3_key_generation_function) + d = S3Downloader( + file_root_path=str(tmp_path), + s3_key_generation_function=s3_key_generation_function, + ) d._storage = mock_s3_storage docs = [Document(meta={"file_id": str(uuid4()), "file_name": "a.txt"})] diff --git a/integrations/amazon_bedrock/tests/test_text_embedder.py b/integrations/amazon_bedrock/tests/test_text_embedder.py index e1e9fb98bd..7d5bf3bca3 100644 --- a/integrations/amazon_bedrock/tests/test_text_embedder.py +++ b/integrations/amazon_bedrock/tests/test_text_embedder.py @@ -9,7 +9,9 @@ AmazonBedrockConfigurationError, AmazonBedrockInferenceError, ) -from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder +from haystack_integrations.components.embedders.amazon_bedrock import ( + AmazonBedrockTextEmbedder, +) class TestAmazonBedrockTextEmbedder: @@ -52,11 +54,31 @@ def test_to_dict(self, mock_boto3_session): expected_dict = { "type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "cohere.embed-english-v3", "input_type": "search_query", "boto3_config": None, @@ -69,11 +91,31 @@ def test_from_dict(self, mock_boto3_session): data = { "type": "haystack_integrations.components.embedders.amazon_bedrock.text_embedder.AmazonBedrockTextEmbedder", "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "cohere.embed-english-v3", "input_type": "search_query", "boto3_config": { @@ -162,7 +204,10 @@ def test_run_invocation_error(self, mock_boto3_session): or not os.getenv("AWS_DEFAULT_REGION"), reason="AWS credentials are not set", ) - @pytest.mark.parametrize("model", ["cohere.embed-v4:0", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"]) + @pytest.mark.parametrize( + "model", + ["cohere.embed-v4:0", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"], + ) def test_live_run(self, model): embedder = AmazonBedrockTextEmbedder(model=model)