Skip to content

Commit f561707

Browse files
authored
Merge branch 'main' into feat/strict-tool-use
2 parents 297b930 + a245e6d commit f561707

7 files changed

Lines changed: 500 additions & 53 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"]
5050
llamaapi = ["llama-api-client>=0.1.0,<1.0.0"]
5151
mistral = ["mistralai>=1.8.2,<2.0.0"]
5252
ollama = ["ollama>=0.4.8,<1.0.0"]
53-
openai = ["openai>=1.68.0,<3.0.0"]
53+
openai = ["openai>=1.68.0,<3.0.0", "aws-bedrock-token-generator>=1.1.0,<2.0.0"]
5454
writer = ["writer-sdk>=2.2.0,<3.0.0"]
5555
sagemaker = [
5656
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Internal helpers for routing OpenAI-compatible clients to Bedrock Mantle.
2+
3+
Converts a ``bedrock_mantle_config`` dict into the ``base_url`` and ``api_key`` that the
4+
OpenAI Python SDK consumes. Tokens are minted on demand via
5+
``aws_bedrock_token_generator.provide_token`` so long-running agents survive the
6+
bearer token's maximum lifetime.
7+
8+
``aws_bedrock_token_generator`` is part of the ``openai`` extras group
9+
(``pip install strands-agents[openai]``) but is *not* included in the ``litellm``
10+
or ``sagemaker`` extras, which also pull in the ``openai`` package. The import is
11+
therefore lazy — it happens inside :func:`resolve_bedrock_client_args` so that
12+
those other extras never trigger an ``ImportError`` at module load.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
from datetime import timedelta
18+
from typing import Any, TypedDict
19+
20+
import boto3
21+
from botocore.credentials import CredentialProvider
22+
23+
_MANTLE_BASE_URL_TEMPLATE = "https://bedrock-mantle.{region}.api.aws/v1"
24+
_MANTLE_DOCS_URL = "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html"
25+
26+
27+
class BedrockMantleConfig(TypedDict, total=False):
28+
"""Config for routing an OpenAI-compatible client through Bedrock Mantle.
29+
30+
Attributes:
31+
region: AWS region hosting the Bedrock Mantle endpoint. If omitted, resolved
32+
from ``boto_session`` (if provided) or the standard boto3 chain
33+
(``AWS_REGION`` / ``AWS_DEFAULT_REGION`` / active profile / EC2 metadata).
34+
A :class:`ValueError` is raised if none resolve.
35+
boto_session: Optional :class:`boto3.Session` used to resolve the region when
36+
``region`` is not provided. Useful for picking up a non-default profile
37+
without exporting env vars.
38+
credentials_provider: Optional botocore :class:`~botocore.credentials.CredentialProvider`
39+
forwarded to ``provide_token``. Omit to let the token generator use the
40+
standard AWS credential chain.
41+
expiry: Optional ``timedelta`` for the bearer token's lifetime, forwarded to
42+
``provide_token``. Defaults to the generator's built-in lifetime when
43+
omitted.
44+
"""
45+
46+
region: str
47+
boto_session: boto3.Session
48+
credentials_provider: CredentialProvider
49+
expiry: timedelta
50+
51+
52+
def _resolve_region(config: BedrockMantleConfig) -> str:
53+
"""Resolve the AWS region, preferring explicit config then falling back to boto3.
54+
55+
Raises:
56+
ValueError: If no region can be resolved from the config, an attached session,
57+
or the standard boto3 credential chain.
58+
"""
59+
region = config.get("region")
60+
if region:
61+
return region
62+
63+
session = config.get("boto_session")
64+
if session is not None and session.region_name:
65+
return str(session.region_name)
66+
67+
# ``boto3.Session()`` with no args reads ``AWS_REGION`` / ``AWS_DEFAULT_REGION``,
68+
# the active profile, and falls back to EC2 instance metadata — the same chain
69+
# :class:`BedrockModel` uses.
70+
default_region = boto3.Session().region_name
71+
if default_region:
72+
return str(default_region)
73+
74+
raise ValueError(
75+
"Could not resolve an AWS region for Bedrock Mantle. Pass 'region' in "
76+
"bedrock_mantle_config, attach a boto_session with a configured region, or set "
77+
f"AWS_REGION in the environment. See {_MANTLE_DOCS_URL} for supported regions."
78+
)
79+
80+
81+
def resolve_bedrock_client_args(
82+
config: BedrockMantleConfig, client_args: dict[str, Any] | None = None
83+
) -> dict[str, Any]:
84+
"""Resolve a ``BedrockMantleConfig`` (plus optional ``client_args``) into OpenAI client kwargs.
85+
86+
Mints a fresh bearer token on every call. Callers are expected to validate that
87+
``client_args`` does not contain ``base_url`` or ``api_key`` before calling this
88+
function (typically at ``__init__`` time for fail-fast behavior).
89+
90+
Raises:
91+
ValueError: If no region can be resolved.
92+
ImportError: If ``aws-bedrock-token-generator`` is not installed.
93+
RuntimeError: If token minting fails (e.g. missing AWS credentials).
94+
"""
95+
region = _resolve_region(config)
96+
97+
# ``aws-bedrock-token-generator`` is included in the ``openai`` extras group but not in
98+
# ``litellm`` or ``sagemaker`` (which also depend on the ``openai`` package). The lazy
99+
# import keeps those extras from hitting an ImportError at module load.
100+
try:
101+
from aws_bedrock_token_generator import provide_token
102+
except ImportError as e:
103+
raise ImportError(
104+
"bedrock_mantle_config requires the 'aws-bedrock-token-generator' package. "
105+
"Install it with: pip install strands-agents[openai]"
106+
) from e
107+
108+
# Only forward kwargs the user set; provide_token rejects expiry=None.
109+
token_kwargs: dict[str, Any] = {"region": region}
110+
if "credentials_provider" in config:
111+
token_kwargs["aws_credentials_provider"] = config["credentials_provider"]
112+
if "expiry" in config:
113+
token_kwargs["expiry"] = config["expiry"]
114+
115+
try:
116+
token = provide_token(**token_kwargs)
117+
except Exception as e:
118+
raise RuntimeError(
119+
f"Failed to mint Bedrock Mantle bearer token for region '{region}'. "
120+
"Verify your AWS credentials and network connectivity."
121+
) from e
122+
123+
resolved: dict[str, Any] = dict(client_args or {})
124+
resolved["base_url"] = _MANTLE_BASE_URL_TEMPLATE.format(region=region)
125+
resolved["api_key"] = token
126+
return resolved

src/strands/models/openai.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..types.streaming import StreamEvent
2323
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
2424
from ._strict_schema import ensure_strict_json_schema
25+
from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args
2526
from ._validation import _has_location_source, validate_config_keys
2627
from .model import BaseModelConfig, Model
2728

@@ -72,6 +73,7 @@ def __init__(
7273
self,
7374
client: Client | None = None,
7475
client_args: dict[str, Any] | None = None,
76+
bedrock_mantle_config: BedrockMantleConfig | None = None,
7577
**model_config: Unpack[OpenAIConfig],
7678
) -> None:
7779
"""Initialize provider instance.
@@ -88,23 +90,50 @@ def __init__(
8890
Note: The client should not be shared across different asyncio event loops.
8991
client_args: Arguments for the OpenAI client (legacy approach).
9092
For a complete list of supported arguments, see https://pypi.org/project/openai/.
93+
May be combined with ``bedrock_mantle_config``; when both are set,
94+
``bedrock_mantle_config`` derives ``base_url`` and ``api_key`` (which must not
95+
appear in ``client_args``).
96+
bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle
97+
(OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted
98+
keys. When set, a fresh bearer token is minted on every request. Cannot be
99+
combined with a pre-built ``client``.
91100
**model_config: Configuration options for the OpenAI model.
92101
93102
Raises:
94-
ValueError: If both `client` and `client_args` are provided.
103+
ValueError: If ``client`` is combined with ``client_args`` or ``bedrock_mantle_config``.
95104
"""
96105
validate_config_keys(model_config, self.OpenAIConfig)
97106
self.config = dict(model_config)
98107

99-
# Validate that only one client configuration method is provided
100-
if client is not None and client_args is not None and len(client_args) > 0:
108+
# client_args + bedrock_mantle_config is allowed; the config derives base_url / api_key.
109+
client_args_provided = client_args is not None and len(client_args) > 0
110+
if client is not None and client_args_provided:
101111
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
112+
if bedrock_mantle_config is not None and client is not None:
113+
raise ValueError("'bedrock_mantle_config' cannot be combined with a pre-built 'client'.")
114+
if bedrock_mantle_config is not None and client_args:
115+
conflicting = [k for k in ("api_key", "base_url") if k in client_args]
116+
if conflicting:
117+
raise ValueError(
118+
f"client_args must not contain {conflicting} when bedrock_mantle_config is set; "
119+
"these are derived from the Mantle config automatically."
120+
)
102121

103122
self._custom_client = client
104123
self.client_args = client_args or {}
124+
self._bedrock_mantle_config = bedrock_mantle_config
105125

106126
logger.debug("config=<%s> | initializing", self.config)
107127

128+
def _resolve_client_args(self) -> dict[str, Any]:
129+
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.
130+
131+
Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set.
132+
"""
133+
if self._bedrock_mantle_config is not None:
134+
return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args)
135+
return self.client_args
136+
108137
@override
109138
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
110139
"""Update the OpenAI model configuration with the provided arguments.
@@ -596,11 +625,10 @@ async def _get_client(self) -> AsyncIterator[Any]:
596625
# Use the injected client (caller manages lifecycle)
597626
yield self._custom_client
598627
else:
599-
# Create a new client from client_args
600628
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
601629
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
602630
# refer to https://github.com/encode/httpx/discussions/2959.
603-
async with openai.AsyncOpenAI(**self.client_args) as client:
631+
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
604632
yield client
605633

606634
@override

src/strands/models/openai_responses.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from ..types.streaming import StreamEvent # noqa: E402
6060
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402
6161
from ._strict_schema import ensure_strict_json_schema # noqa: E402
62+
from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args # noqa: E402
6263
from ._validation import validate_config_keys # noqa: E402
6364
from .model import BaseModelConfig, Model # noqa: E402
6465

@@ -142,21 +143,48 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False):
142143
stateful: bool
143144

