diff --git a/pyproject.toml b/pyproject.toml index 83a7bbf4d..8a017cd07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] mistral = ["mistralai>=1.8.2,<2.0.0"] ollama = ["ollama>=0.4.8,<1.0.0"] -openai = ["openai>=1.68.0,<3.0.0"] +openai = ["openai>=1.68.0,<3.0.0", "aws-bedrock-token-generator>=1.1.0,<2.0.0"] writer = ["writer-sdk>=2.2.0,<3.0.0"] sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", diff --git a/src/strands/models/_openai_bedrock.py b/src/strands/models/_openai_bedrock.py new file mode 100644 index 000000000..149a47ec5 --- /dev/null +++ b/src/strands/models/_openai_bedrock.py @@ -0,0 +1,126 @@ +"""Internal helpers for routing OpenAI-compatible clients to Bedrock Mantle. + +Converts a ``bedrock_mantle_config`` dict into the ``base_url`` and ``api_key`` that the +OpenAI Python SDK consumes. Tokens are minted on demand via +``aws_bedrock_token_generator.provide_token`` so long-running agents survive the +bearer token's maximum lifetime. + +``aws_bedrock_token_generator`` is part of the ``openai`` extras group +(``pip install strands-agents[openai]``) but is *not* included in the ``litellm`` +or ``sagemaker`` extras, which also pull in the ``openai`` package. The import is +therefore lazy — it happens inside :func:`resolve_bedrock_client_args` so that +those other extras never trigger an ``ImportError`` at module load. +""" + +from __future__ import annotations + +from datetime import timedelta +from typing import Any, TypedDict + +import boto3 +from botocore.credentials import CredentialProvider + +_MANTLE_BASE_URL_TEMPLATE = "https://bedrock-mantle.{region}.api.aws/v1" +_MANTLE_DOCS_URL = "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html" + + +class BedrockMantleConfig(TypedDict, total=False): + """Config for routing an OpenAI-compatible client through Bedrock Mantle. + + Attributes: + region: AWS region hosting the Bedrock Mantle endpoint. If omitted, resolved + from ``boto_session`` (if provided) or the standard boto3 chain + (``AWS_REGION`` / ``AWS_DEFAULT_REGION`` / active profile / EC2 metadata). + A :class:`ValueError` is raised if none resolve. + boto_session: Optional :class:`boto3.Session` used to resolve the region when + ``region`` is not provided. Useful for picking up a non-default profile + without exporting env vars. + credentials_provider: Optional botocore :class:`~botocore.credentials.CredentialProvider` + forwarded to ``provide_token``. Omit to let the token generator use the + standard AWS credential chain. + expiry: Optional ``timedelta`` for the bearer token's lifetime, forwarded to + ``provide_token``. Defaults to the generator's built-in lifetime when + omitted. + """ + + region: str + boto_session: boto3.Session + credentials_provider: CredentialProvider + expiry: timedelta + + +def _resolve_region(config: BedrockMantleConfig) -> str: + """Resolve the AWS region, preferring explicit config then falling back to boto3. + + Raises: + ValueError: If no region can be resolved from the config, an attached session, + or the standard boto3 credential chain. + """ + region = config.get("region") + if region: + return region + + session = config.get("boto_session") + if session is not None and session.region_name: + return str(session.region_name) + + # ``boto3.Session()`` with no args reads ``AWS_REGION`` / ``AWS_DEFAULT_REGION``, + # the active profile, and falls back to EC2 instance metadata — the same chain + # :class:`BedrockModel` uses. + default_region = boto3.Session().region_name + if default_region: + return str(default_region) + + raise ValueError( + "Could not resolve an AWS region for Bedrock Mantle. Pass 'region' in " + "bedrock_mantle_config, attach a boto_session with a configured region, or set " + f"AWS_REGION in the environment. See {_MANTLE_DOCS_URL} for supported regions." + ) + + +def resolve_bedrock_client_args( + config: BedrockMantleConfig, client_args: dict[str, Any] | None = None +) -> dict[str, Any]: + """Resolve a ``BedrockMantleConfig`` (plus optional ``client_args``) into OpenAI client kwargs. + + Mints a fresh bearer token on every call. Callers are expected to validate that + ``client_args`` does not contain ``base_url`` or ``api_key`` before calling this + function (typically at ``__init__`` time for fail-fast behavior). + + Raises: + ValueError: If no region can be resolved. + ImportError: If ``aws-bedrock-token-generator`` is not installed. + RuntimeError: If token minting fails (e.g. missing AWS credentials). + """ + region = _resolve_region(config) + + # ``aws-bedrock-token-generator`` is included in the ``openai`` extras group but not in + # ``litellm`` or ``sagemaker`` (which also depend on the ``openai`` package). The lazy + # import keeps those extras from hitting an ImportError at module load. + try: + from aws_bedrock_token_generator import provide_token + except ImportError as e: + raise ImportError( + "bedrock_mantle_config requires the 'aws-bedrock-token-generator' package. " + "Install it with: pip install strands-agents[openai]" + ) from e + + # Only forward kwargs the user set; provide_token rejects expiry=None. + token_kwargs: dict[str, Any] = {"region": region} + if "credentials_provider" in config: + token_kwargs["aws_credentials_provider"] = config["credentials_provider"] + if "expiry" in config: + token_kwargs["expiry"] = config["expiry"] + + try: + token = provide_token(**token_kwargs) + except Exception as e: + raise RuntimeError( + f"Failed to mint Bedrock Mantle bearer token for region '{region}'. " + "Verify your AWS credentials and network connectivity." + ) from e + + resolved: dict[str, Any] = dict(client_args or {}) + resolved["base_url"] = _MANTLE_BASE_URL_TEMPLATE.format(region=region) + resolved["api_key"] = token + return resolved diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index c4be7d360..ea16c7713 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -21,6 +21,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args from ._validation import _has_location_source, validate_config_keys from .model import BaseModelConfig, Model @@ -71,6 +72,7 @@ def __init__( self, client: Client | None = None, client_args: dict[str, Any] | None = None, + bedrock_mantle_config: BedrockMantleConfig | None = None, **model_config: Unpack[OpenAIConfig], ) -> None: """Initialize provider instance. @@ -87,23 +89,50 @@ def __init__( Note: The client should not be shared across different asyncio event loops. client_args: Arguments for the OpenAI client (legacy approach). For a complete list of supported arguments, see https://pypi.org/project/openai/. + May be combined with ``bedrock_mantle_config``; when both are set, + ``bedrock_mantle_config`` derives ``base_url`` and ``api_key`` (which must not + appear in ``client_args``). + bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle + (OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted + keys. When set, a fresh bearer token is minted on every request. Cannot be + combined with a pre-built ``client``. **model_config: Configuration options for the OpenAI model. Raises: - ValueError: If both `client` and `client_args` are provided. + ValueError: If ``client`` is combined with ``client_args`` or ``bedrock_mantle_config``. """ validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) - # Validate that only one client configuration method is provided - if client is not None and client_args is not None and len(client_args) > 0: + # client_args + bedrock_mantle_config is allowed; the config derives base_url / api_key. + client_args_provided = client_args is not None and len(client_args) > 0 + if client is not None and client_args_provided: raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + if bedrock_mantle_config is not None and client is not None: + raise ValueError("'bedrock_mantle_config' cannot be combined with a pre-built 'client'.") + if bedrock_mantle_config is not None and client_args: + conflicting = [k for k in ("api_key", "base_url") if k in client_args] + if conflicting: + raise ValueError( + f"client_args must not contain {conflicting} when bedrock_mantle_config is set; " + "these are derived from the Mantle config automatically." + ) self._custom_client = client self.client_args = client_args or {} + self._bedrock_mantle_config = bedrock_mantle_config logger.debug("config=<%s> | initializing", self.config) + def _resolve_client_args(self) -> dict[str, Any]: + """Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request. + + Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set. + """ + if self._bedrock_mantle_config is not None: + return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args) + return self.client_args + @override def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] """Update the OpenAI model configuration with the provided arguments. @@ -590,11 +619,10 @@ async def _get_client(self) -> AsyncIterator[Any]: # Use the injected client (caller manages lifecycle) yield self._custom_client else: - # Create a new client from client_args # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying # httpx client. The asyncio event loop does not allow connections to be shared. For more details, please # refer to https://github.com/encode/httpx/discussions/2959. - async with openai.AsyncOpenAI(**self.client_args) as client: + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: yield client @override diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 73a889aad..4aff07ccd 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -58,6 +58,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 from ..types.streaming import StreamEvent # noqa: E402 from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 +from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args # noqa: E402 from ._validation import validate_config_keys # noqa: E402 from .model import BaseModelConfig, Model # noqa: E402 @@ -141,21 +142,48 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False): stateful: bool def __init__( - self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig] + self, + client_args: dict[str, Any] | None = None, + bedrock_mantle_config: BedrockMantleConfig | None = None, + **model_config: Unpack[OpenAIResponsesConfig], ) -> None: """Initialize provider instance. Args: client_args: Arguments for the OpenAI client. For a complete list of supported arguments, see https://pypi.org/project/openai/. + May be combined with ``bedrock_mantle_config``; when both are set, the config + derives ``base_url`` and ``api_key`` (which must not appear in ``client_args``). + bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle + (OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted + keys. When set, a fresh bearer token is minted on every request. **model_config: Configuration options for the OpenAI Responses API model. """ validate_config_keys(model_config, self.OpenAIResponsesConfig) self.config = dict(model_config) + self.client_args = client_args or {} + self._bedrock_mantle_config = bedrock_mantle_config + + if bedrock_mantle_config is not None and client_args: + conflicting = [k for k in ("api_key", "base_url") if k in client_args] + if conflicting: + raise ValueError( + f"client_args must not contain {conflicting} when bedrock_mantle_config is set; " + "these are derived from the Mantle config automatically." + ) logger.debug("config=<%s> | initializing", self.config) + def _resolve_client_args(self) -> dict[str, Any]: + """Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request. + + Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set. + """ + if self._bedrock_mantle_config is not None: + return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args) + return self.client_args + @property @override def stateful(self) -> bool: @@ -215,7 +243,7 @@ async def count_tokens( count_tokens_fields = {"model", "input", "instructions", "tools"} request = {k: request[k] for k in request.keys() & count_tokens_fields} - async with openai.AsyncOpenAI(**self.client_args) as client: + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: response = await client.responses.input_tokens.count(**request) total_tokens: int = response.input_tokens @@ -267,7 +295,7 @@ async def stream( logger.debug("invoking OpenAI Responses API model") - async with openai.AsyncOpenAI(**self.client_args) as client: + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: try: response = await client.responses.create(**request) @@ -447,7 +475,7 @@ async def structured_output( ContextWindowOverflowException: If the input exceeds the model's context window. ModelThrottledException: If the request is throttled by OpenAI (rate limits). """ - async with openai.AsyncOpenAI(**self.client_args) as client: + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: try: response = await client.responses.parse( model=self.get_config()["model_id"], diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 94e4caa3f..b43915b07 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1,4 +1,5 @@ import logging +import os import unittest.mock import openai @@ -1710,3 +1711,159 @@ def test_format_request_messages_multiple_tool_calls_with_images(): }, ] assert tru_result == exp_result + + +# ============================================================================= +# Bedrock Mantle (bedrock_mantle_config) integration with OpenAIModel +# ============================================================================= + + +class TestOpenAIModelBedrockMantleConfig: + @pytest.fixture + def mock_provide_token(self): + with unittest.mock.patch("aws_bedrock_token_generator.provide_token") as mock: + mock.return_value = "bedrock-api-key-deadbeef&Version=1" + yield mock + + def test_bedrock_mantle_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token): + """bedrock_mantle_config produces the Mantle base_url and a minted bearer token as api_key.""" + _ = openai_client + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + + # Token is minted lazily per request, so inspect the resolved kwargs. + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + # Optional kwargs aren't forwarded so provide_token's own defaults apply. + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token): + """Optional credentials_provider and expiry are forwarded to provide_token.""" + _ = openai_client + from datetime import timedelta + + provider = unittest.mock.Mock() + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={ + "region": "us-west-2", + "credentials_provider": provider, + "expiry": timedelta(minutes=15), + }, + ) + model._resolve_client_args() + mock_provide_token.assert_called_once_with( + region="us-west-2", + aws_credentials_provider=provider, + expiry=timedelta(minutes=15), + ) + + def test_bedrock_mantle_config_mints_token_per_request(self, openai_client, mock_provide_token): + """Each call to _resolve_client_args mints a fresh token (long-lived processes).""" + _ = openai_client + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + model._resolve_client_args() + model._resolve_client_args() + model._resolve_client_args() + assert mock_provide_token.call_count == 3 + + def test_bedrock_mantle_config_conflicts_with_custom_client(self, openai_client): + """Cannot pass both bedrock_mantle_config and a pre-built client.""" + _ = openai_client + custom_client = unittest.mock.Mock() + with pytest.raises(ValueError, match="bedrock_mantle_config"): + OpenAIModel( + model_id="openai.gpt-oss-120b", + client=custom_client, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_merges_with_client_args(self, openai_client, mock_provide_token): + """bedrock_mantle_config composes with client_args; transport options are preserved.""" + _ = openai_client + sentinel_http_client = unittest.mock.Mock() + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + client_args={ + "timeout": 42, + "http_client": sentinel_http_client, + "default_headers": {"X-Trace-Id": "abc"}, + }, + bedrock_mantle_config={"region": "us-east-1"}, + ) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + assert resolved["timeout"] == 42 + assert resolved["http_client"] is sentinel_http_client + assert resolved["default_headers"] == {"X-Trace-Id": "abc"} + + def test_bedrock_mantle_config_rejects_base_url_in_client_args(self, openai_client): + """client_args must not contain base_url or api_key when bedrock_mantle_config is set.""" + _ = openai_client + with pytest.raises(ValueError, match="client_args must not contain"): + OpenAIModel( + model_id="openai.gpt-oss-120b", + client_args={"base_url": "https://custom.example.com"}, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_requires_region(self, openai_client): + """bedrock_mantle_config raises when no region can be resolved from config, session, or env.""" + _ = openai_client + with ( + unittest.mock.patch("boto3.Session") as mock_session_cls, + unittest.mock.patch.dict(os.environ, {}, clear=True), + ): + mock_session_cls.return_value.region_name = None + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + with pytest.raises(ValueError, match="Could not resolve an AWS region"): + model._resolve_client_args() + + def test_bedrock_mantle_config_region_resolved_from_boto3_default(self, openai_client, mock_provide_token): + """When region is omitted, the default boto3 session chain resolves it.""" + _ = openai_client + with unittest.mock.patch("boto3.Session") as mock_session_cls: + mock_session_cls.return_value.region_name = "eu-west-1" + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.eu-west-1.api.aws/v1" + mock_provide_token.assert_called_once_with(region="eu-west-1") + + def test_bedrock_mantle_config_region_resolved_from_boto_session(self, openai_client, mock_provide_token): + """An explicit ``boto_session`` supplies the region when ``region`` is omitted.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"boto_session": session}, + ) + + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.ap-southeast-2.api.aws/v1" + mock_provide_token.assert_called_once_with(region="ap-southeast-2") + + def test_bedrock_mantle_config_explicit_region_wins_over_boto_session(self, openai_client, mock_provide_token): + """``region`` takes precedence over a session's region.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"region": "us-east-1", "boto_session": session}, + ) + + model._resolve_client_args() + + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_wraps_token_failures_with_context(self, openai_client, mock_provide_token): + """provide_token failures are wrapped in a RuntimeError with actionable context.""" + _ = openai_client + mock_provide_token.side_effect = RuntimeError("no credentials in chain") + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + with pytest.raises(RuntimeError, match="Bedrock Mantle bearer token.*us-east-1"): + model._resolve_client_args() diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 88cbee326..b35d2d0de 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1,3 +1,4 @@ +import os import unittest.mock import openai @@ -1298,3 +1299,125 @@ async def test_fallback_logs_debug(self, model, openai_client, messages, caplog) await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + +# ============================================================================= +# Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel +# ============================================================================= + + +class TestOpenAIResponsesModelBedrockMantleConfig: + @pytest.fixture + def mock_provide_token(self): + with unittest.mock.patch("aws_bedrock_token_generator.provide_token") as mock: + mock.return_value = "bedrock-api-key-deadbeef&Version=1" + yield mock + + def test_bedrock_mantle_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token): + _ = openai_client + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token): + _ = openai_client + from datetime import timedelta + + provider = unittest.mock.Mock() + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={ + "region": "us-west-2", + "credentials_provider": provider, + "expiry": timedelta(minutes=15), + }, + ) + model._resolve_client_args() + mock_provide_token.assert_called_once_with( + region="us-west-2", + aws_credentials_provider=provider, + expiry=timedelta(minutes=15), + ) + + def test_bedrock_mantle_config_mints_token_per_request(self, openai_client, mock_provide_token): + _ = openai_client + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + model._resolve_client_args() + model._resolve_client_args() + assert mock_provide_token.call_count == 2 + + def test_bedrock_mantle_config_merges_with_client_args(self, openai_client, mock_provide_token): + """bedrock_mantle_config composes with client_args; transport options are preserved.""" + _ = openai_client + sentinel_http_client = unittest.mock.Mock() + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + client_args={ + "timeout": 42, + "http_client": sentinel_http_client, + }, + bedrock_mantle_config={"region": "us-east-1"}, + ) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + assert resolved["timeout"] == 42 + assert resolved["http_client"] is sentinel_http_client + + def test_bedrock_mantle_config_rejects_base_url_in_client_args(self, openai_client): + """client_args must not contain base_url or api_key when bedrock_mantle_config is set.""" + _ = openai_client + with pytest.raises(ValueError, match="client_args must not contain"): + OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + client_args={"api_key": "should-not-be-here"}, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_requires_region(self, openai_client): + """bedrock_mantle_config raises when no region can be resolved from config, session, or env.""" + _ = openai_client + with ( + unittest.mock.patch("boto3.Session") as mock_session_cls, + unittest.mock.patch.dict(os.environ, {}, clear=True), + ): + mock_session_cls.return_value.region_name = None + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + with pytest.raises(ValueError, match="Could not resolve an AWS region"): + model._resolve_client_args() + + def test_bedrock_mantle_config_region_resolved_from_boto3_default(self, openai_client, mock_provide_token): + """When region is omitted, the default boto3 session chain resolves it.""" + _ = openai_client + with unittest.mock.patch("boto3.Session") as mock_session_cls: + mock_session_cls.return_value.region_name = "eu-west-1" + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.eu-west-1.api.aws/v1" + mock_provide_token.assert_called_once_with(region="eu-west-1") + + def test_bedrock_mantle_config_region_resolved_from_boto_session(self, openai_client, mock_provide_token): + """An explicit ``boto_session`` supplies the region when ``region`` is omitted.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"boto_session": session}, + ) + + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.ap-southeast-2.api.aws/v1" + mock_provide_token.assert_called_once_with(region="ap-southeast-2") + + def test_bedrock_mantle_config_wraps_token_failures_with_context(self, openai_client, mock_provide_token): + """provide_token failures are wrapped in a RuntimeError with actionable context.""" + _ = openai_client + mock_provide_token.side_effect = RuntimeError("no credentials in chain") + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + with pytest.raises(RuntimeError, match="Bedrock Mantle bearer token.*us-east-1"): + model._resolve_client_args() diff --git a/tests_integ/models/test_model_mantle.py b/tests_integ/models/test_model_mantle.py index 1dc029344..7cc032146 100644 --- a/tests_integ/models/test_model_mantle.py +++ b/tests_integ/models/test_model_mantle.py @@ -1,61 +1,46 @@ -"""Integration tests for OpenAI Responses API on Bedrock Mantle with AWS credentials.""" +"""Integration tests for OpenAI-compatible APIs on Bedrock Mantle. + +Exercises the ``bedrock_mantle_config`` pathway on ``OpenAIModel`` (Chat Completions) and +``OpenAIResponsesModel`` (Responses API) against the live +``bedrock-mantle..api.aws/v1`` endpoint. Credentials come from the +ambient AWS credential chain; no explicit API key is passed by the user. +""" -import httpx import pytest -from botocore.auth import SigV4Auth -from botocore.awsrequest import AWSRequest -from botocore.session import Session as BotocoreSession from strands import Agent +from strands.models.openai import OpenAIModel from strands.models.openai_responses import OpenAIResponsesModel +_REGION = "us-east-1" +_MODEL_ID = "openai.gpt-oss-120b" -class _SigV4Auth(httpx.Auth): - """httpx Auth handler that signs requests with AWS SigV4.""" - - def __init__(self, region: str): - session = BotocoreSession() - self.credentials = session.get_credentials().get_frozen_credentials() - self.signer = SigV4Auth(self.credentials, "bedrock", region) - - def auth_flow(self, request: httpx.Request): - aws_request = AWSRequest( - method=request.method, - url=str(request.url), - headers=dict(request.headers), - data=request.content, - ) - self.signer.add_auth(aws_request) - for key, value in aws_request.headers.items(): - request.headers[key] = value - yield request +@pytest.fixture +def bedrock_mantle_config(): + return {"region": _REGION} -class _NonClosingAsyncClient(httpx.AsyncClient): - """AsyncClient that survives the OpenAI SDK's context manager lifecycle.""" - async def aclose(self) -> None: - pass +@pytest.fixture +def chat_completions_model(bedrock_mantle_config): + return OpenAIModel(model_id=_MODEL_ID, bedrock_mantle_config=bedrock_mantle_config) @pytest.fixture -def client_args(): - region = "us-east-1" - return { - "api_key": "unused", - "base_url": f"https://bedrock-mantle.{region}.api.aws/v1", - "http_client": _NonClosingAsyncClient(auth=_SigV4Auth(region)), - } +def model(bedrock_mantle_config): + return OpenAIResponsesModel(model_id=_MODEL_ID, bedrock_mantle_config=bedrock_mantle_config) @pytest.fixture -def model(client_args): - return OpenAIResponsesModel(model_id="openai.gpt-oss-120b", client_args=client_args) +def stateful_model(bedrock_mantle_config): + return OpenAIResponsesModel(model_id=_MODEL_ID, stateful=True, bedrock_mantle_config=bedrock_mantle_config) -@pytest.fixture -def stateful_model(client_args): - return OpenAIResponsesModel(model_id="openai.gpt-oss-120b", stateful=True, client_args=client_args) +def test_chat_completions_agent_invoke(chat_completions_model): + """OpenAIModel (Chat Completions) reaches Mantle via bedrock_mantle_config.""" + agent = Agent(model=chat_completions_model, system_prompt="Reply in one short sentence.", callback_handler=None) + result = agent("What is 2+2?") + assert "4" in str(result) def test_agent_invoke(model): @@ -74,11 +59,11 @@ def test_responses_server_side_conversation(stateful_model): assert "alice" in str(result).lower() -def test_reasoning_content_multi_turn(client_args): +def test_reasoning_content_multi_turn(bedrock_mantle_config): """Test that reasoning content from gpt-oss models doesn't break multi-turn conversations.""" model = OpenAIResponsesModel( - model_id="openai.gpt-oss-120b", - client_args=client_args, + model_id=_MODEL_ID, + bedrock_mantle_config=bedrock_mantle_config, params={"reasoning": {"effort": "low"}}, ) agent = Agent(model=model, system_prompt="Reply in one short sentence.", callback_handler=None)