Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import os
import warnings
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast

import boto3
Expand All @@ -29,7 +30,7 @@

logger = logging.getLogger(__name__)

DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"
DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0"
Comment thread
afarntrog marked this conversation as resolved.
Outdated
Comment thread
dbschmigelski marked this conversation as resolved.
Outdated
DEFAULT_BEDROCK_REGION = "us-west-2"

BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
Expand All @@ -47,6 +48,7 @@

DEFAULT_READ_TIMEOUT = 120


class BedrockModel(Model):
"""AWS Bedrock model provider implementation.

Expand Down Expand Up @@ -129,13 +131,18 @@ def __init__(
if region_name and boto_session:
raise ValueError("Cannot specify both `region_name` and `boto_session`.")

self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID, include_tool_result_status="auto")
session = boto_session or boto3.Session()
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
self.config = BedrockModel.BedrockConfig(
model_id=DEFAULT_BEDROCK_MODEL_ID.format(
BedrockModel.get_model_prefix_with_warning(resolved_region, model_config)
),
include_tool_result_status="auto",
)
self.update_config(**model_config)

logger.debug("config=<%s> | initializing", self.config)

session = boto_session or boto3.Session()

# Add strands-agents to the request user agent
if boto_client_config:
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
Expand All @@ -150,8 +157,6 @@ def __init__(
else:
client_config = BotocoreConfig(user_agent_extra="strands-agents", read_timeout=DEFAULT_READ_TIMEOUT)

resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION

self.client = session.client(
service_name="bedrock-runtime",
config=client_config,
Expand Down Expand Up @@ -763,3 +768,38 @@ async def structured_output(
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")

yield {"output": output_model(**output_response)}

@staticmethod
def get_model_prefix_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str:
Comment thread
afarntrog marked this conversation as resolved.
Outdated
"""Get model prefix for bedrock model based on region.

If the region is not **known** to support inference then we show a helpful warning
that compliments the exception that Bedrock will throw.
If the customer provided a model_id in their config then we should not
show any warnings as this is only for the **default** model we provide.

Args:
region_name (str): region for bedrock model
model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init
"""
prefix_infr_map = {"ap": "apac"} # some inference endpoints can be a bit different then the region prefix
Comment thread
afarntrog marked this conversation as resolved.
Outdated
Comment thread
afarntrog marked this conversation as resolved.
Outdated
model_config = model_config or {}
prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1`
if prefix not in {"us", "eu", "ap"} and not model_config.get("model_id"):
warnings.warn(
f"""
================== WARNING ==================

This region {region_name} does not support
our default inference endpoint: {DEFAULT_BEDROCK_MODEL_ID.format(prefix)}.
Update the agent to pass in a 'model_id' like so:
```
Agent(..., model='valid_model_id', ...)
````
Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html

==================================================
""",
stacklevel=2,
)
return prefix_infr_map.get(prefix, prefix)
7 changes: 5 additions & 2 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from tests.fixtures.mock_session_repository import MockedSessionRepository
from tests.fixtures.mocked_model_provider import MockedModelProvider

# For unit testing we will use the the us inference
FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us")


@pytest.fixture
def mock_randint():
Expand Down Expand Up @@ -211,7 +214,7 @@ def test_agent__init__with_default_model():
agent = Agent()

assert isinstance(agent.model, BedrockModel)
assert agent.model.config["model_id"] == DEFAULT_BEDROCK_MODEL_ID
assert agent.model.config["model_id"] == FORMATTED_DEFAULT_MODEL_ID


def test_agent__init__with_explicit_model(mock_model):
Expand Down Expand Up @@ -891,7 +894,7 @@ def test_agent__del__(agent):
def test_agent_init_with_no_model_or_model_id():
agent = Agent()
assert agent.model is not None
assert agent.model.get_config().get("model_id") == DEFAULT_BEDROCK_MODEL_ID
assert agent.model.get_config().get("model_id") == FORMATTED_DEFAULT_MODEL_ID


def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint, agenerator):
Expand Down
77 changes: 76 additions & 1 deletion tests/strands/models/test_bedrock.py
Comment thread
afarntrog marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from strands.types.exceptions import ModelThrottledException
from strands.types.tools import ToolSpec

FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us")


@pytest.fixture
def session_cls():
Expand Down Expand Up @@ -119,7 +121,7 @@ def test__init__default_model_id(bedrock_client):
model = BedrockModel()

tru_model_id = model.get_config().get("model_id")
exp_model_id = DEFAULT_BEDROCK_MODEL_ID
exp_model_id = FORMATTED_DEFAULT_MODEL_ID

assert tru_model_id == exp_model_id

Expand Down Expand Up @@ -1477,3 +1479,76 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings
assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "wrong_param" in str(captured_warnings[0].message)


def test_get_model_prefix_with_warning_supported_regions_shows_no_warning(captured_warnings):
"""Test get_model_prefix_with_warning doesn't warn for supported region prefixes."""
prefix_us = BedrockModel.get_model_prefix_with_warning("us-west-2")
assert prefix_us == "us"

prefix_eu = BedrockModel.get_model_prefix_with_warning("eu-west-1")
assert prefix_eu == "eu"

assert len(captured_warnings) == 0


def test_get_model_prefix_with_warning_unsupported_region_warns(captured_warnings):
"""Test get_model_prefix_with_warning warns for unsupported regions."""
prefix = BedrockModel.get_model_prefix_with_warning("ca-central-1")

assert prefix == "ca"
assert len(captured_warnings) == 1
assert "This region ca-central-1 does not support" in str(captured_warnings[0].message)
assert "our default inference endpoint" in str(captured_warnings[0].message)


def test_get_model_prefix_with_warning_no_warning_with_custom_model_id(captured_warnings):
"""Test get_model_prefix_with_warning doesn't warn when custom model_id provided."""
model_config = {"model_id": "custom-model"}
prefix = BedrockModel.get_model_prefix_with_warning("ca-central-1", model_config)

assert prefix == "ca"
assert len(captured_warnings) == 0


def test_init_with_unsupported_region_warns(session_cls, captured_warnings):
"""Test BedrockModel initialization warns for unsupported regions."""
BedrockModel(region_name="ca-central-1")

assert len(captured_warnings) == 1
assert "This region ca-central-1 does not support" in str(captured_warnings[0].message)


def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings):
"""Test BedrockModel initialization doesn't warn when custom model_id provided."""
BedrockModel(region_name="ca-central-1", model_id="custom-model")

assert len(captured_warnings) == 0


def test_default_model_id_format_with_region_prefix(session_cls):
"""Test that default model ID is formatted with region prefix."""
session_cls.return_value.region_name = "eu-west-1"

model = BedrockModel()
model_id = model.get_config().get("model_id")

assert model_id.startswith("eu.")
assert "anthropic.claude-sonnet-4" in model_id


def test_custom_model_id_not_overridden_by_region_formatting(session_cls):
"""Test that custom model_id is not overridden by region formatting."""
custom_model_id = "custom.model.id"

model = BedrockModel(model_id=custom_model_id)
model_id = model.get_config().get("model_id")

assert model_id == custom_model_id


def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings):
"""Test get_model_prefix_with_warning warns for APAC regions since 'ap' is not in supported prefixes."""
prefix = BedrockModel.get_model_prefix_with_warning("ap-southeast-1")

assert prefix == "apac"
Loading