Skip to content

Commit aa5b972

Browse files
committed
feat(bedrock): route through OpenAI-compatible endpoint via openai_endpoint config
1 parent e12ac9d commit aa5b972

3 files changed

Lines changed: 579 additions & 3 deletions

File tree

src/strands/models/bedrock.py

Lines changed: 252 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
"""AWS Bedrock model provider.
22
3-
- Docs: https://aws.amazon.com/bedrock/
3+
Supports two transports:
4+
5+
- Converse API (default) via the ``bedrock-runtime`` endpoint. Used when ``openai_endpoint``
6+
is not set on the config. Provides guardrails, prompt caching, and the full Converse feature
7+
set.
8+
- OpenAI-compatible endpoint (``bedrock-mantle``) when ``openai_endpoint`` is provided on the
9+
config. Routes through the OpenAI Python SDK to the Responses or Chat Completions API.
10+
Unlocks features such as server-side stateful conversations, Responses API reasoning, and
11+
built-in tools.
12+
13+
Generic inference parameters (``temperature``, ``top_p``, ``max_tokens``, ``stop_sequences``,
14+
``streaming``) apply to both transports and live at the top level of ``BedrockConfig``.
15+
Converse-only fields (``guardrail_*``, ``cache_*``, ``service_tier``, etc.) may not be combined
16+
with ``openai_endpoint``; doing so raises at init time.
17+
18+
Docs:
19+
20+
- Bedrock overview: https://aws.amazon.com/bedrock/
21+
- OpenAI-compatible endpoints: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html
422
"""
523

624
import asyncio
@@ -9,7 +27,7 @@
927
import os
1028
import warnings
1129
from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView
12-
from typing import Any, Literal, TypeVar, cast
30+
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
1331

1432
import boto3
1533
from botocore.config import Config as BotocoreConfig
@@ -34,6 +52,10 @@
3452
from ._validation import validate_config_keys
3553
from .model import BaseModelConfig, CacheConfig, Model
3654

55+
if TYPE_CHECKING:
56+
from .openai import OpenAIModel
57+
from .openai_responses import OpenAIResponsesModel
58+
3759
logger = logging.getLogger(__name__)
3860

3961
# See: `BedrockModel._get_default_model_with_warning` for why we need both
@@ -57,6 +79,73 @@
5779

5880
DEFAULT_READ_TIMEOUT = 120
5981

82+
# Bedrock OpenAI-compatible endpoint (Mantle). See:
83+
# https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html
84+
_BEDROCK_MANTLE_BASE_URL_TEMPLATE = "https://bedrock-mantle.{region}.api.aws/v1"
85+
86+
# Config fields that only apply to the Converse transport. Setting any of these together with
87+
# ``openai_endpoint`` would silently no-op, so we reject the combination at init time.
88+
#
89+
# ``include_tool_result_status`` is intentionally excluded: it is always auto-defaulted by
90+
# ``__init__`` and only affects Converse-side tool-result serialization. It has no effect on
91+
# the Mantle path either way, so requiring users to clear it would be a pure footgun.
92+
_CONVERSE_ONLY_CONFIG_KEYS: frozenset[str] = frozenset(
93+
{
94+
"additional_request_fields",
95+
"additional_response_field_paths",
96+
"cache_prompt",
97+
"cache_config",
98+
"cache_tools",
99+
"guardrail_id",
100+
"guardrail_trace",
101+
"guardrail_version",
102+
"guardrail_stream_processing_mode",
103+
"guardrail_redact_input",
104+
"guardrail_redact_input_message",
105+
"guardrail_redact_output",
106+
"guardrail_redact_output_message",
107+
"guardrail_latest_message",
108+
"service_tier",
109+
}
110+
)
111+
112+
113+
class OpenAIEndpointConfig(TypedDict, total=False):
114+
"""Configuration for routing a :class:`BedrockModel` through the OpenAI-compatible endpoint.
115+
116+
When this config is present on :class:`BedrockModel.BedrockConfig`, requests are sent
117+
through the Bedrock Mantle endpoint (``bedrock-mantle.<region>.api.aws``) using the OpenAI
118+
Python SDK instead of the Converse API. This unlocks features that are specific to the
119+
OpenAI-compatible surface, such as the Responses API's server-side stateful conversations
120+
and reasoning controls.
121+
122+
Generic inference parameters (``temperature``, ``top_p``, ``max_tokens``,
123+
``stop_sequences``, ``streaming``) continue to live on :class:`BedrockModel.BedrockConfig`
124+
and are forwarded to the underlying OpenAI model.
125+
126+
Attributes:
127+
api: Which OpenAI API surface to use. ``"responses"`` maps to the Responses API and
128+
``"chat_completions"`` maps to the Chat Completions API. Required.
129+
api_key: Bedrock API key to send as the bearer token. The OpenAI SDK is only the
130+
transport here; this is a Bedrock-issued key, not an OpenAI account key. The
131+
AWS docs recommend setting the ``OPENAI_API_KEY`` environment variable to your
132+
Bedrock API key, which is the OpenAI SDK's default env var. When ``api_key`` is
133+
omitted, the underlying SDK picks up ``OPENAI_API_KEY`` automatically.
134+
stateful: Enable server-side conversation state management. Responses API only.
135+
params: Extra parameters forwarded to the OpenAI SDK ``params`` dict. Use this for
136+
Responses-only options such as ``reasoning``.
137+
client_args: Extra arguments merged into the OpenAI client constructor. Use this
138+
to plug in a custom ``http_client`` (for example, to sign requests with AWS
139+
SigV4) or to override timeouts. ``api_key`` and ``base_url`` set here override
140+
the values derived from ``api_key`` and ``region``.
141+
"""
142+
143+
api: Literal["responses", "chat_completions"]
144+
api_key: str | None
145+
stateful: bool | None
146+
params: dict[str, Any] | None
147+
client_args: dict[str, Any] | None
148+
60149

