Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ModelThrottledException
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent, Usage
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
Expand Down Expand Up @@ -71,6 +71,14 @@ def __init__(
else:
self.client = LlamaAPIClient(**client_args)

OVERFLOW_MESSAGES = {
"context length exceeded",
"context window",
"max context length",
"prompt is too long",
"token limit",
}

@override
def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: ignore
"""Update the Llama API Model configuration with the provided arguments.
Expand Down Expand Up @@ -368,6 +376,11 @@ async def stream(
response = self.client.chat.completions.create(**request)
except llama_api_client.RateLimitError as e:
raise ModelThrottledException(str(e)) from e
except Exception as error:
error_str = str(error).lower()
if any(msg in error_str for msg in self.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(error)) from error
raise

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
Expand Down
15 changes: 13 additions & 2 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ModelThrottledException
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
Expand Down Expand Up @@ -97,6 +97,14 @@ def __init__(
if api_key:
self.client_args["api_key"] = api_key

OVERFLOW_MESSAGES = {
"context length exceeded",
"context window",
"max context length",
"prompt is too long",
"token limit",
}

@override
def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore
"""Update the Mistral Model configuration with the provided arguments.
Expand Down Expand Up @@ -500,7 +508,10 @@ async def stream(
yield self.format_chunk({"chunk_type": "metadata", "data": chunk.data.usage})

except Exception as e:
if "rate" in str(e).lower() or "429" in str(e):
error_str = str(e).lower()
if any(msg in error_str for msg in self.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(e)) from e
if "rate" in error_str or "429" in str(e):
raise ModelThrottledException(str(e)) from e
raise

Expand Down
17 changes: 16 additions & 1 deletion src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolChoice, ToolSpec
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
Expand All @@ -33,6 +34,13 @@ class OllamaModel(Model):
- Tool/function calling
"""

OVERFLOW_MESSAGES = {
"context length exceeded",
"context window",
"max context length",
"prompt is too long",
}

class OllamaConfig(TypedDict, total=False):
"""Configuration parameters for Ollama models.

Expand Down Expand Up @@ -319,7 +327,14 @@ async def stream(
tool_requested = False

client = ollama.AsyncClient(self.host, **self.client_args)
response = await client.chat(**request)

try:
response = await client.chat(**request)
except Exception as error:
error_str = str(error).lower()
if any(msg in error_str for msg in self.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(error)) from error
raise

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
Expand Down
15 changes: 14 additions & 1 deletion src/strands/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ModelThrottledException
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
Expand Down Expand Up @@ -63,6 +63,14 @@ def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Un
client_args = client_args or {}
self.client = writerai.AsyncClient(**client_args)

OVERFLOW_MESSAGES = {
"context length exceeded",
"context window",
"max context length",
"prompt is too long",
"token limit",
}

@override
def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: ignore[override]
"""Update the Writer Model configuration with the provided arguments.
Expand Down Expand Up @@ -397,6 +405,11 @@ async def stream(
response = await self.client.chat.chat(**request)
except writerai.RateLimitError as e:
raise ModelThrottledException(str(e)) from e
except Exception as error:
error_str = str(error).lower()
if any(msg in error_str for msg in self.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(error)) from error
raise

yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_block_start", "data_type": "text"})
Expand Down