Skip to content

Commit d94a7f4

Browse files
committed
feat: enable openai provider use aws profile
1 parent 6e208a8 commit d94a7f4

6 files changed

Lines changed: 278 additions & 53 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"]
5050
llamaapi = ["llama-api-client>=0.1.0,<1.0.0"]
5151
mistral = ["mistralai>=1.8.2,<2.0.0"]
5252
ollama = ["ollama>=0.4.8,<1.0.0"]
53-
openai = ["openai>=1.68.0,<3.0.0"]
53+
openai = ["openai>=1.68.0,<3.0.0", "aws-bedrock-token-generator>=1.1.0,<2.0.0"]
5454
writer = ["writer-sdk>=2.2.0,<3.0.0"]
5555
sagemaker = [
5656
"boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0",

src/strands/models/openai.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
2222
from ..types.streaming import StreamEvent
2323
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
24+
from ._openai_bedrock import AwsConfig, resolve_bedrock_client_args
2425
from ._validation import _has_location_source, validate_config_keys
2526
from .model import BaseModelConfig, Model
2627

@@ -71,6 +72,7 @@ def __init__(
7172
self,
7273
client: Client | None = None,
7374
client_args: dict[str, Any] | None = None,
75+
aws_config: AwsConfig | None = None,
7476
**model_config: Unpack[OpenAIConfig],
7577
) -> None:
7678
"""Initialize provider instance.
@@ -87,23 +89,53 @@ def __init__(
8789
Note: The client should not be shared across different asyncio event loops.
8890
client_args: Arguments for the OpenAI client (legacy approach).
8991
For a complete list of supported arguments, see https://pypi.org/project/openai/.
92+
May be combined with ``aws_config``; transport-level options like ``http_client``,
93+
``timeout``, or ``default_headers`` are preserved, while ``base_url`` and
94+
``api_key`` are always overridden by ``aws_config`` when both are set.
95+
aws_config: Route requests through Amazon Bedrock's Mantle (OpenAI-compatible)
96+
endpoint. Provide ``{"region": "us-east-1"}`` at minimum. Accepts optional
97+
``credentials_provider`` (a botocore ``CredentialProvider``) and ``expiry``
98+
(a ``datetime.timedelta`` up to 12h). When set, a fresh bearer token is minted
99+
on every request via ``aws-bedrock-token-generator`` and the OpenAI client is
100+
pointed at ``https://bedrock-mantle.<region>.api.aws/v1``. Cannot be combined
101+
with a pre-built ``client``.
90102
**model_config: Configuration options for the OpenAI model.
91103
92104
Raises:
93-
ValueError: If both `client` and `client_args` are provided.
105+
ValueError: If ``client`` is combined with ``client_args`` or ``aws_config``,
106+
or if ``aws_config`` is missing a region.
94107
"""
95108
validate_config_keys(model_config, self.OpenAIConfig)
96109
self.config = dict(model_config)
97110

98-
# Validate that only one client configuration method is provided
99-
if client is not None and client_args is not None and len(client_args) > 0:
111+
# Validate that client configuration methods are mutually exclusive where they conflict.
112+
# client_args + aws_config is allowed — aws_config will override base_url / api_key only.
113+
client_args_provided = client_args is not None and len(client_args) > 0
114+
if client is not None and client_args_provided:
100115
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
116+
if aws_config is not None:
117+
if client is not None:
118+
raise ValueError("'aws_config' cannot be combined with a pre-built 'client'.")
119+
if not aws_config.get("region"):
120+
raise ValueError("aws_config must include a non-empty 'region'.")
101121

102122
self._custom_client = client
103123
self.client_args = client_args or {}
124+
self._aws_config = aws_config
104125

105126
logger.debug("config=<%s> | initializing", self.config)
106127

128+
def _resolve_client_args(self) -> dict[str, Any]:
129+
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.
130+
131+
When ``aws_config`` is set, a fresh Bedrock Mantle bearer token is minted on every
132+
call and ``base_url`` / ``api_key`` are overridden. Any other entries from
133+
``client_args`` (e.g. ``http_client``, ``timeout``) are preserved.
134+
"""
135+
if self._aws_config is not None:
136+
return resolve_bedrock_client_args(self._aws_config, self.client_args)
137+
return self.client_args
138+
107139
@override
108140
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
109141
"""Update the OpenAI model configuration with the provided arguments.
@@ -590,11 +622,11 @@ async def _get_client(self) -> AsyncIterator[Any]:
590622
# Use the injected client (caller manages lifecycle)
591623
yield self._custom_client
592624
else:
593-
# Create a new client from client_args
625+
# Create a new client from resolved args (static client_args or freshly-minted Bedrock creds).
594626
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
595627
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
596628
# refer to https://github.com/encode/httpx/discussions/2959.
597-
async with openai.AsyncOpenAI(**self.client_args) as client:
629+
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
598630
yield client
599631

600632
@override

src/strands/models/openai_responses.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402
5959
from ..types.streaming import StreamEvent # noqa: E402
6060
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402
61+
from ._openai_bedrock import AwsConfig, resolve_bedrock_client_args # noqa: E402
6162
from ._validation import validate_config_keys # noqa: E402
6263
from .model import BaseModelConfig, Model # noqa: E402
6364

@@ -141,21 +142,52 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False):
141142
stateful: bool
142143