61150
class BedrockModel(Model):
62151
"""AWS Bedrock model provider implementation.
@@ -94,6 +183,10 @@ class BedrockConfig(BaseModelConfig, total=False):
94183
model_id: The Bedrock model ID (e.g., "global.anthropic.claude-sonnet-4-6")
95184
include_tool_result_status: Flag to include status field in tool results.
96185
True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto".
186+
openai_endpoint: When set, route requests through Bedrock's OpenAI-compatible endpoint
187+
(``bedrock-mantle``) using the OpenAI Python SDK instead of the Converse API.
188+
See :class:`OpenAIEndpointConfig`. Converse-only fields (``guardrail_*``,
189+
``cache_*``, ``service_tier``, etc.) may not be combined with this option.
97190
service_tier: Service tier for the request, controlling the trade-off between latency and cost.
98191
Valid values: "default" (standard), "priority" (faster, premium), "flex" (cheaper, slower).
99192
Please check https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html for
@@ -127,6 +220,7 @@ class BedrockConfig(BaseModelConfig, total=False):
127220
streaming: bool | None
128221
temperature: float | None
129222
top_p: float | None
223+
openai_endpoint: OpenAIEndpointConfig | None
130224

131225
def __init__(
132226
self,
@@ -152,6 +246,7 @@ def __init__(
152246

153247
session = boto_session or boto3.Session()
154248
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
249+
self._resolved_region = resolved_region
155250
self.config = BedrockModel.BedrockConfig(
156251
model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config),
157252
include_tool_result_status="auto",
@@ -160,6 +255,22 @@ def __init__(
160255

161256
logger.debug("config=<%s> | initializing", self.config)
162257

258+
# When ``openai_endpoint`` is configured, requests are routed through the Bedrock Mantle
259+
# OpenAI-compatible endpoint via the OpenAI Python SDK. Skip the boto client since it is
260+
# not used on that path.
261+
self._openai_delegate: OpenAIModel | OpenAIResponsesModel | None = None
262+
endpoint_config = self.config.get("openai_endpoint")
263+
if endpoint_config is not None:
264+
self._validate_openai_endpoint_config()
265+
self._openai_delegate = self._build_openai_delegate()
266+
self.client = None
267+
logger.debug(
268+
"region=<%s>, api=<%s> | bedrock openai-compatible delegate created",
269+
resolved_region,
270+
endpoint_config["api"],
271+
)
272+
return
273+
163274
# Add strands-agents to the request user agent
164275
if boto_client_config:
165276
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
@@ -183,6 +294,98 @@ def __init__(
183294

184295
logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name)
185296

297+
def _validate_openai_endpoint_config(self) -> None:
298+
"""Validate that ``openai_endpoint`` is not combined with Converse-only fields.
299+
300+
Raises:
301+
ValueError: If a Converse-only config field is set alongside ``openai_endpoint``
302+
or if ``api`` is missing from ``openai_endpoint``.
303+
"""
304+
endpoint = cast(OpenAIEndpointConfig, self.config["openai_endpoint"])
305+
api = endpoint.get("api")
306+
if api not in ("responses", "chat_completions"):
307+
raise ValueError(f'openai_endpoint requires "api" to be "responses" or "chat_completions", got {api!r}')
308+
309+
conflicting = sorted(k for k in _CONVERSE_ONLY_CONFIG_KEYS if self.config.get(k) is not None)
310+
if conflicting:
311+
raise ValueError(
312+
"openai_endpoint cannot be combined with Converse-only config fields: "
313+
f"{conflicting}. Remove these fields or drop openai_endpoint to use the Converse API."
314+
)
315+
316+
# ``stateful`` is only meaningful for the Responses API.
317+
if endpoint.get("stateful") and api != "responses":
318+
raise ValueError(f'openai_endpoint.stateful is only supported when api="responses". Got api={api!r}.')
319+
320+
def _build_openai_delegate(self) -> "OpenAIModel | OpenAIResponsesModel":
321+
"""Construct the OpenAI-compatible delegate for the Mantle endpoint.
322+
323+
Forwards generic inference params from :class:`BedrockConfig` into the OpenAI SDK
324+
``params`` dict, translating names where the OpenAI and Bedrock conventions differ.
325+
326+
Returns:
327+
An :class:`OpenAIResponsesModel` or :class:`OpenAIModel` configured to talk to
328+
``bedrock-mantle.<region>.api.aws``.
329+
"""
330+
endpoint = cast(OpenAIEndpointConfig, self.config["openai_endpoint"])
331+
api = endpoint["api"]
332+
333+
# The Mantle base URL is fully determined by region; AWS owns the endpoint list:
334+
# https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html
335+
base_url = _BEDROCK_MANTLE_BASE_URL_TEMPLATE.format(region=self._resolved_region)
336+
337+
params: dict[str, Any] = dict(endpoint.get("params") or {})
338+
# Forward generic inference params from BedrockConfig. Translate naming where the
339+
# Responses API differs from Chat Completions (max_tokens -> max_output_tokens).
340+
max_tokens = self.config.get("max_tokens")
341+
if max_tokens is not None:
342+
params.setdefault("max_output_tokens" if api == "responses" else "max_tokens", max_tokens)
343+
temperature = self.config.get("temperature")
344+
if temperature is not None:
345+
params.setdefault("temperature", temperature)
346+
top_p = self.config.get("top_p")
347+
if top_p is not None:
348+
params.setdefault("top_p", top_p)
349+
stop_sequences = self.config.get("stop_sequences")
350+
if stop_sequences is not None:
351+
# Chat Completions uses `stop`; Responses does not accept stop sequences.
352+
if api == "chat_completions":
353+
params.setdefault("stop", stop_sequences)
354+
else:
355+
logger.debug("stop_sequences ignored when routing through Responses API")
356+
357+
# The OpenAI SDK's ``api_key`` parameter is just the bearer token it sends in the
358+
# Authorization header; when pointed at bedrock-mantle this is a Bedrock-issued key.
359+
# The AWS docs recommend setting OPENAI_API_KEY to the Bedrock API key so existing
360+
# OpenAI SDK code works unchanged. If the user does not pass ``api_key`` here, the
361+
# OpenAI SDK will read OPENAI_API_KEY from the environment on its own.
362+
client_args: dict[str, Any] = {"base_url": base_url}
363+
if api_key := endpoint.get("api_key"):
364+
client_args["api_key"] = api_key
365+
# User-supplied client_args win over derived defaults. This is the escape hatch for
366+
# plumbing a signed httpx client (SigV4), custom timeouts, etc.
367+
if extra_client_args := endpoint.get("client_args"):
368+
client_args.update(extra_client_args)
369+
370+
if api == "responses":
371+
from .openai_responses import OpenAIResponsesModel
372+
373+
stateful = bool(endpoint.get("stateful"))
374+
return OpenAIResponsesModel(
375+
client_args=client_args,
376+
model_id=self.config["model_id"],
377+
params=params,
378+
stateful=stateful,
379+
)
380+
381+
from .openai import OpenAIModel
382+
383+
return OpenAIModel(
384+
client_args=client_args,
385+
model_id=self.config["model_id"],
386+
params=params,
387+
)
388+
186389
@property
187390
def _cache_strategy(self) -> str | None:
188391
"""The cache strategy for this model based on its model ID.
@@ -194,6 +397,18 @@ def _cache_strategy(self) -> str | None:
194397
return "anthropic"
195398
return None
196399

