Skip to content

Commit 316ee94

Browse files
authored
feat: support str in addition to Secret for aws_region_name (#3423)
1 parent 197032e commit 316ee94

14 files changed

Lines changed: 198 additions & 45 deletions

File tree

integrations/amazon_bedrock/src/haystack_integrations/components/downloaders/s3/s3_downloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
"AWS_SECRET_ACCESS_KEY", strict=False
3838
),
3939
aws_session_token: Secret | None = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
40-
aws_region_name: Secret | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
40+
aws_region_name: Secret | str | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
4141
aws_profile_name: Secret | None = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
4242
boto3_config: dict[str, Any] | None = None,
4343
file_root_path: str | None = None,
@@ -115,8 +115,8 @@ def __init__(
115115

116116
self._storage: S3Storage | None = None
117117

118-
def resolve_secret(secret: Secret | None) -> str | None:
119-
return secret.resolve_value() if secret else None
118+
def resolve_secret(secret: Secret | str | None) -> str | None:
119+
return secret.resolve_value() if isinstance(secret, Secret) else secret
120120

121121
self._session = get_aws_session(
122122
aws_access_key_id=resolve_secret(aws_access_key_id),

integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class AmazonBedrockDocumentEmbedder:
2929
```python
3030
import os
3131
from haystack.dataclasses import Document
32-
from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockDocumentEmbedder
32+
from haystack_integrations.components.embedders.amazon_bedrock import (
33+
AmazonBedrockDocumentEmbedder,
34+
)
3335
3436
os.environ["AWS_ACCESS_KEY_ID"] = "..."
3537
os.environ["AWS_SECRET_ACCESS_KEY_ID"] = "..."
@@ -43,7 +45,7 @@ class AmazonBedrockDocumentEmbedder:
4345
doc = Document(content="I love Paris in the winter.", meta={"name": "doc1"})
4446
4547
result = embedder.run([doc])
46-
print(result['documents'][0].embedding)
48+
print(result["documents"][0].embedding)
4749
4850
# [0.002, 0.032, 0.504, ...]
4951
```
@@ -57,7 +59,7 @@ def __init__(
5759
"AWS_SECRET_ACCESS_KEY", strict=False
5860
),
5961
aws_session_token: Secret | None = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
60-
aws_region_name: Secret | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
62+
aws_region_name: Secret | str | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
6163
aws_profile_name: Secret | None = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
6264
batch_size: int = 32,
6365
progress_bar: bool = True,
@@ -120,8 +122,8 @@ def __init__(
120122
self.boto3_config = boto3_config
121123
self.kwargs = kwargs
122124

123-
def resolve_secret(secret: Secret | None) -> str | None:
124-
return secret.resolve_value() if secret else None
125+
def resolve_secret(secret: Secret | str | None) -> str | None:
126+
return secret.resolve_value() if isinstance(secret, Secret) else secret
125127

126128
try:
127129
session = get_aws_session(

integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_image_embedder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
"AWS_SECRET_ACCESS_KEY", strict=False
7474
),
7575
aws_session_token: Secret | None = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
76-
aws_region_name: Secret | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
76+
aws_region_name: Secret | str | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
7777
aws_profile_name: Secret | None = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
7878
file_path_meta_field: str = "file_path",
7979
root_path: str | None = None,
@@ -146,8 +146,8 @@ def __init__(
146146
raise ValueError(msg)
147147
self.embedding_types = embedding_types
148148

149-
def resolve_secret(secret: Secret | None) -> str | None:
150-
return secret.resolve_value() if secret else None
149+
def resolve_secret(secret: Secret | str | None) -> str | None:
150+
return secret.resolve_value() if isinstance(secret, Secret) else secret
151151

152152
try:
153153
session = get_aws_session(

integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ class AmazonBedrockTextEmbedder:
2323
Usage example:
2424
```python
2525
import os
26-
from haystack_integrations.components.embedders.amazon_bedrock import AmazonBedrockTextEmbedder
26+
from haystack_integrations.components.embedders.amazon_bedrock import (
27+
AmazonBedrockTextEmbedder,
28+
)
2729
2830
os.environ["AWS_ACCESS_KEY_ID"] = "..."
2931
os.environ["AWS_SECRET_ACCESS_KEY_ID"] = "..."
@@ -48,7 +50,7 @@ def __init__(
4850
"AWS_SECRET_ACCESS_KEY", strict=False
4951
),
5052
aws_session_token: Secret | None = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
51-
aws_region_name: Secret | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
53+
aws_region_name: Secret | str | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
5254
aws_profile_name: Secret | None = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
5355
boto3_config: dict[str, Any] | None = None,
5456
**kwargs: Any,
@@ -97,8 +99,8 @@ def __init__(
9799
self.boto3_config = boto3_config
98100
self.kwargs = kwargs
99101

100-
def resolve_secret(secret: Secret | None) -> str | None:
101-
return secret.resolve_value() if secret else None
102+
def resolve_secret(secret: Secret | str | None) -> str | None:
103+
return secret.resolve_value() if isinstance(secret, Secret) else secret
102104

103105
try:
104106
session = get_aws_session(

integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,24 @@ class AmazonBedrockChatGenerator:
5555
**Usage example**
5656
5757
```python
58-
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
58+
from haystack_integrations.components.generators.amazon_bedrock import (
59+
AmazonBedrockChatGenerator,
60+
)
5961
from haystack.dataclasses import ChatMessage
6062
from haystack.components.generators.utils import print_streaming_chunk
6163
62-
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"),
63-
ChatMessage.from_user("What's Natural Language Processing?")]
64+
messages = [
65+
ChatMessage.from_system(
66+
"\\nYou are a helpful, respectful and honest assistant, answer in German only"
67+
),
68+
ChatMessage.from_user("What's Natural Language Processing?"),
69+
]
6470
6571
66-
client = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6",
67-
streaming_callback=print_streaming_chunk)
72+
client = AmazonBedrockChatGenerator(
73+
model="global.anthropic.claude-sonnet-4-6",
74+
streaming_callback=print_streaming_chunk,
75+
)
6876
client.run(messages, generation_kwargs={"max_tokens": 512})
6977
```
7078
@@ -152,7 +160,9 @@ def weather(city: str):
152160
To cache messages, you can use the `cachePoint` key in `ChatMessage.meta` attribute.
153161
154162
```python
155-
ChatMessage.from_user("Long message...", meta={"cachePoint": {"type": "default"}})
163+
ChatMessage.from_user(
164+
"Long message...", meta={"cachePoint": {"type": "default"}}
165+
)
156166
```
157167
158168
For more information, see the [Amazon Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html).
@@ -179,7 +189,7 @@ def __init__(
179189
["AWS_SECRET_ACCESS_KEY"], strict=False
180190
),
181191
aws_session_token: Secret | None = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
182-
aws_region_name: Secret | None = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
192+
aws_region_name: Secret | str | None = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
183193
aws_profile_name: Secret | None = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
184194
generation_kwargs: dict[str, Any] | None = None,
185195
streaming_callback: StreamingCallbackT | None = None,
@@ -219,12 +229,15 @@ def __init__(
219229
220230
Example::
221231
222-
generation_kwargs={
232+
generation_kwargs = {
223233
"response_format": {
224234
"name": "person",
225235
"schema": {
226236
"type": "object",
227-
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
237+
"properties": {
238+
"name": {"type": "string"},
239+
"age": {"type": "integer"},
240+
},
228241
"required": ["name", "age"],
229242
"additionalProperties": False,
230243
},
@@ -288,8 +301,8 @@ def __init__(
288301
_validate_and_format_cache_point(tools_cachepoint_config) if tools_cachepoint_config else None
289302
)
290303

291-
def resolve_secret(secret: Secret | None) -> str | None:
292-
return secret.resolve_value() if secret else None
304+
def resolve_secret(secret: Secret | str | None) -> str | None:
305+
return secret.resolve_value() if isinstance(secret, Secret) else secret
293306

294307
config = Config(
295308
user_agent_extra="x-client-framework:haystack",
@@ -401,7 +414,6 @@ def _prepare_request_params(
401414
tools: ToolsType | None = None,
402415
requires_async: bool = False,
403416
) -> tuple[dict[str, Any], StreamingCallbackT | None]:
404-
405417
generation_kwargs = generation_kwargs or {}
406418

407419
# Merge generation_kwargs with defaults
@@ -636,7 +648,11 @@ async def run_async(
636648
self.aws_secret_access_key.resolve_value() if self.aws_secret_access_key else None
637649
),
638650
aws_session_token=(self.aws_session_token.resolve_value() if self.aws_session_token else None),
639-
region_name=(self.aws_region_name.resolve_value() if self.aws_region_name else None),
651+
region_name=(
652+
self.aws_region_name.resolve_value()
653+
if isinstance(self.aws_region_name, Secret)
654+
else self.aws_region_name
655+
),
640656
config=config,
641657
) as async_client:
642658
if callback:

integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,12 @@ class AmazonBedrockGenerator:
4444
### Usage example
4545
4646
```python
47-
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator
48-
49-
generator = AmazonBedrockGenerator(
50-
model="anthropic.claude-v2",
51-
max_length=99
47+
from haystack_integrations.components.generators.amazon_bedrock import (
48+
AmazonBedrockGenerator,
5249
)
5350
51+
generator = AmazonBedrockGenerator(model="anthropic.claude-v2", max_length=99)
52+
5453
print(generator.run("Who is the best American actor?"))
5554
```
5655
@@ -106,7 +105,7 @@ def __init__(
106105
"AWS_SECRET_ACCESS_KEY", strict=False
107106
),
108107
aws_session_token: Secret | None = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
109-
aws_region_name: Secret | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
108+
aws_region_name: Secret | str | None = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
110109
aws_profile_name: Secret | None = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
111110
max_length: int | None = None,
112111
truncate: bool | None = None,
@@ -170,8 +169,8 @@ def __init__(
170169
self.kwargs = kwargs
171170
self.model_family = model_family
172171

173-
def resolve_secret(secret: Secret | None) -> str | None:
174-
return secret.resolve_value() if secret else None
172+
def resolve_secret(secret: Secret | str | None) -> str | None:
173+
return secret.resolve_value() if isinstance(secret, Secret) else secret
175174

176175
try:
177176
session = get_aws_session(

integrations/amazon_bedrock/src/haystack_integrations/components/rankers/amazon_bedrock/ranker.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ class AmazonBedrockRanker:
3131
```python
3232
from haystack import Document
3333
from haystack.utils import Secret
34-
from haystack_integrations.components.rankers.amazon_bedrock import AmazonBedrockRanker
34+
from haystack_integrations.components.rankers.amazon_bedrock import (
35+
AmazonBedrockRanker,
36+
)
3537
3638
ranker = AmazonBedrockRanker(
3739
model="cohere.rerank-v3-5:0",
3840
top_k=2,
39-
aws_region_name=Secret.from_token("eu-central-1")
41+
aws_region_name=Secret.from_token("eu-central-1"),
4042
)
4143
4244
docs = [Document(content="Paris"), Document(content="Berlin")]
@@ -66,7 +68,7 @@ def __init__(
6668
["AWS_SECRET_ACCESS_KEY"], strict=False
6769
),
6870
aws_session_token: Secret | None = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
69-
aws_region_name: Secret | None = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
71+
aws_region_name: Secret | str | None = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
7072
aws_profile_name: Secret | None = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
7173
max_chunks_per_doc: int | None = None,
7274
meta_fields_to_embed: list[str] | None = None,
@@ -104,8 +106,8 @@ def __init__(
104106
self.meta_fields_to_embed = meta_fields_to_embed or []
105107
self.meta_data_separator = meta_data_separator
106108

107-
def resolve_secret(secret: Secret | None) -> str | None:
108-
return secret.resolve_value() if secret else None
109+
def resolve_secret(secret: Secret | str | None) -> str | None:
110+
return secret.resolve_value() if isinstance(secret, Secret) else secret
109111

110112
try:
111113
session = get_aws_session(
@@ -199,8 +201,8 @@ def run(self, query: str, documents: list[Document], top_k: int | None = None) -
199201
if not documents:
200202
return {"documents": []}
201203

202-
def resolve_secret(secret: Secret | None) -> str | None:
203-
return secret.resolve_value() if secret else None
204+
def resolve_secret(secret: Secret | str | None) -> str | None:
205+
return secret.resolve_value() if isinstance(secret, Secret) else secret
204206

205207
region = resolve_secret(self.aws_region_name)
206208

integrations/amazon_bedrock/tests/test_chat_generator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,25 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] |
291291
assert generator.streaming_callback == print_streaming_chunk
292292
assert generator.boto3_config == boto3_config
293293

294+
def test_from_dict_aws_region_name(self, mock_boto3_session: Any):
295+
"""
296+
Test that aws_region_name as str value is correctly parsed
297+
"""
298+
generator = AmazonBedrockChatGenerator.from_dict(
299+
{
300+
"type": CLASS_TYPE,
301+
"init_parameters": {
302+
"aws_region_name": "my-fake-region",
303+
"model": "global.anthropic.claude-sonnet-4-6",
304+
},
305+
}
306+
)
307+
assert generator.model == "global.anthropic.claude-sonnet-4-6"
308+
assert generator.aws_region_name == "my-fake-region"
309+
310+
serialized = generator.to_dict()
311+
assert serialized["init_parameters"]["aws_region_name"] == "my-fake-region"
312+
294313
def test_default_constructor(self, mock_boto3_session, mock_aioboto3_session, set_env_variables):
295314
"""
296315
Test that the default constructor sets the correct values
@@ -448,7 +467,6 @@ def test_prepare_request_params_guardrail_config(self, mock_boto3_session, set_e
448467
}
449468

450469
def test_prepare_request_params_response_format(self, mock_boto3_session, set_env_variables):
451-
452470
generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6")
453471
schema = {
454472
"type": "object",
@@ -724,7 +742,6 @@ def test_run_completion(self, mock_boto3_session, set_env_variables):
724742
assert result["replies"][0].text == "Paris"
725743

726744
def test_run_client_error(self, mock_boto3_session, set_env_variables):
727-
728745
generator = AmazonBedrockChatGenerator(model="global.anthropic.claude-sonnet-4-6")
729746
generator.client = MagicMock()
730747
generator.client.converse.side_effect = ClientError(

integrations/amazon_bedrock/tests/test_document_embedder.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,25 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] |
165165
assert embedder.embedding_separator == "\n"
166166
assert embedder.boto3_config == boto3_config
167167

168+
def test_from_dict_aws_region_name(self, mock_boto3_session):
169+
"""
170+
Test that aws_region_name as str value is correctly parsed
171+
"""
172+
embedder = AmazonBedrockDocumentEmbedder.from_dict(
173+
{
174+
"type": TYPE,
175+
"init_parameters": {
176+
"aws_region_name": "my-fake-region",
177+
"model": "cohere.embed-english-v3",
178+
},
179+
}
180+
)
181+
assert embedder.model == "cohere.embed-english-v3"
182+
assert embedder.aws_region_name == "my-fake-region"
183+
184+
serialized = embedder.to_dict()
185+
assert serialized["init_parameters"]["aws_region_name"] == "my-fake-region"
186+
168187
def test_init_invalid_model(self):
169188
with pytest.raises(ValueError):
170189
AmazonBedrockDocumentEmbedder(model="")

integrations/amazon_bedrock/tests/test_document_image_embedder.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,25 @@ def test_from_dict(self, mock_boto3_session: Any, boto3_config: dict[str, Any] |
165165
assert embedder.progress_bar
166166
assert embedder.boto3_config == boto3_config
167167

168+
def test_from_dict_aws_region_name(self, mock_boto3_session):
169+
"""
170+
Test that aws_region_name as str value is correctly parsed
171+
"""
172+
embedder = AmazonBedrockDocumentImageEmbedder.from_dict(
173+
{
174+
"type": TYPE,
175+
"init_parameters": {
176+
"aws_region_name": "my-fake-region",
177+
"model": "cohere.embed-english-v3",
178+
},
179+
}
180+
)
181+
assert embedder.model == "cohere.embed-english-v3"
182+
assert embedder.aws_region_name == "my-fake-region"
183+
184+
serialized = embedder.to_dict()
185+
assert serialized["init_parameters"]["aws_region_name"] == "my-fake-region"
186+
168187
def test_init_invalid_model(self):
169188
with pytest.raises(ValueError):
170189
AmazonBedrockDocumentImageEmbedder(model="")

0 commit comments

Comments
 (0)