Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"]
llamaapi = ["llama-api-client>=0.1.0,<1.0.0"]
mistral = ["mistralai>=1.8.2,<2.0.0"]
ollama = ["ollama>=0.4.8,<1.0.0"]
openai = ["openai>=1.68.0,<3.0.0"]
openai = ["openai>=1.68.0,<3.0.0", "aws-bedrock-token-generator>=1.1.0,<2.0.0"]
Comment thread
JackYPCOnline marked this conversation as resolved.
writer = ["writer-sdk>=2.2.0,<3.0.0"]
sagemaker = [
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",
Expand Down
81 changes: 81 additions & 0 deletions src/strands/models/_openai_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Internal helpers for routing OpenAI-compatible clients to Bedrock Mantle.

Converts an ``aws_config`` dict into the ``base_url`` and ``api_key`` that the
OpenAI Python SDK consumes. Tokens are minted on demand via
``aws_bedrock_token_generator.provide_token`` so long-running agents survive the
bearer token's maximum lifetime.

``aws_bedrock_token_generator`` is imported lazily so that extras which reuse
the OpenAI package without pulling the Mantle dependency do not hit an
``ImportError`` at module load.
"""

from __future__ import annotations

from datetime import timedelta
from typing import Any, TypedDict

from typing_extensions import Required

_MANTLE_BASE_URL_TEMPLATE = "https://bedrock-mantle.{region}.api.aws/v1"


class AwsConfig(TypedDict, total=False):
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
"""AWS-side config for reaching Bedrock Mantle via an OpenAI-compatible client.

Attributes:
region: AWS region hosting the Bedrock Mantle endpoint.
credentials_provider: Optional botocore ``CredentialProvider`` forwarded to
``provide_token``. Defaults to the AWS credential chain.
expiry: Optional ``timedelta`` for the bearer token's lifetime, forwarded
to ``provide_token``.
"""

region: Required[str]
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
credentials_provider: Any
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
expiry: timedelta


def resolve_bedrock_client_args(aws_config: AwsConfig, client_args: dict[str, Any] | None = None) -> dict[str, Any]:
"""Resolve an ``AwsConfig`` (plus optional ``client_args``) into OpenAI client kwargs.

Mints a fresh bearer token on every call. When ``client_args`` is provided, its
entries are preserved except for ``base_url`` and ``api_key``, which are always
overridden by the values derived from ``aws_config``.

Raises:
ValueError: If ``aws_config['region']`` is missing or empty.
ImportError: If ``aws-bedrock-token-generator`` is not installed.
RuntimeError: If token minting fails (e.g. missing AWS credentials).
"""
region = aws_config.get("region")
if not region:
raise ValueError("aws_config must include a non-empty 'region'.")

try:
from aws_bedrock_token_generator import provide_token
except ImportError as e:
raise ImportError(
"aws_config requires the 'aws-bedrock-token-generator' package. "
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
"Install it with: pip install strands-agents[openai]"
) from e

# Only forward kwargs the user set; provide_token rejects expiry=None.
token_kwargs: dict[str, Any] = {"region": region}
if "credentials_provider" in aws_config:
token_kwargs["aws_credentials_provider"] = aws_config["credentials_provider"]
if "expiry" in aws_config:
token_kwargs["expiry"] = aws_config["expiry"]

try:
token = provide_token(**token_kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to mint Bedrock Mantle bearer token for region '{region}'. "
"Verify your AWS credentials and network connectivity."
) from e

resolved: dict[str, Any] = dict(client_args or {})
resolved["base_url"] = _MANTLE_BASE_URL_TEMPLATE.format(region=region)
Comment thread
JackYPCOnline marked this conversation as resolved.
resolved["api_key"] = token
return resolved
29 changes: 24 additions & 5 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._openai_bedrock import AwsConfig, resolve_bedrock_client_args
from ._validation import _has_location_source, validate_config_keys
from .model import BaseModelConfig, Model

Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
self,
client: Client | None = None,
client_args: dict[str, Any] | None = None,
aws_config: AwsConfig | None = None,
**model_config: Unpack[OpenAIConfig],
) -> None:
"""Initialize provider instance.
Expand All @@ -87,23 +89,41 @@ def __init__(
Note: The client should not be shared across different asyncio event loops.
client_args: Arguments for the OpenAI client (legacy approach).
For a complete list of supported arguments, see https://pypi.org/project/openai/.
May be combined with ``aws_config``; when both are set, ``aws_config`` overrides
``base_url`` and ``api_key`` only.
aws_config: Route requests through Amazon Bedrock's Mantle (OpenAI-compatible)
Comment thread
JackYPCOnline marked this conversation as resolved.
Outdated
endpoint. See :class:`AwsConfig` for accepted keys. When set, a fresh bearer
token is minted on every request. Cannot be combined with a pre-built ``client``.
**model_config: Configuration options for the OpenAI model.

Raises:
ValueError: If both `client` and `client_args` are provided.
ValueError: If ``client`` is combined with ``client_args`` or ``aws_config``.
"""
validate_config_keys(model_config, self.OpenAIConfig)
self.config = dict(model_config)

# Validate that only one client configuration method is provided
if client is not None and client_args is not None and len(client_args) > 0:
# client_args + aws_config is allowed; aws_config overrides base_url / api_key only.
client_args_provided = client_args is not None and len(client_args) > 0
if client is not None and client_args_provided:
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
if aws_config is not None and client is not None:
raise ValueError("'aws_config' cannot be combined with a pre-built 'client'.")

self._custom_client = client
self.client_args = client_args or {}
self._aws_config = aws_config

logger.debug("config=<%s> | initializing", self.config)

def _resolve_client_args(self) -> dict[str, Any]:
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.

Delegates to :func:`resolve_bedrock_client_args` when ``aws_config`` is set.
"""
if self._aws_config is not None:
return resolve_bedrock_client_args(self._aws_config, self.client_args)
return self.client_args

@override
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
"""Update the OpenAI model configuration with the provided arguments.
Expand Down Expand Up @@ -590,11 +610,10 @@ async def _get_client(self) -> AsyncIterator[Any]:
# Use the injected client (caller manages lifecycle)
yield self._custom_client
else:
# Create a new client from client_args
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
# refer to https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
yield client

@override
Expand Down
28 changes: 24 additions & 4 deletions src/strands/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402
from ..types.streaming import StreamEvent # noqa: E402
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402
from ._openai_bedrock import AwsConfig, resolve_bedrock_client_args # noqa: E402
from ._validation import validate_config_keys # noqa: E402
from .model import BaseModelConfig, Model # noqa: E402

Expand Down Expand Up @@ -141,21 +142,40 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False):
stateful: bool

def __init__(
self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig]
self,
client_args: dict[str, Any] | None = None,
aws_config: AwsConfig | None = None,
**model_config: Unpack[OpenAIResponsesConfig],
) -> None:
"""Initialize provider instance.

Args:
client_args: Arguments for the OpenAI client.
For a complete list of supported arguments, see https://pypi.org/project/openai/.
May be combined with ``aws_config``; when both are set, ``aws_config`` overrides
``base_url`` and ``api_key`` only.
aws_config: Route requests through Amazon Bedrock's Mantle (OpenAI-compatible)
endpoint. See :class:`AwsConfig` for accepted keys. When set, a fresh bearer
token is minted on every request.
**model_config: Configuration options for the OpenAI Responses API model.
"""
validate_config_keys(model_config, self.OpenAIResponsesConfig)
self.config = dict(model_config)

self.client_args = client_args or {}
self._aws_config = aws_config

logger.debug("config=<%s> | initializing", self.config)

def _resolve_client_args(self) -> dict[str, Any]:
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.

Delegates to :func:`resolve_bedrock_client_args` when ``aws_config`` is set.
"""
if self._aws_config is not None:
return resolve_bedrock_client_args(self._aws_config, self.client_args)
return self.client_args

@property
@override
def stateful(self) -> bool:
Expand Down Expand Up @@ -215,7 +235,7 @@ async def count_tokens(
count_tokens_fields = {"model", "input", "instructions", "tools"}
request = {k: request[k] for k in request.keys() & count_tokens_fields}

async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
response = await client.responses.input_tokens.count(**request)
total_tokens: int = response.input_tokens

Expand Down Expand Up @@ -267,7 +287,7 @@ async def stream(

logger.debug("invoking OpenAI Responses API model")

async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
try:
response = await client.responses.create(**request)

Expand Down Expand Up @@ -447,7 +467,7 @@ async def structured_output(
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
"""
async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
try:
response = await client.responses.parse(
model=self.get_config()["model_id"],
Expand Down
105 changes: 105 additions & 0 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,3 +1710,108 @@ def test_format_request_messages_multiple_tool_calls_with_images():
},
]
assert tru_result == exp_result


# =============================================================================
# Bedrock Mantle (aws_config) integration with OpenAIModel
# =============================================================================


class TestOpenAIModelAwsConfig:
@pytest.fixture
def mock_provide_token(self):
with unittest.mock.patch("aws_bedrock_token_generator.provide_token") as mock:
mock.return_value = "bedrock-api-key-deadbeef&Version=1"
yield mock

def test_aws_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token):
"""aws_config produces the Mantle base_url and a minted bearer token as api_key."""
_ = openai_client
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})

# Token is minted lazily per request, so inspect the resolved kwargs.
resolved = model._resolve_client_args()
assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1"
assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1"
# Optional kwargs aren't forwarded so provide_token's own defaults apply.
mock_provide_token.assert_called_once_with(region="us-east-1")

def test_aws_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token):
"""Optional credentials_provider and expiry are forwarded to provide_token."""
_ = openai_client
from datetime import timedelta

provider = unittest.mock.Mock()
model = OpenAIModel(
model_id="openai.gpt-oss-120b",
aws_config={
"region": "us-west-2",
"credentials_provider": provider,
"expiry": timedelta(minutes=15),
},
)
model._resolve_client_args()
mock_provide_token.assert_called_once_with(
region="us-west-2",
aws_credentials_provider=provider,
expiry=timedelta(minutes=15),
)

def test_aws_config_mints_token_per_request(self, openai_client, mock_provide_token):
"""Each call to _resolve_client_args mints a fresh token (long-lived processes)."""
_ = openai_client
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})
model._resolve_client_args()
model._resolve_client_args()
model._resolve_client_args()
assert mock_provide_token.call_count == 3

def test_aws_config_conflicts_with_custom_client(self, openai_client):
"""Cannot pass both aws_config and a pre-built client."""
_ = openai_client
custom_client = unittest.mock.Mock()
with pytest.raises(ValueError, match="aws_config"):
OpenAIModel(
model_id="openai.gpt-oss-120b",
client=custom_client,
aws_config={"region": "us-east-1"},
)

def test_aws_config_merges_with_client_args(self, openai_client, mock_provide_token):
"""aws_config is allowed alongside client_args; base_url and api_key are overridden,
other transport-level options (timeout, http_client, default_headers) are preserved.
"""
_ = openai_client
sentinel_http_client = unittest.mock.Mock()
model = OpenAIModel(
model_id="openai.gpt-oss-120b",
client_args={
"api_key": "will-be-overridden",
"base_url": "https://also-overridden.example.com",
"timeout": 42,
"http_client": sentinel_http_client,
"default_headers": {"X-Trace-Id": "abc"},
},
aws_config={"region": "us-east-1"},
)
resolved = model._resolve_client_args()
assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1"
assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1"
assert resolved["timeout"] == 42
assert resolved["http_client"] is sentinel_http_client
assert resolved["default_headers"] == {"X-Trace-Id": "abc"}

def test_aws_config_requires_region(self, openai_client):
"""aws_config must include a region; validated when the helper mints a token."""
_ = openai_client
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={})
with pytest.raises(ValueError, match="region"):
model._resolve_client_args()

def test_aws_config_wraps_token_failures_with_context(self, openai_client, mock_provide_token):
"""provide_token failures are wrapped in a RuntimeError with actionable context."""
_ = openai_client
mock_provide_token.side_effect = RuntimeError("no credentials in chain")
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})
with pytest.raises(RuntimeError, match="Bedrock Mantle bearer token.*us-east-1"):
model._resolve_client_args()
Loading
Loading