Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion integrations/amazon_bedrock/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.24.1", "boto3>=1.28.57", "aioboto3>=14.0.0"]
dependencies = ["haystack-ai>=2.24.1", "boto3>=1.42.84", "aioboto3>=14.0.0"]



[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme"
Expand Down Expand Up @@ -172,6 +174,12 @@ omit = ["*/tests/*", "*/__init__.py"]
show_missing = true
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]

[tool.uv]
# aiobotocore 2.25.1 pins botocore<1.40.62 but works fine with newer botocore in practice.
# outputConfig support in the Bedrock Converse API requires botocore>=1.42.84.
# Override the transitive constraint so uv can resolve the dependency graph.
override-dependencies = ["botocore>=1.42.84", "boto3>=1.42.84"]

[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import json
from typing import Any

import aioboto3
from botocore.config import Config
from botocore.eventstream import EventStream
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, select_streaming_callback
from haystack.dataclasses import (
ChatMessage,
ComponentInfo,
StreamingCallbackT,
select_streaming_callback,
)
from haystack.tools import (
ToolsType,
_check_duplicate_tool_names,
Expand All @@ -14,7 +20,10 @@
serialize_tools_or_toolset,
)
from haystack.utils.auth import Secret
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.callable_serialization import (
deserialize_callable,
serialize_callable,
)

from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
Expand All @@ -27,6 +36,7 @@
_parse_completion_response,
_parse_streaming_response,
_parse_streaming_response_async,
_parse_structured_output,
_validate_and_format_cache_point,
_validate_guardrail_config,
)
Expand Down Expand Up @@ -164,13 +174,11 @@ def weather(city: str):
def __init__(
self,
model: str,
aws_access_key_id: Secret | None = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
aws_secret_access_key: Secret | None = Secret.from_env_var( # noqa: B008
["AWS_SECRET_ACCESS_KEY"], strict=False
),
aws_session_token: Secret | None = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
aws_region_name: Secret | None = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
aws_profile_name: Secret | None = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
aws_access_key_id: Secret | None = None,
aws_secret_access_key: Secret | None = None,
aws_session_token: Secret | None = None,
aws_region_name: Secret | None = None,
aws_profile_name: Secret | None = None,
generation_kwargs: dict[str, Any] | None = None,
streaming_callback: StreamingCallbackT | None = None,
boto3_config: dict[str, Any] | None = None,
Expand Down Expand Up @@ -236,6 +244,22 @@ def __init__(
msg = "'model' cannot be None or empty string"
raise ValueError(msg)
self.model = model
if aws_access_key_id is None:
aws_access_key_id = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False)

if aws_secret_access_key is None:
aws_secret_access_key = Secret.from_env_var(
["AWS_SECRET_ACCESS_KEY"], strict=False
)

if aws_session_token is None:
aws_session_token = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False)

if aws_region_name is None:
aws_region_name = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False)