143144
def __init__(
144-
self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig]
145+
self,
146+
client_args: dict[str, Any] | None = None,
147+
aws_config: AwsConfig | None = None,
148+
**model_config: Unpack[OpenAIResponsesConfig],
145149
) -> None:
146150
"""Initialize provider instance.
147151
148152
Args:
149153
client_args: Arguments for the OpenAI client.
150154
For a complete list of supported arguments, see https://pypi.org/project/openai/.
155+
May be combined with ``aws_config``; transport-level options like ``http_client``,
156+
``timeout``, or ``default_headers`` are preserved, while ``base_url`` and
157+
``api_key`` are always overridden by ``aws_config`` when both are set.
158+
aws_config: Route requests through Amazon Bedrock's Mantle (OpenAI-compatible)
159+
endpoint. Provide ``{"region": "us-east-1"}`` at minimum. Accepts optional
160+
``credentials_provider`` (a botocore ``CredentialProvider``) and ``expiry``
161+
(a ``datetime.timedelta`` up to 12h). When set, a fresh bearer token is minted
162+
on every request via ``aws-bedrock-token-generator`` and the OpenAI client is
163+
pointed at ``https://bedrock-mantle.<region>.api.aws/v1``.
151164
**model_config: Configuration options for the OpenAI Responses API model.
165+
166+
Raises:
167+
ValueError: If ``aws_config`` is missing a region.
152168
"""
153169
validate_config_keys(model_config, self.OpenAIResponsesConfig)
154170
self.config = dict(model_config)
171+
172+
if aws_config is not None and not aws_config.get("region"):
173+
raise ValueError("aws_config must include a non-empty 'region'.")
174+
155175
self.client_args = client_args or {}
176+
self._aws_config = aws_config
156177

157178
logger.debug("config=<%s> | initializing", self.config)
158179

