Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..types.exceptions import (
ContextWindowOverflowException,
ModelThrottledException,
ProviderTokenCountError,
)
from ..types.streaming import CitationsDelta, StreamEvent
from ..types.tools import ToolChoice, ToolSpec
Expand Down Expand Up @@ -789,7 +790,10 @@ async def count_tokens(
modelId=self.config["model_id"],
input={"converse": converse_input},
)
total_tokens: int = response["inputTokens"]
input_tokens = response.get("inputTokens")
if input_tokens is None:
raise ProviderTokenCountError("Bedrock count_tokens returned None for inputTokens")
total_tokens: int = input_tokens

logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens)
return total_tokens
Expand Down
4 changes: 2 additions & 2 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing_extensions import Required, Unpack, override

from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages, SystemContentBlock
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException, ProviderTokenCountError
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolSpec
from ._validation import _has_location_source, validate_config_keys
Expand Down Expand Up @@ -465,7 +465,7 @@ async def count_tokens(
contents=contents,
)
if response.total_tokens is None:
raise ValueError("Gemini count_tokens returned None for total_tokens")
raise ProviderTokenCountError("Gemini count_tokens returned None for total_tokens")
total_tokens: int = response.total_tokens

# The google-genai SDK explicitly raises ValueError for system_instruction, tools, and
Expand Down
10 changes: 10 additions & 0 deletions src/strands/types/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ class SnapshotException(Exception):
pass


class ProviderTokenCountError(Exception):
"""Thrown when a model provider's native token counting API fails.
Comment thread
opieter-aws marked this conversation as resolved.

This error is used as internal control flow within provider ``count_tokens()`` overrides.
When caught, the provider falls back to the base class heuristic estimation.
"""

pass
Comment thread
opieter-aws marked this conversation as resolved.


class ToolProviderException(Exception):
"""Exception raised when a tool provider fails to load or cleanup tools."""

Expand Down
9 changes: 9 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3210,6 +3210,15 @@ async def test_fallback_on_generic_exception(self, model_with_client, bedrock_cl
assert isinstance(result, int)
assert result >= 0

@pytest.mark.asyncio
async def test_fallback_on_none_input_tokens(self, model_with_client, bedrock_client, messages):
bedrock_client.count_tokens.return_value = {}

result = await model_with_client.count_tokens(messages=messages)

assert isinstance(result, int)
assert result >= 0

@pytest.mark.asyncio
async def test_fallback_logs_warning(self, model_with_client, bedrock_client, messages, caplog):
bedrock_client.count_tokens.side_effect = RuntimeError("API down")
Expand Down
Loading