Skip to content

Commit 1e65115

Browse files
anakin87claude
andauthored
feat: Amazon Bedrock - accept str as ChatGenerator input; deprecate generator; migrate generator example to chat generator (#3398)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8dbd688 commit 1e65115

5 files changed

Lines changed: 72 additions & 13 deletions

File tree

integrations/amazon_bedrock/examples/embedders_generator_with_rag_example.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,23 @@
44
# Note: if you change the model, update the model-specific inference parameters.
55

66
from haystack import Document, Pipeline
7-
from haystack.components.builders import PromptBuilder
7+
from haystack.components.builders import ChatPromptBuilder
88
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
9+
from haystack.dataclasses import ChatMessage
910
from haystack.document_stores.in_memory import InMemoryDocumentStore
1011

1112
from haystack_integrations.components.embedders.amazon_bedrock import (
1213
AmazonBedrockDocumentEmbedder,
1314
AmazonBedrockTextEmbedder,
1415
)
15-
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator
16+
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
1617

17-
generator_model_name = "amazon.titan-text-lite-v1"
18+
generator_model_name = "amazon.nova-lite-v1:0"
1819
embedder_model_name = "amazon.titan-embed-text-v1"
1920

20-
prompt_template = """
21+
prompt_template = [
22+
ChatMessage.from_user(
23+
"""
2124
Context:
2225
{% for document in documents %}
2326
{{ document.content }}
@@ -30,6 +33,8 @@
3033
3134
Question: {{ question }}?
3235
"""
36+
)
37+
]
3338

3439
docs = [
3540
Document(content="User ABC is using Amazon SageMaker to train ML models."),
@@ -47,21 +52,21 @@
4752
pipe = Pipeline()
4853
pipe.add_component("text_embedder", AmazonBedrockTextEmbedder(embedder_model_name))
4954
pipe.add_component("retriever", InMemoryEmbeddingRetriever(doc_store, top_k=1))
50-
pipe.add_component("prompt_builder", PromptBuilder(prompt_template))
55+
pipe.add_component("prompt_builder", ChatPromptBuilder(prompt_template))
5156
pipe.add_component(
52-
"generator",
53-
AmazonBedrockGenerator(
57+
"llm",
58+
AmazonBedrockChatGenerator(
5459
model=generator_model_name,
5560
# model-specific inference parameters
5661
generation_kwargs={
57-
"maxTokenCount": 1024,
62+
"maxTokens": 1024,
5863
"temperature": 0.0,
5964
},
6065
),
6166
)
6267
pipe.connect("text_embedder", "retriever")
6368
pipe.connect("retriever", "prompt_builder")
64-
pipe.connect("prompt_builder", "generator")
69+
pipe.connect("prompt_builder.prompt", "llm.messages")
6570

6671

6772
question = "Which user is using IaaS services for Machine Learning?"
@@ -71,4 +76,4 @@
7176
"prompt_builder": {"question": question},
7277
}
7378
)
74-
results["generator"]["replies"]
79+
print(results["llm"]["replies"][0].text) # noqa: T201

integrations/amazon_bedrock/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.24.1", "boto3>=1.42.84,<2", "aiobotocore>=3.4.0,<4"]
26+
dependencies = ["haystack-ai>=2.30.0", "boto3>=1.42.84,<2", "aiobotocore>=3.4.0,<4"]
2727

2828

2929

integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from botocore.eventstream import EventStream
77
from botocore.exceptions import ClientError
88
from haystack import component, default_from_dict, default_to_dict, logging
9+
from haystack.components.generators.utils import _normalize_messages
910
from haystack.dataclasses import (
1011
ChatMessage,
1112
ComponentInfo,
@@ -520,7 +521,7 @@ def _resolve_flattened_generation_kwargs(generation_kwargs: dict[str, Any]) -> d
520521
@component.output_types(replies=list[ChatMessage])
521522
def run(
522523
self,
523-
messages: list[ChatMessage],
524+
messages: list[ChatMessage] | str,
524525
streaming_callback: StreamingCallbackT | None = None,
525526
generation_kwargs: dict[str, Any] | None = None,
526527
tools: ToolsType | None = None,
@@ -531,6 +532,7 @@ def run(
531532
Supports both standard and streaming responses depending on whether a streaming callback is provided.
532533
533534
:param messages: A list of `ChatMessage` objects forming the chat history.
535+
If a string is provided, it is converted to a list containing a ChatMessage with user role.
534536
:param streaming_callback: Optional callback for handling streaming outputs.
535537
:param generation_kwargs: Optional dictionary of generation parameters. Some common parameters are:
536538
- `maxTokens`: Maximum number of tokens to generate.
@@ -546,6 +548,7 @@ def run(
546548
:raises AmazonBedrockInferenceError:
547549
If the Bedrock inference API call fails.
548550
"""
551+
messages = _normalize_messages(messages)
549552
component_info = ComponentInfo.from_component(self)
550553

551554
params, callback = self._prepare_request_params(
@@ -582,7 +585,7 @@ def run(
582585
@component.output_types(replies=list[ChatMessage])
583586
async def run_async(
584587
self,
585-
messages: list[ChatMessage],
588+
messages: list[ChatMessage] | str,
586589
streaming_callback: StreamingCallbackT | None = None,
587590
generation_kwargs: dict[str, Any] | None = None,
588591
tools: ToolsType | None = None,
@@ -593,6 +596,7 @@ async def run_async(
593596
Designed for use cases where non-blocking or concurrent execution is desired.
594597
595598
:param messages: A list of `ChatMessage` objects forming the chat history.
599+
If a string is provided, it is converted to a list containing a ChatMessage with user role.
596600
:param streaming_callback: Optional async-compatible callback for handling streaming outputs.
597601
:param generation_kwargs: Optional dictionary of generation parameters. Some common parameters are:
598602
- `maxTokens`: Maximum number of tokens to generate.
@@ -608,6 +612,7 @@ async def run_async(
608612
:raises AmazonBedrockInferenceError:
609613
If the Bedrock inference API call fails.
610614
"""
615+
messages = _normalize_messages(messages)
611616
component_info = ComponentInfo.from_component(self)
612617

613618
params, callback = self._prepare_request_params(

integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,12 @@ def __init__(
142142
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
143143
not supported.
144144
"""
145+
warnings.warn(
146+
"The `AmazonBedrockGenerator` component is deprecated and will be removed in a future version. "
147+
"Use `AmazonBedrockChatGenerator` instead, which now also supports string inputs.",
148+
FutureWarning,
149+
stacklevel=2,
150+
)
145151
if not model:
146152
msg = "'model' cannot be None or empty string"
147153
raise ValueError(msg)

integrations/amazon_bedrock/tests/test_chat_generator.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,49 @@ async def test_run_async_completion(self, mock_boto3_session, set_env_variables)
779779
assert len(result["replies"]) == 1
780780
assert result["replies"][0].text == "Paris"
781781

782+
def test_run_with_string_input(self, mock_boto3_session, set_env_variables):
783+
generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6")
784+
generator.client = MagicMock()
785+
generator.client.converse.return_value = {
786+
"output": {"message": {"role": "assistant", "content": [{"text": "Paris"}]}},
787+
"stopReason": "end_turn",
788+
"usage": {"inputTokens": 10, "outputTokens": 5},
789+
"metrics": {"latencyMs": 100},
790+
}
791+
result = generator.run("What's the capital of France?")
792+
_, kwargs = generator.client.converse.call_args
793+
assert kwargs["messages"] == [{"content": [{"text": "What's the capital of France?"}], "role": "user"}]
794+
assert isinstance(result["replies"], list)
795+
assert len(result["replies"]) == 1
796+
assert isinstance(result["replies"][0], ChatMessage)
797+
798+
@pytest.mark.asyncio
799+
async def test_run_async_with_string_input(self, mock_boto3_session, set_env_variables):
800+
generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6")
801+
mock_response = {
802+
"output": {"message": {"role": "assistant", "content": [{"text": "Paris"}]}},
803+
"stopReason": "end_turn",
804+
"usage": {"inputTokens": 10, "outputTokens": 5},
805+
"metrics": {"latencyMs": 100},
806+
}
807+
mock_async_client = AsyncMock()
808+
mock_async_client.converse.return_value = mock_response
809+
810+
mock_session = MagicMock()
811+
mock_cm = MagicMock()
812+
mock_cm.__aenter__ = AsyncMock(return_value=mock_async_client)
813+
mock_cm.__aexit__ = AsyncMock(return_value=False)
814+
mock_session.create_client.return_value = mock_cm
815+
816+
generator.async_session = mock_session
817+
818+
result = await generator.run_async("What's the capital of France?")
819+
_, kwargs = mock_async_client.converse.call_args
820+
assert kwargs["messages"] == [{"content": [{"text": "What's the capital of France?"}], "role": "user"}]
821+
assert isinstance(result["replies"], list)
822+
assert len(result["replies"]) == 1
823+
assert isinstance(result["replies"][0], ChatMessage)
824+
782825

783826
# In the CI, those tests are skipped if AWS Authentication fails
784827
@pytest.mark.integration

0 commit comments

Comments
 (0)