if aws_profile_name is None:
aws_profile_name = Secret.from_env_var(["AWS_PROFILE"], strict=False)
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
Expand All @@ -247,18 +271,23 @@ def __init__(
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
self.tools = tools

_validate_guardrail_config(guardrail_config=guardrail_config, streaming=streaming_callback is not None)
_validate_guardrail_config(
guardrail_config=guardrail_config, streaming=streaming_callback is not None
)
self.guardrail_config = guardrail_config

self.tools_cachepoint_config = (
_validate_and_format_cache_point(tools_cachepoint_config) if tools_cachepoint_config else None
_validate_and_format_cache_point(tools_cachepoint_config)
if tools_cachepoint_config
else None
)

def resolve_secret(secret: Secret | None) -> str | None:
return secret.resolve_value() if secret else None

config = Config(
user_agent_extra="x-client-framework:haystack", **(self.boto3_config if self.boto3_config else {})
user_agent_extra="x-client-framework:haystack",
**(self.boto3_config if self.boto3_config else {}),
)

try:
Expand Down Expand Up @@ -300,13 +329,31 @@ def _get_async_session(self) -> aioboto3.Session:

try:
self.async_session = get_aws_session(
aws_access_key_id=self.aws_access_key_id.resolve_value() if self.aws_access_key_id else None,
aws_access_key_id=(
self.aws_access_key_id.resolve_value()
if self.aws_access_key_id
else None
),
aws_secret_access_key=(
self.aws_secret_access_key.resolve_value() if self.aws_secret_access_key else None
self.aws_secret_access_key.resolve_value()
if self.aws_secret_access_key
else None
),
aws_session_token=(
self.aws_session_token.resolve_value()
if self.aws_session_token
else None
),
aws_region_name=(
self.aws_region_name.resolve_value()
if self.aws_region_name
else None
),
aws_profile_name=(
self.aws_profile_name.resolve_value()
if self.aws_profile_name
else None
),
aws_session_token=self.aws_session_token.resolve_value() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.resolve_value() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.resolve_value() if self.aws_profile_name else None,
async_mode=True,
)
return self.async_session
Expand All @@ -325,7 +372,11 @@ def to_dict(self) -> dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
callback_name = (
serialize_callable(self.streaming_callback)
if self.streaming_callback
else None
)
return default_to_dict(
self,
aws_access_key_id=self.aws_access_key_id,
Expand Down Expand Up @@ -360,7 +411,9 @@ def from_dict(cls, data: dict[str, Any]) -> "AmazonBedrockChatGenerator":

serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
data["init_parameters"]["streaming_callback"] = deserialize_callable(
serialized_callback_handler
)
deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
return default_from_dict(cls, data)

Expand All @@ -385,6 +438,26 @@ def _prepare_request_params(
- `stopSequences`: List of stop sequences to stop generation.
- `temperature`: Sampling temperature.
- `topP`: Nucleus sampling parameter.
- `json_schema`: Request structured JSON output validated against a schema. Provide a dict with:
- `schema` (required): a JSON Schema dict describing the expected output structure.
- `name` (optional): a name for the schema, defaults to ``"response_schema"``.
- `description` (optional): a description of the schema.

Example::

generation_kwargs={
"json_schema": {
"name": "person",
"schema": {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
"required": ["name", "age"],
"additionalProperties": False,
},
}
}

When set, the parsed JSON object is stored in ``reply.meta["structured_output"]``.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.
:param requires_async: Boolean flag to indicate if an async-compatible streaming callback function is needed.
Expand Down Expand Up @@ -413,13 +486,36 @@ def _prepare_request_params(
flattened_tools = flatten_tools_or_toolsets(tools)
_check_duplicate_tool_names(flattened_tools)
tool_config = merged_kwargs.pop("toolConfig", None)
json_schema = merged_kwargs.pop("json_schema", None)
if flattened_tools:
# Format Haystack tools to Bedrock format
tool_config = _format_tools(flattened_tools, tools_cachepoint_config=self.tools_cachepoint_config)
tool_config = _format_tools(
flattened_tools, tools_cachepoint_config=self.tools_cachepoint_config
)

# Any remaining kwargs go to additionalModelRequestFields
additional_fields = merged_kwargs if merged_kwargs else None

# Build outputConfig from json_schema for structured output support.
# See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html
output_config: dict[str, Any] | None = None
if json_schema is not None:
if "schema" not in json_schema:
msg = "'json_schema' must contain a 'schema' key with the JSON Schema dict."
raise ValueError(msg)
json_schema_block: dict[str, Any] = {
"name": json_schema.get("name", "response_schema"),
"schema": json.dumps(json_schema["schema"]),
}
if "description" in json_schema:
json_schema_block["description"] = json_schema["description"]
output_config = {
"textFormat": {
"type": "json_schema",
"structure": {"jsonSchema": json_schema_block},
}
}

# Format messages to Bedrock format
system_prompts, messages_list = _format_messages(messages)

Expand All @@ -434,6 +530,8 @@ def _prepare_request_params(
params["toolConfig"] = tool_config
if additional_fields:
params["additionalModelRequestFields"] = additional_fields
if output_config:
params["outputConfig"] = output_config
if self.guardrail_config:
params["guardrailConfig"] = self.guardrail_config

Expand All @@ -447,10 +545,14 @@ def _prepare_request_params(

return params, callback

def _resolve_flattened_generation_kwargs(self, generation_kwargs: dict[str, Any]) -> dict[str, Any]:
def _resolve_flattened_generation_kwargs(
self, generation_kwargs: dict[str, Any]
) -> dict[str, Any]:
generation_kwargs = generation_kwargs.copy()

disable_parallel_tool_use = generation_kwargs.pop("disable_parallel_tool_use", None)
disable_parallel_tool_use = generation_kwargs.pop(
"disable_parallel_tool_use", None
)
parallel_tool_use = generation_kwargs.pop("parallel_tool_use", None)

if disable_parallel_tool_use is not None and parallel_tool_use is not None:
Expand Down Expand Up @@ -497,13 +599,16 @@ def run(
- `stopSequences`: List of stop sequences to stop generation.
- `temperature`: Sampling temperature.
- `topP`: Nucleus sampling parameter.
- `json_schema`: Request structured JSON output validated against a schema. See
:meth:`_prepare_request_params` for full details.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.

:returns:
A dictionary containing the model-generated replies under the `"replies"` key.
When ``json_schema`` is used, each reply's ``meta["structured_output"]`` contains the parsed JSON object.
:raises AmazonBedrockInferenceError:
If the Bedrock inference API call fails.
If the Bedrock inference API call fails or the model returns invalid JSON for structured output.
"""
component_info = ComponentInfo.from_component(self)

Expand Down Expand Up @@ -536,6 +641,9 @@ def run(
msg = f"Could not perform inference for Amazon Bedrock model {self.model} due to:\n{exception}"
raise AmazonBedrockInferenceError(msg) from exception

if "outputConfig" in params:
replies = _parse_structured_output(replies)

return {"replies": replies}

@component.output_types(replies=list[ChatMessage])
Expand All @@ -558,13 +666,16 @@ async def run_async(
- `stopSequences`: List of stop sequences to stop generation.
- `temperature`: Sampling temperature.
- `topP`: Nucleus sampling parameter.
- `json_schema`: Request structured JSON output validated against a schema. See
:meth:`_prepare_request_params` for full details.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
Each tool should have a unique name.

:returns:
A dictionary containing the model-generated replies under the `"replies"` key.
When ``json_schema`` is used, each reply's ``meta["structured_output"]`` contains the parsed JSON object.
:raises AmazonBedrockInferenceError:
If the Bedrock inference API call fails.
If the Bedrock inference API call fails or the model returns invalid JSON for structured output.
"""
component_info = ComponentInfo.from_component(self)

Expand All @@ -581,7 +692,8 @@ async def run_async(
# Note: https://aioboto3.readthedocs.io/en/latest/usage.html
# we need to create a new client for each request
config = Config(
user_agent_extra="x-client-framework:haystack", **(self.boto3_config if self.boto3_config else {})
user_agent_extra="x-client-framework:haystack",
**(self.boto3_config if self.boto3_config else {}),
)
async with session.client("bedrock-runtime", config=config) as async_client:
if callback:
Expand All @@ -605,4 +717,7 @@ async def run_async(
msg = f"Could not perform inference for Amazon Bedrock model {self.model} due to:\n{exception}"
raise AmazonBedrockInferenceError(msg) from exception

if "outputConfig" in params:
replies = _parse_structured_output(replies)

return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
)
from haystack.tools import Tool

from haystack_integrations.common.amazon_bedrock.errors import AmazonBedrockInferenceError

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -713,6 +715,28 @@ async def _parse_streaming_response_async(
return replies


def _parse_structured_output(replies: list[ChatMessage]) -> list[ChatMessage]:
"""
Parse JSON structured output from model replies and store it in message metadata.

When structured output is requested via ``json_schema`` in ``generation_kwargs``, the model
returns JSON text. This function parses that JSON and stores the resulting object in
``reply.meta["structured_output"]`` for each reply that contains text.

:param replies: List of ChatMessage objects returned by the model.
:returns: The same list with ``meta["structured_output"]`` populated on text replies.
:raises AmazonBedrockInferenceError: If the model's response is not valid JSON.
"""
for reply in replies:
if reply.text:
try:
reply.meta["structured_output"] = json.loads(reply.text)
except json.JSONDecodeError as e:
msg = f"Structured output was requested but the model returned invalid JSON: {e}"
raise AmazonBedrockInferenceError(msg) from e
return replies


def _validate_guardrail_config(guardrail_config: dict[str, str] | None = None, streaming: bool = False) -> None:
"""
Validate the guardrail configuration.
Expand Down
Loading
Loading