144145
def __init__(
145-
self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig]
146+
self,
147+
client_args: dict[str, Any] | None = None,
148+
bedrock_mantle_config: BedrockMantleConfig | None = None,
149+
**model_config: Unpack[OpenAIResponsesConfig],
146150
) -> None:
147151
"""Initialize provider instance.
148152
149153
Args:
150154
client_args: Arguments for the OpenAI client.
151155
For a complete list of supported arguments, see https://pypi.org/project/openai/.
156+
May be combined with ``bedrock_mantle_config``; when both are set, the config
157+
derives ``base_url`` and ``api_key`` (which must not appear in ``client_args``).
158+
bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle
159+
(OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted
160+
keys. When set, a fresh bearer token is minted on every request.
152161
**model_config: Configuration options for the OpenAI Responses API model.
153162
"""
154163
validate_config_keys(model_config, self.OpenAIResponsesConfig)
155164
self.config = dict(model_config)
165+
156166
self.client_args = client_args or {}
167+
self._bedrock_mantle_config = bedrock_mantle_config
168+
169+
if bedrock_mantle_config is not None and client_args:
170+
conflicting = [k for k in ("api_key", "base_url") if k in client_args]
171+
if conflicting:
172+
raise ValueError(
173+
f"client_args must not contain {conflicting} when bedrock_mantle_config is set; "
174+
"these are derived from the Mantle config automatically."
175+
)
157176

