diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index b1ed4563a..563345f03 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -17,7 +17,7 @@ from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages -from ..types.exceptions import ModelThrottledException +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported @@ -71,6 +71,14 @@ def __init__( else: self.client = LlamaAPIClient(**client_args) + OVERFLOW_MESSAGES = { + "context length exceeded", + "context window", + "max context length", + "prompt is too long", + "token limit", + } + @override def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: ignore """Update the Llama API Model configuration with the provided arguments. @@ -368,6 +376,11 @@ async def stream( response = self.client.chat.completions.create(**request) except llama_api_client.RateLimitError as e: raise ModelThrottledException(str(e)) from e + except Exception as error: + error_str = str(error).lower() + if any(msg in error_str for msg in self.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(error)) from error + raise logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index f44a11d30..7cfe10f8c 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -14,7 +14,7 @@ from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages -from ..types.exceptions import ModelThrottledException +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported @@ -97,6 +97,14 @@ def __init__( if api_key: self.client_args["api_key"] = api_key + OVERFLOW_MESSAGES = { + "context length exceeded", + "context window", + "max context length", + "prompt is too long", + "token limit", + } + @override def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore """Update the Mistral Model configuration with the provided arguments. @@ -500,7 +508,10 @@ async def stream( yield self.format_chunk({"chunk_type": "metadata", "data": chunk.data.usage}) except Exception as e: - if "rate" in str(e).lower() or "429" in str(e): + error_str = str(e).lower() + if any(msg in error_str for msg in self.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(e)) from e + if "rate" in error_str or "429" in str(e): raise ModelThrottledException(str(e)) from e raise diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 97cb7948a..3666942b8 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -13,6 +13,7 @@ from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported @@ -33,6 +34,13 @@ class OllamaModel(Model): - Tool/function calling """ + OVERFLOW_MESSAGES = { + "context length exceeded", + "context window", + "max context length", + "prompt is too long", + } + class OllamaConfig(TypedDict, total=False): """Configuration parameters for Ollama models. @@ -319,7 +327,14 @@ async def stream( tool_requested = False client = ollama.AsyncClient(self.host, **self.client_args) - response = await client.chat(**request) + + try: + response = await client.chat(**request) + except Exception as error: + error_str = str(error).lower() + if any(msg in error_str for msg in self.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(error)) from error + raise logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index 94774b363..32207eac0 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -15,7 +15,7 @@ from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages -from ..types.exceptions import ModelThrottledException +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported @@ -63,6 +63,14 @@ def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Un client_args = client_args or {} self.client = writerai.AsyncClient(**client_args) + OVERFLOW_MESSAGES = { + "context length exceeded", + "context window", + "max context length", + "prompt is too long", + "token limit", + } + @override def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: ignore[override] """Update the Writer Model configuration with the provided arguments. @@ -397,6 +405,11 @@ async def stream( response = await self.client.chat.chat(**request) except writerai.RateLimitError as e: raise ModelThrottledException(str(e)) from e + except Exception as error: + error_str = str(error).lower() + if any(msg in error_str for msg in self.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(error)) from error + raise yield self.format_chunk({"chunk_type": "message_start"}) yield self.format_chunk({"chunk_type": "content_block_start", "data_type": "text"})