Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"]
llamaapi = ["llama-api-client>=0.1.0,<1.0.0"]
mistral = ["mistralai>=1.8.2,<2.0.0"]
ollama = ["ollama>=0.4.8,<1.0.0"]
openai = ["openai>=1.68.0,<3.0.0"]
openai = ["openai>=1.68.0,<3.0.0", "aws-bedrock-token-generator>=1.1.0,<2.0.0"]
Comment thread
JackYPCOnline marked this conversation as resolved.
writer = ["writer-sdk>=2.2.0,<3.0.0"]
sagemaker = [
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",
Expand Down
83 changes: 83 additions & 0 deletions src/strands/models/_openai_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Internal helpers for routing OpenAI-compatible clients to Bedrock Mantle.

Converts an ``aws_config`` dict into the ``base_url`` and ``api_key`` that the
OpenAI Python SDK consumes. Tokens are minted on demand via
``aws_bedrock_token_generator.provide_token`` so long-running agents keep working
past the bearer token's maximum lifetime.

``aws_bedrock_token_generator`` is imported lazily inside
:func:`resolve_bedrock_client_args` so that users of OpenAI-compatible extras
(``sagemaker``, ``litellm``, bare ``openai`` without Mantle) don't pay an
``ImportError`` just for importing the model class.
"""

from __future__ import annotations

from datetime import timedelta
from typing import Any, TypedDict

from typing_extensions import Required

_MANTLE_BASE_URL_TEMPLATE = "https://bedrock-mantle.{region}.api.aws/v1"


class AwsConfig(TypedDict, total=False):
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
"""AWS-side config for reaching Bedrock Mantle via an OpenAI-compatible client.

Attributes:
region: AWS region hosting the Bedrock Mantle endpoint (required).
credentials_provider: Optional botocore ``CredentialProvider`` forwarded to
``provide_token``. Defaults to the AWS credential chain.
expiry: Optional ``timedelta`` for the bearer token's lifetime, forwarded
to ``provide_token``.
"""

region: Required[str]
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
credentials_provider: Any
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
expiry: timedelta


def resolve_bedrock_client_args(aws_config: AwsConfig, client_args: dict[str, Any] | None = None) -> dict[str, Any]:
"""Resolve an ``AwsConfig`` (plus optional ``client_args``) into OpenAI client kwargs.

Mints a fresh bearer token on every call. When ``client_args`` is provided, its
entries are preserved except for ``base_url`` and ``api_key``, which are always
overridden by the values derived from ``aws_config``.

Raises:
ValueError: If ``aws_config['region']`` is missing or empty.
ImportError: If ``aws-bedrock-token-generator`` is not installed.
RuntimeError: If token minting fails (e.g. missing AWS credentials).
"""
region = aws_config.get("region")
if not region:
raise ValueError("aws_config must include a non-empty 'region'.")

try:
from aws_bedrock_token_generator import provide_token
except ImportError as e:
raise ImportError(
"aws_config requires the 'aws-bedrock-token-generator' package. "
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
"Install it with: pip install strands-agents[openai]"
) from e

# Only forward optional kwargs when explicitly set, so provide_token's own
# defaults apply. Passing expiry=None in particular crashes the library.
token_kwargs: dict[str, Any] = {"region": region}
if "credentials_provider" in aws_config:
token_kwargs["aws_credentials_provider"] = aws_config["credentials_provider"]
if "expiry" in aws_config:
token_kwargs["expiry"] = aws_config["expiry"]

try:
token = provide_token(**token_kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to mint Bedrock Mantle bearer token for region '{region}'. "
"Verify your AWS credentials and network connectivity."
) from e

resolved: dict[str, Any] = dict(client_args or {})
resolved["base_url"] = _MANTLE_BASE_URL_TEMPLATE.format(region=region)
Comment thread
JackYPCOnline marked this conversation as resolved.
resolved["api_key"] = token
return resolved
39 changes: 34 additions & 5 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._openai_bedrock import AwsConfig, resolve_bedrock_client_args
from ._validation import _has_location_source, validate_config_keys
from .model import BaseModelConfig, Model

Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
self,
client: Client | None = None,
client_args: dict[str, Any] | None = None,
aws_config: AwsConfig | None = None,
**model_config: Unpack[OpenAIConfig],
) -> None:
"""Initialize provider instance.
Expand All @@ -87,23 +89,50 @@ def __init__(
Note: The client should not be shared across different asyncio event loops.
client_args: Arguments for the OpenAI client (legacy approach).
For a complete list of supported arguments, see https://pypi.org/project/openai/.
May be combined with ``aws_config``; transport-level options like ``http_client``,
``timeout``, or ``default_headers`` are preserved, while ``base_url`` and
``api_key`` are always overridden by ``aws_config`` when both are set.
aws_config: Route requests through Amazon Bedrock's Mantle (OpenAI-compatible)
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
endpoint. Provide ``{"region": "us-east-1"}`` at minimum. Accepts optional
``credentials_provider`` (a botocore ``CredentialProvider``) and ``expiry``
(a ``datetime.timedelta`` up to 12h). When set, a fresh bearer token is minted
on every request via ``aws-bedrock-token-generator`` and the OpenAI client is
pointed at ``https://bedrock-mantle.<region>.api.aws/v1``. Cannot be combined
with a pre-built ``client``.
**model_config: Configuration options for the OpenAI model.

Raises:
ValueError: If both `client` and `client_args` are provided.
ValueError: If ``client`` is combined with ``client_args`` or ``aws_config``,
or if ``aws_config`` is missing a region.
"""
validate_config_keys(model_config, self.OpenAIConfig)
self.config = dict(model_config)

# Validate that only one client configuration method is provided
if client is not None and client_args is not None and len(client_args) > 0:
# Validate that client configuration methods are mutually exclusive where they conflict.
# client_args + aws_config is allowed — aws_config will override base_url / api_key only.
client_args_provided = client_args is not None and len(client_args) > 0
if client is not None and client_args_provided:
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
if aws_config is not None and client is not None:
raise ValueError("'aws_config' cannot be combined with a pre-built 'client'.")

self._custom_client = client
self.client_args = client_args or {}
self._aws_config = aws_config

logger.debug("config=<%s> | initializing", self.config)

def _resolve_client_args(self) -> dict[str, Any]:
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.

When ``aws_config`` is set, a fresh Bedrock Mantle bearer token is minted on every
call and ``base_url`` / ``api_key`` are overridden. Any other entries from
``client_args`` (e.g. ``http_client``, ``timeout``) are preserved.
"""
if self._aws_config is not None:
return resolve_bedrock_client_args(self._aws_config, self.client_args)
return self.client_args

@override
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
"""Update the OpenAI model configuration with the provided arguments.
Expand Down Expand Up @@ -590,11 +619,11 @@ async def _get_client(self) -> AsyncIterator[Any]:
# Use the injected client (caller manages lifecycle)
yield self._custom_client
else:
# Create a new client from client_args
# Create a new client from resolved args (static client_args or freshly-minted Bedrock creds).
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
# refer to https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
yield client

@override
Expand Down
37 changes: 33 additions & 4 deletions src/strands/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402
from ..types.streaming import StreamEvent # noqa: E402
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402
from ._openai_bedrock import AwsConfig, resolve_bedrock_client_args # noqa: E402
from ._validation import validate_config_keys # noqa: E402
from .model import BaseModelConfig, Model # noqa: E402

Expand Down Expand Up @@ -141,21 +142,49 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False):
stateful: bool

def __init__(
self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig]
self,
client_args: dict[str, Any] | None = None,
aws_config: AwsConfig | None = None,
**model_config: Unpack[OpenAIResponsesConfig],
) -> None:
"""Initialize provider instance.

Args:
client_args: Arguments for the OpenAI client.
For a complete list of supported arguments, see https://pypi.org/project/openai/.
May be combined with ``aws_config``; transport-level options like ``http_client``,
``timeout``, or ``default_headers`` are preserved, while ``base_url`` and
``api_key`` are always overridden by ``aws_config`` when both are set.
aws_config: Route requests through Amazon Bedrock's Mantle (OpenAI-compatible)
endpoint. Provide ``{"region": "us-east-1"}`` at minimum. Accepts optional
``credentials_provider`` (a botocore ``CredentialProvider``) and ``expiry``
(a ``datetime.timedelta`` up to 12h). When set, a fresh bearer token is minted
on every request via ``aws-bedrock-token-generator`` and the OpenAI client is
pointed at ``https://bedrock-mantle.<region>.api.aws/v1``.
**model_config: Configuration options for the OpenAI Responses API model.

Raises:
ValueError: If ``aws_config`` is missing a region.
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
"""
validate_config_keys(model_config, self.OpenAIResponsesConfig)
self.config = dict(model_config)

self.client_args = client_args or {}
self._aws_config = aws_config

logger.debug("config=<%s> | initializing", self.config)

def _resolve_client_args(self) -> dict[str, Any]:
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.

When ``aws_config`` is set, a fresh Bedrock Mantle bearer token is minted on every
call and ``base_url`` / ``api_key`` are overridden. Any other entries from
``client_args`` (e.g. ``http_client``, ``timeout``) are preserved.
"""
if self._aws_config is not None:
return resolve_bedrock_client_args(self._aws_config, self.client_args)
return self.client_args

@property
@override
def stateful(self) -> bool:
Expand Down Expand Up @@ -215,7 +244,7 @@ async def count_tokens(
count_tokens_fields = {"model", "input", "instructions", "tools"}
request = {k: request[k] for k in request.keys() & count_tokens_fields}

async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
response = await client.responses.input_tokens.count(**request)
total_tokens: int = response.input_tokens

Expand Down Expand Up @@ -267,7 +296,7 @@ async def stream(

logger.debug("invoking OpenAI Responses API model")

async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
try:
response = await client.responses.create(**request)

Expand Down Expand Up @@ -447,7 +476,7 @@ async def structured_output(
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
"""
async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
try:
response = await client.responses.parse(
model=self.get_config()["model_id"],
Expand Down
108 changes: 108 additions & 0 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,3 +1710,111 @@ def test_format_request_messages_multiple_tool_calls_with_images():
},
]
assert tru_result == exp_result


# =============================================================================
# Bedrock Mantle (aws_config) integration with OpenAIModel
# =============================================================================


class TestOpenAIModelAwsConfig:
"""Tests for the Bedrock Mantle pathway via the aws_config kwarg."""

@pytest.fixture
def mock_provide_token(self):
with unittest.mock.patch("aws_bedrock_token_generator.provide_token") as mock:
mock.return_value = "bedrock-api-key-deadbeef&Version=1"
yield mock

def test_aws_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token):
"""aws_config produces the Mantle base_url and a minted bearer token as api_key."""
_ = openai_client
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})

# api_key is resolved per-request (lazy), so check via the resolved client_args at call time
resolved = model._resolve_client_args()
assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1"
assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1"
# Only region is forwarded when the user did not set optional kwargs,
# so provide_token's own defaults (e.g. 12h expiry) apply.
mock_provide_token.assert_called_once_with(region="us-east-1")

def test_aws_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token):
"""Optional credentials_provider and expiry are forwarded to provide_token."""
_ = openai_client
from datetime import timedelta

provider = unittest.mock.Mock()
model = OpenAIModel(
model_id="openai.gpt-oss-120b",
aws_config={
"region": "us-west-2",
"credentials_provider": provider,
"expiry": timedelta(minutes=15),
},
)
model._resolve_client_args()
mock_provide_token.assert_called_once_with(
region="us-west-2",
aws_credentials_provider=provider,
expiry=timedelta(minutes=15),
)

def test_aws_config_mints_token_per_request(self, openai_client, mock_provide_token):
"""Each call to _resolve_client_args mints a fresh token (long-lived processes)."""
_ = openai_client
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})
model._resolve_client_args()
model._resolve_client_args()
model._resolve_client_args()
assert mock_provide_token.call_count == 3

def test_aws_config_conflicts_with_custom_client(self, openai_client):
"""Cannot pass both aws_config and a pre-built client."""
_ = openai_client
custom_client = unittest.mock.Mock()
with pytest.raises(ValueError, match="aws_config"):
OpenAIModel(
model_id="openai.gpt-oss-120b",
client=custom_client,
aws_config={"region": "us-east-1"},
)

def test_aws_config_merges_with_client_args(self, openai_client, mock_provide_token):
"""aws_config is allowed alongside client_args; base_url and api_key are overridden,
other transport-level options (timeout, http_client, default_headers) are preserved.
"""
_ = openai_client
sentinel_http_client = unittest.mock.Mock()
model = OpenAIModel(
model_id="openai.gpt-oss-120b",
client_args={
"api_key": "will-be-overridden",
"base_url": "https://also-overridden.example.com",
"timeout": 42,
"http_client": sentinel_http_client,
"default_headers": {"X-Trace-Id": "abc"},
},
aws_config={"region": "us-east-1"},
)
resolved = model._resolve_client_args()
assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1"
assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1"
assert resolved["timeout"] == 42
assert resolved["http_client"] is sentinel_http_client
assert resolved["default_headers"] == {"X-Trace-Id": "abc"}

def test_aws_config_requires_region(self, openai_client):
"""aws_config must include a region; validated when the helper mints a token."""
_ = openai_client
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={})
with pytest.raises(ValueError, match="region"):
model._resolve_client_args()

def test_aws_config_wraps_token_failures_with_context(self, openai_client, mock_provide_token):
"""provide_token failures are wrapped in a RuntimeError with actionable context."""
_ = openai_client
mock_provide_token.side_effect = RuntimeError("no credentials in chain")
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})
with pytest.raises(RuntimeError, match="Bedrock Mantle bearer token.*us-east-1"):
model._resolve_client_args()
Loading
Loading