Skip to content

Commit 7db79bb

Browse files
authored
fix(openai): Handles Bedrock-style context overflow errors for OpenAI-compatible endpoints (strands-agents#1529)
1 parent ea1ea1c commit 7db79bb

2 files changed

Lines changed: 111 additions & 0 deletions

File tree

src/strands/models/openai.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@
2727

2828
T = TypeVar("T", bound=BaseModel)
2929

30+
# Alternative context overflow error messages
31+
# These are commonly returned by OpenAI-compatible endpoints wrapping other providers
32+
# (e.g., Databricks serving Bedrock models)
33+
_CONTEXT_OVERFLOW_MESSAGES = [
34+
"Input is too long for requested model",
35+
"input length and `max_tokens` exceed context limit",
36+
"too many total text bytes",
37+
]
38+
3039

3140
class Client(Protocol):
3241
"""Protocol defining the OpenAI-compatible interface for the underlying provider client."""
@@ -600,6 +609,14 @@ async def stream(
600609
# Rate limits (including TPM) require waiting/retrying, not context reduction
601610
logger.warning("OpenAI threw rate limit error")
602611
raise ModelThrottledException(str(e)) from e
612+
except openai.APIError as e:
613+
# Check for alternative context overflow error messages
614+
error_message = str(e)
615+
if any(overflow_msg in error_message for overflow_msg in _CONTEXT_OVERFLOW_MESSAGES):
616+
logger.warning("context window overflow error detected")
617+
raise ContextWindowOverflowException(error_message) from e
618+
# Re-raise other APIError exceptions
619+
raise
603620

604621
logger.debug("got response from model")
605622
yield self.format_chunk({"chunk_type": "message_start"})
@@ -723,6 +740,14 @@ async def structured_output(
723740
# Rate limits (including TPM) require waiting/retrying, not context reduction
724741
logger.warning("OpenAI threw rate limit error")
725742
raise ModelThrottledException(str(e)) from e
743+
except openai.APIError as e:
744+
# Check for alternative context overflow error messages
745+
error_message = str(e)
746+
if any(overflow_msg in error_message for overflow_msg in _CONTEXT_OVERFLOW_MESSAGES):
747+
logger.warning("context window overflow error detected")
748+
raise ContextWindowOverflowException(error_message) from e
749+
# Re-raise other APIError exceptions
750+
raise
726751

727752
parsed: T | None = None
728753
# Find the first choice with tool_calls

tests/strands/models/test_openai.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,92 @@ async def test_stream_context_overflow_exception(openai_client, model, messages)
10351035
assert exc_info.value.__cause__ == mock_error
10361036

10371037

1038+
@pytest.mark.asyncio
1039+
@pytest.mark.parametrize(
1040+
"error_message",
1041+
[
1042+
"Input is too long for requested model",
1043+
"input length and `max_tokens` exceed context limit",
1044+
"too many total text bytes",
1045+
],
1046+
)
1047+
async def test_stream_alternative_context_overflow_messages(openai_client, model, messages, error_message):
1048+
"""Test that alternative context overflow messages in APIError are properly converted."""
1049+
# Create a mock OpenAI APIError with alternative context overflow message
1050+
mock_error = openai.APIError(
1051+
message=error_message,
1052+
request=unittest.mock.MagicMock(),
1053+
body={"error": {"message": error_message}},
1054+
)
1055+
1056+
# Configure the mock client to raise the APIError
1057+
openai_client.chat.completions.create.side_effect = mock_error
1058+
1059+
# Test that the stream method converts the error properly
1060+
with pytest.raises(ContextWindowOverflowException) as exc_info:
1061+
async for _ in model.stream(messages):
1062+
pass
1063+
1064+
# Verify the exception message contains the original error
1065+
assert error_message in str(exc_info.value)
1066+
assert exc_info.value.__cause__ == mock_error
1067+
1068+
1069+
@pytest.mark.asyncio
1070+
@pytest.mark.parametrize(
1071+
"error_message",
1072+
[
1073+
"Input is too long for requested model",
1074+
"input length and `max_tokens` exceed context limit",
1075+
"too many total text bytes",
1076+
],
1077+
)
1078+
async def test_structured_output_alternative_context_overflow_messages(
1079+
openai_client, model, messages, test_output_model_cls, error_message
1080+
):
1081+
"""Test that alternative context overflow messages in APIError are properly converted in structured output."""
1082+
# Create a mock OpenAI APIError with alternative context overflow message
1083+
mock_error = openai.APIError(
1084+
message=error_message,
1085+
request=unittest.mock.MagicMock(),
1086+
body={"error": {"message": error_message}},
1087+
)
1088+
1089+
# Configure the mock client to raise the APIError
1090+
openai_client.beta.chat.completions.parse.side_effect = mock_error
1091+
1092+
# Test that the structured_output method converts the error properly
1093+
with pytest.raises(ContextWindowOverflowException) as exc_info:
1094+
async for _ in model.structured_output(test_output_model_cls, messages):
1095+
pass
1096+
1097+
# Verify the exception message contains the original error
1098+
assert error_message in str(exc_info.value)
1099+
assert exc_info.value.__cause__ == mock_error
1100+
1101+
1102+
@pytest.mark.asyncio
1103+
async def test_stream_api_error_passthrough(openai_client, model, messages):
1104+
"""Test that APIError without overflow messages passes through unchanged."""
1105+
# Create a mock OpenAI APIError without overflow message
1106+
mock_error = openai.APIError(
1107+
message="Some other API error",
1108+
request=unittest.mock.MagicMock(),
1109+
body={"error": {"message": "Some other API error"}},
1110+
)
1111+
1112+
# Configure the mock client to raise the APIError
1113+
openai_client.chat.completions.create.side_effect = mock_error
1114+
1115+
# Test that APIError without overflow messages passes through
1116+
with pytest.raises(openai.APIError) as exc_info:
1117+
async for _ in model.stream(messages):
1118+
pass
1119+
1120+
# Verify the original exception is raised, not ContextWindowOverflowException
1121+
assert exc_info.value == mock_error
1122+
1123+
10381124
@pytest.mark.asyncio
10391125
async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages):
10401126
"""Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException."""

0 commit comments

Comments
 (0)