Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"]
llamaapi = ["llama-api-client>=0.1.0,<1.0.0"]
mistral = ["mistralai>=1.8.2,<2.0.0"]
ollama = ["ollama>=0.4.8,<1.0.0"]
openai = ["openai>=1.68.0,<3.0.0"]
openai = ["openai>=1.68.0,<3.0.0", "aws-bedrock-token-generator>=1.1.0,<2.0.0"]
Comment thread
JackYPCOnline marked this conversation as resolved.
writer = ["writer-sdk>=2.2.0,<3.0.0"]
sagemaker = [
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",
Expand Down
126 changes: 126 additions & 0 deletions src/strands/models/_openai_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Internal helpers for routing OpenAI-compatible clients to Bedrock Mantle.

Converts a ``bedrock_mantle_config`` dict into the ``base_url`` and ``api_key`` that the
OpenAI Python SDK consumes. Tokens are minted on demand via
``aws_bedrock_token_generator.provide_token`` so long-running agents survive the
bearer token's maximum lifetime.

``aws_bedrock_token_generator`` is part of the ``openai`` extras group
(``pip install strands-agents[openai]``) but is *not* included in the ``litellm``
or ``sagemaker`` extras, which also pull in the ``openai`` package. The import is
therefore lazy — it happens inside :func:`resolve_bedrock_client_args` so that
those other extras never trigger an ``ImportError`` at module load.
"""

from __future__ import annotations

from datetime import timedelta
from typing import Any, TypedDict

import boto3
from botocore.credentials import CredentialProvider

_MANTLE_BASE_URL_TEMPLATE = "https://bedrock-mantle.{region}.api.aws/v1"
_MANTLE_DOCS_URL = "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html"


class BedrockMantleConfig(TypedDict, total=False):
"""Config for routing an OpenAI-compatible client through Bedrock Mantle.

Attributes:
region: AWS region hosting the Bedrock Mantle endpoint. If omitted, resolved
from ``boto_session`` (if provided) or the standard boto3 chain
(``AWS_REGION`` / ``AWS_DEFAULT_REGION`` / active profile / EC2 metadata).
A :class:`ValueError` is raised if none resolve.
boto_session: Optional :class:`boto3.Session` used to resolve the region when
``region`` is not provided. Useful for picking up a non-default profile
without exporting env vars.
credentials_provider: Optional botocore :class:`~botocore.credentials.CredentialProvider`
forwarded to ``provide_token``. Omit to let the token generator use the
standard AWS credential chain.
expiry: Optional ``timedelta`` for the bearer token's lifetime, forwarded to
``provide_token``. Defaults to the generator's built-in lifetime when
omitted.
"""

region: str
boto_session: boto3.Session
Comment thread
JackYPCOnline marked this conversation as resolved.
credentials_provider: CredentialProvider
expiry: timedelta


def _resolve_region(config: BedrockMantleConfig) -> str:
"""Resolve the AWS region, preferring explicit config then falling back to boto3.

Raises:
ValueError: If no region can be resolved from the config, an attached session,
or the standard boto3 credential chain.
"""
region = config.get("region")
if region:
return region

session = config.get("boto_session")
if session is not None and session.region_name:
return str(session.region_name)

# ``boto3.Session()`` with no args reads ``AWS_REGION`` / ``AWS_DEFAULT_REGION``,
# the active profile, and falls back to EC2 instance metadata — the same chain
# :class:`BedrockModel` uses.
default_region = boto3.Session().region_name
if default_region:
return str(default_region)

raise ValueError(
"Could not resolve an AWS region for Bedrock Mantle. Pass 'region' in "
"bedrock_mantle_config, attach a boto_session with a configured region, or set "
f"AWS_REGION in the environment. See {_MANTLE_DOCS_URL} for supported regions."
)


def resolve_bedrock_client_args(
config: BedrockMantleConfig, client_args: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Resolve a ``BedrockMantleConfig`` (plus optional ``client_args``) into OpenAI client kwargs.

Mints a fresh bearer token on every call. Callers are expected to validate that
``client_args`` does not contain ``base_url`` or ``api_key`` before calling this
function (typically at ``__init__`` time for fail-fast behavior).

Raises:
ValueError: If no region can be resolved.
ImportError: If ``aws-bedrock-token-generator`` is not installed.
RuntimeError: If token minting fails (e.g. missing AWS credentials).
"""
region = _resolve_region(config)

# ``aws-bedrock-token-generator`` is included in the ``openai`` extras group but not in
# ``litellm`` or ``sagemaker`` (which also depend on the ``openai`` package). The lazy
# import keeps those extras from hitting an ImportError at module load.
try:
from aws_bedrock_token_generator import provide_token
except ImportError as e:
raise ImportError(
"bedrock_mantle_config requires the 'aws-bedrock-token-generator' package. "
"Install it with: pip install strands-agents[openai]"
) from e

# Only forward kwargs the user set; provide_token rejects expiry=None.
token_kwargs: dict[str, Any] = {"region": region}
if "credentials_provider" in config:
token_kwargs["aws_credentials_provider"] = config["credentials_provider"]
if "expiry" in config:
token_kwargs["expiry"] = config["expiry"]

try:
token = provide_token(**token_kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to mint Bedrock Mantle bearer token for region '{region}'. "
"Verify your AWS credentials and network connectivity."
) from e

resolved: dict[str, Any] = dict(client_args or {})
resolved["base_url"] = _MANTLE_BASE_URL_TEMPLATE.format(region=region)
Comment thread
JackYPCOnline marked this conversation as resolved.
resolved["api_key"] = token
return resolved
38 changes: 33 additions & 5 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args
from ._validation import _has_location_source, validate_config_keys
from .model import BaseModelConfig, Model

Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
self,
client: Client | None = None,
client_args: dict[str, Any] | None = None,
bedrock_mantle_config: BedrockMantleConfig | None = None,
**model_config: Unpack[OpenAIConfig],
) -> None:
"""Initialize provider instance.
Expand All @@ -87,23 +89,50 @@ def __init__(
Note: The client should not be shared across different asyncio event loops.
client_args: Arguments for the OpenAI client (legacy approach).
For a complete list of supported arguments, see https://pypi.org/project/openai/.
May be combined with ``bedrock_mantle_config``; when both are set,
``bedrock_mantle_config`` derives ``base_url`` and ``api_key`` (which must not
appear in ``client_args``).
bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle
(OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted
keys. When set, a fresh bearer token is minted on every request. Cannot be
combined with a pre-built ``client``.
**model_config: Configuration options for the OpenAI model.

Raises:
ValueError: If both `client` and `client_args` are provided.
ValueError: If ``client`` is combined with ``client_args`` or ``bedrock_mantle_config``.
"""
validate_config_keys(model_config, self.OpenAIConfig)
self.config = dict(model_config)