158177
logger.debug("config=<%s> | initializing", self.config)
159178

179+
def _resolve_client_args(self) -> dict[str, Any]:
180+
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.
181+
182+
Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set.
183+
"""
184+
if self._bedrock_mantle_config is not None:
185+
return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args)
186+
return self.client_args
187+
160188
@property
161189
@override
162190
def stateful(self) -> bool:
@@ -216,7 +244,7 @@ async def count_tokens(
216244
count_tokens_fields = {"model", "input", "instructions", "tools"}
217245
request = {k: request[k] for k in request.keys() & count_tokens_fields}
218246

219-
async with openai.AsyncOpenAI(**self.client_args) as client:
247+
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
220248
response = await client.responses.input_tokens.count(**request)
221249
total_tokens: int = response.input_tokens
222250

@@ -268,7 +296,7 @@ async def stream(
268296

269297
logger.debug("invoking OpenAI Responses API model")
270298

271-
async with openai.AsyncOpenAI(**self.client_args) as client:
299+
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
272300
try:
273301
response = await client.responses.create(**request)
274302

@@ -448,7 +476,7 @@ async def structured_output(
448476
ContextWindowOverflowException: If the input exceeds the model's context window.
449477
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
450478
"""
451-
async with openai.AsyncOpenAI(**self.client_args) as client:
479+
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
452480
try:
453481
response = await client.responses.parse(
454482
model=self.get_config()["model_id"],

0 commit comments

Comments
 (0)