400+
@property
401+
@override
402+
def stateful(self) -> bool:
403+
"""Whether the model manages conversation state server-side.
404+
405+
Delegates to the underlying OpenAI-compatible model when ``openai_endpoint`` is configured,
406+
otherwise returns False (the Converse API is always stateless).
407+
"""
408+
if self._openai_delegate is not None:
409+
return self._openai_delegate.stateful
410+
return False
411+
197412
@override
198413
def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore
199414
"""Update the Bedrock Model configuration with the provided arguments.
@@ -204,12 +419,19 @@ def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type:
204419
validate_config_keys(model_config, self.BedrockConfig)
205420
self.config.update(model_config)
206421

422+
# If the delegate is already built and the caller changed anything the delegate depends on,
423+
# rebuild it so subsequent calls pick up the new config. Skipped during __init__ where
424+
# ``_openai_delegate`` is not yet set.
425+
if getattr(self, "_openai_delegate", None) is not None:
426+
self._validate_openai_endpoint_config()
427+
self._openai_delegate = self._build_openai_delegate()
428+
207429
@override
208430
def get_config(self) -> BedrockConfig:
209431
"""Get the current Bedrock Model configuration.
210432
211433
Returns:
212-
The Bedrock model configuration.
434+
The Bedrock model configuration.405
213435
"""
214436
return self.config
215437

@@ -772,7 +994,15 @@ async def count_tokens(
772994
Returns:
773995
Total input token count.
774996
"""
997+
# The openai_endpoint path has no Bedrock Converse client and no equivalent native
998+
# count endpoint, so fall back to the base ``Model.count_tokens`` estimation
999+
# (tiktoken when available, heuristic otherwise).
1000+
if self._openai_delegate is not None:
1001+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
1002+
7751003
try:
1004+
# The openai_endpoint early-return above guarantees ``self.client`` exists here.
1005+
assert self.client is not None, "Bedrock Converse client is unavailable"
7761006
if system_prompt and system_prompt_content is None:
7771007
system_prompt_content = [{"text": system_prompt}]
7781008