180+
def _resolve_client_args(self) -> dict[str, Any]:
181+
"""Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request.
182+
183+
When ``aws_config`` is set, a fresh Bedrock Mantle bearer token is minted on every
184+
call and ``base_url`` / ``api_key`` are overridden. Any other entries from
185+
``client_args`` (e.g. ``http_client``, ``timeout``) are preserved.
186+
"""
187+
if self._aws_config is not None:
188+
return resolve_bedrock_client_args(self._aws_config, self.client_args)
189+
return self.client_args
190+
159191
@property
160192
@override
161193
def stateful(self) -> bool:
@@ -215,7 +247,7 @@ async def count_tokens(
215247
count_tokens_fields = {"model", "input", "instructions", "tools"}
216248
request = {k: request[k] for k in request.keys() & count_tokens_fields}
217249

218-
async with openai.AsyncOpenAI(**self.client_args) as client:
250+
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
219251
response = await client.responses.input_tokens.count(**request)
220252
total_tokens: int = response.input_tokens
221253

@@ -267,7 +299,7 @@ async def stream(
267299

268300
logger.debug("invoking OpenAI Responses API model")
269301

270-
async with openai.AsyncOpenAI(**self.client_args) as client:
302+
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
271303
try:
272304
response = await client.responses.create(**request)
273305

@@ -447,7 +479,7 @@ async def structured_output(
447479
ContextWindowOverflowException: If the input exceeds the model's context window.
448480
ModelThrottledException: If the request is throttled by OpenAI (rate limits).
449481
"""
450-
async with openai.AsyncOpenAI(**self.client_args) as client:
482+
async with openai.AsyncOpenAI(**self._resolve_client_args()) as client:
451483
try:
452484
response = await client.responses.parse(
453485
model=self.get_config()["model_id"],

tests/strands/models/test_openai.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,3 +1710,102 @@ def test_format_request_messages_multiple_tool_calls_with_images():
17101710
},
17111711
]
17121712
assert tru_result == exp_result
1713+
1714+
1715+
# =============================================================================
1716+
# Bedrock Mantle (aws_config) integration with OpenAIModel
1717+
# =============================================================================
1718+
1719+
1720+
class TestOpenAIModelAwsConfig:
1721+
"""Tests for the Bedrock Mantle pathway via the aws_config kwarg."""
1722+
1723+
@pytest.fixture
1724+
def mock_provide_token(self):
1725+
with unittest.mock.patch("strands.models._openai_bedrock.provide_token") as mock:
1726+
mock.return_value = "bedrock-api-key-deadbeef&Version=1"
1727+
yield mock
1728+
1729+
def test_aws_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token):
1730+
"""aws_config produces the Mantle base_url and a minted bearer token as api_key."""
1731+
_ = openai_client
1732+
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})
1733+
1734+
# api_key is resolved per-request (lazy), so check via the resolved client_args at call time
1735+
resolved = model._resolve_client_args()
1736+
assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1"
1737+
assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1"
1738+
# Only region is forwarded when the user did not set optional kwargs,
1739+
# so provide_token's own defaults (e.g. 12h expiry) apply.
1740+
mock_provide_token.assert_called_once_with(region="us-east-1")
1741+
1742+
def test_aws_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token):
1743+
"""Optional credentials_provider and expiry are forwarded to provide_token."""
1744+
_ = openai_client
1745+
from datetime import timedelta
1746+
1747+
provider = unittest.mock.Mock()
1748+
model = OpenAIModel(
1749+
model_id="openai.gpt-oss-120b",
1750+
aws_config={
1751+
"region": "us-west-2",
1752+
"credentials_provider": provider,
1753+
"expiry": timedelta(minutes=15),
1754+
},
1755+
)
1756+
model._resolve_client_args()
1757+
mock_provide_token.assert_called_once_with(
1758+
region="us-west-2",
1759+
aws_credentials_provider=provider,
1760+
expiry=timedelta(minutes=15),
1761+
)
1762+
1763+
def test_aws_config_mints_token_per_request(self, openai_client, mock_provide_token):
1764+
"""Each call to _resolve_client_args mints a fresh token (long-lived processes)."""
1765+
_ = openai_client
1766+
model = OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})
1767+
model._resolve_client_args()
1768+
model._resolve_client_args()
1769+
model._resolve_client_args()
1770+
assert mock_provide_token.call_count == 3
1771+
1772+
def test_aws_config_conflicts_with_custom_client(self, openai_client):
1773+
"""Cannot pass both aws_config and a pre-built client."""
1774+
_ = openai_client
1775+
custom_client = unittest.mock.Mock()
1776+
with pytest.raises(ValueError, match="aws_config"):
1777+
OpenAIModel(
1778+
model_id="openai.gpt-oss-120b",
1779+
client=custom_client,
1780+
aws_config={"region": "us-east-1"},
1781+
)
1782+
1783+
def test_aws_config_merges_with_client_args(self, openai_client, mock_provide_token):
1784+
"""aws_config is allowed alongside client_args; base_url and api_key are overridden,
1785+
other transport-level options (timeout, http_client, default_headers) are preserved.
1786+
"""
1787+
_ = openai_client
1788+
sentinel_http_client = unittest.mock.Mock()
1789+
model = OpenAIModel(
1790+
model_id="openai.gpt-oss-120b",
1791+
client_args={
1792+
"api_key": "will-be-overridden",
1793+
"base_url": "https://also-overridden.example.com",
1794+
"timeout": 42,
1795+
"http_client": sentinel_http_client,
1796+
"default_headers": {"X-Trace-Id": "abc"},
1797+
},
1798+
aws_config={"region": "us-east-1"},
1799+
)
1800+
resolved = model._resolve_client_args()
1801+
assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1"
1802+
assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1"
1803+
assert resolved["timeout"] == 42
1804+
assert resolved["http_client"] is sentinel_http_client
1805+
assert resolved["default_headers"] == {"X-Trace-Id": "abc"}
1806+
1807+
def test_aws_config_requires_region(self, openai_client):
1808+
"""aws_config must include a region."""
1809+
_ = openai_client
1810+
with pytest.raises(ValueError, match="region"):
1811+
OpenAIModel(model_id="openai.gpt-oss-120b", aws_config={})

tests/strands/models/test_openai_responses.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,3 +1298,80 @@ async def test_fallback_logs_debug(self, model, openai_client, messages, caplog)
12981298
await model.count_tokens(messages=messages)
12991299

13001300
assert any("native token counting failed" in record.message for record in caplog.records)
1301+
1302+
1303+
# =============================================================================
1304+
# Bedrock Mantle (aws_config) integration with OpenAIResponsesModel
1305+
# =============================================================================
1306+
1307+
1308+
class TestOpenAIResponsesModelAwsConfig:
1309+
"""Tests for the Bedrock Mantle pathway via the aws_config kwarg."""
1310+
1311+
@pytest.fixture
1312+
def mock_provide_token(self):
1313+
with unittest.mock.patch("strands.models._openai_bedrock.provide_token") as mock:
1314+
mock.return_value = "bedrock-api-key-deadbeef&Version=1"
1315+
yield mock
1316+
1317+
def test_aws_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token):
1318+
_ = openai_client
1319+
model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})
1320+
resolved = model._resolve_client_args()
1321+
assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1"
1322+
assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1"
1323+
mock_provide_token.assert_called_once_with(region="us-east-1")
1324+
1325+
def test_aws_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token):
1326+
_ = openai_client
1327+
from datetime import timedelta
1328+
1329+
provider = unittest.mock.Mock()
1330+
model = OpenAIResponsesModel(
1331+
model_id="openai.gpt-oss-120b",
1332+
aws_config={
1333+
"region": "us-west-2",
1334+
"credentials_provider": provider,
1335+
"expiry": timedelta(minutes=15),
1336+
},
1337+
)
1338+
model._resolve_client_args()
1339+
mock_provide_token.assert_called_once_with(
1340+
region="us-west-2",
1341+
aws_credentials_provider=provider,
1342+
expiry=timedelta(minutes=15),
1343+
)
1344+
1345+
def test_aws_config_mints_token_per_request(self, openai_client, mock_provide_token):
1346+
_ = openai_client
1347+
model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", aws_config={"region": "us-east-1"})
1348+
model._resolve_client_args()
1349+
model._resolve_client_args()
1350+
assert mock_provide_token.call_count == 2
1351+
1352+
def test_aws_config_merges_with_client_args(self, openai_client, mock_provide_token):
1353+
"""aws_config is allowed alongside client_args; base_url and api_key are overridden,
1354+
other transport-level options are preserved.
1355+
"""
1356+
_ = openai_client
1357+
sentinel_http_client = unittest.mock.Mock()
1358+
model = OpenAIResponsesModel(
1359+
model_id="openai.gpt-oss-120b",
1360+
client_args={
1361+
"api_key": "will-be-overridden",
1362+
"base_url": "https://also-overridden.example.com",
1363+
"timeout": 42,
1364+
"http_client": sentinel_http_client,
1365+
},
1366+
aws_config={"region": "us-east-1"},
1367+
)
1368+
resolved = model._resolve_client_args()
1369+
assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1"
1370+
assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1"
1371+
assert resolved["timeout"] == 42
1372+
assert resolved["http_client"] is sentinel_http_client
1373+
1374+
def test_aws_config_requires_region(self, openai_client):
1375+
_ = openai_client
1376+
with pytest.raises(ValueError, match="region"):
1377+
OpenAIResponsesModel(model_id="openai.gpt-oss-120b", aws_config={})

0 commit comments

Comments
 (0)