# Validate that only one client configuration method is provided
if client is not None and client_args is not None and len(client_args) > 0:
# client_args + bedrock_mantle_config is allowed; the config derives base_url / api_key.
client_args_provided = client_args is not None and len(client_args) > 0
if client is not None and client_args_provided:
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
if bedrock_mantle_config is not None and client is not None:
raise ValueError("'bedrock_mantle_config' cannot be combined with a pre-built 'client'.")
if bedrock_mantle_config is not None and client_args:
conflicting = [k for k in ("api_key", "base_url") if k in client_args]
if conflicting:
raise ValueError(
f"client_args must not contain {conflicting} when bedrock_mantle_config is set; "
"these are derived from the Mantle config automatically."
)

self._custom_client = client
self.client_args = client_args or {}
self._bedrock_mantle_config = bedrock_mantle_config

logger.debug("config=<%s> | initializing", self.config)

def _resolve_client_args(self) -> dict[str, Any]:
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.

Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set.
"""
if self._bedrock_mantle_config is not None:
return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args)
return self.client_args

@override
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
"""Update the OpenAI model configuration with the provided arguments.
Expand Down Expand Up @@ -590,11 +619,10 @@ async def _get_client(self) -> AsyncIterator[Any]:
# Use the injected client (caller manages lifecycle)
yield self._custom_client
else:
# Create a new client from client_args
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
# refer to https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
yield client

@override
Expand Down
36 changes: 32 additions & 4 deletions src/strands/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402
from ..types.streaming import StreamEvent # noqa: E402
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402
from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args # noqa: E402
from ._validation import validate_config_keys # noqa: E402
from .model import BaseModelConfig, Model # noqa: E402

Expand Down Expand Up @@ -141,21 +142,48 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False):
stateful: bool

def __init__(
self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig]
self,
client_args: dict[str, Any] | None = None,
bedrock_mantle_config: BedrockMantleConfig | None = None,
**model_config: Unpack[OpenAIResponsesConfig],
) -> None:
"""Initialize provider instance.

Args:
client_args: Arguments for the OpenAI client.
For a complete list of supported arguments, see https://pypi.org/project/openai/.
May be combined with ``bedrock_mantle_config``; when both are set, the config
derives ``base_url`` and ``api_key`` (which must not appear in ``client_args``).
bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle
(OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted
keys. When set, a fresh bearer token is minted on every request.
**model_config: Configuration options for the OpenAI Responses API model.
"""
validate_config_keys(model_config, self.OpenAIResponsesConfig)
self.config = dict(model_config)

self.client_args = client_args or {}
self._bedrock_mantle_config = bedrock_mantle_config

if bedrock_mantle_config is not None and client_args:
conflicting = [k for k in ("api_key", "base_url") if k in client_args]
if conflicting:
raise ValueError(
f"client_args must not contain {conflicting} when bedrock_mantle_config is set; "
"these are derived from the Mantle config automatically."
)

logger.debug("config=<%s> | initializing", self.config)

def _resolve_client_args(self) -> dict[str, Any]:
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.

Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set.
"""
if self._bedrock_mantle_config is not None:
return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args)
return self.client_args

@property
@override
def stateful(self) -> bool:
Expand Down Expand Up @@ -215,7 +243,7 @@ async def count_tokens(
count_tokens_fields = {"model", "input", "instructions", "tools"}
request = {k: request[k] for k in request.keys() & count_tokens_fields}

async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
response = await client.responses.input_tokens.count(**request)
total_tokens: int = response.input_tokens

Expand Down Expand Up @@ -267,7 +295,7 @@ async def stream(

logger.debug("invoking OpenAI Responses API model")

async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
try:
response = await client.responses.create(**request)

Expand Down Expand Up @@ -447,7 +475,7 @@ async def structured_output(
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
"""
async with openai.AsyncOpenAI(**self.client_args) as client:
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
try:
response = await client.responses.parse(
model=self.get_config()["model_id"],
Expand Down
Loading
Loading