@@ -836,6 +1066,16 @@ async def stream(
8361066
ContextWindowOverflowException: If the input exceeds the model's context window.
8371067
ModelThrottledException: If the model service is throttling requests.
8381068
"""
1069+
if self._openai_delegate is not None:
1070+
async for delegate_event in self._openai_delegate.stream(
1071+
messages,
1072+
tool_specs,
1073+
system_prompt,
1074+
tool_choice=tool_choice,
1075+
**kwargs,
1076+
):
1077+
yield delegate_event
1078+
return
8391079

8401080
def callback(event: StreamEvent | None = None) -> None:
8411081
loop.call_soon_threadsafe(queue.put_nowait, event)
@@ -885,6 +1125,8 @@ def _stream(
8851125
ContextWindowOverflowException: If the input exceeds the model's context window.
8861126
ModelThrottledException: If the model service is throttling requests.
8871127
"""
1128+
# Converse-only path; the Mantle delegate handles streaming directly in ``stream``.
1129+
assert self.client is not None, "Bedrock Converse client is unavailable"
8881130
try:
8891131
logger.debug("formatting request")
8901132
request = self._format_request(messages, tool_specs, system_prompt_content, tool_choice)
@@ -1110,6 +1352,13 @@ async def structured_output(
11101352
Yields:
11111353
Model events with the last being the structured output.
11121354
"""
1355+
if self._openai_delegate is not None:
1356+
async for delegate_event in self._openai_delegate.structured_output(
1357+
output_model, prompt, system_prompt, **kwargs
1358+
):
1359+
yield delegate_event
1360+
return
1361+
11131362
tool_spec = convert_pydantic_to_tool_spec(output_model)
11141363

11151364
response = self.stream(

0 commit comments

Comments
 (0)