Skip to content

Commit 009374f

Browse files
authored
feat: add ProviderTokenCountError for native token counting failures (#2211)
1 parent b340dc4 commit 009374f

4 files changed

Lines changed: 26 additions & 3 deletions

File tree

src/strands/models/bedrock.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..types.exceptions import (
2828
ContextWindowOverflowException,
2929
ModelThrottledException,
30+
ProviderTokenCountError,
3031
)
3132
from ..types.streaming import CitationsDelta, StreamEvent
3233
from ..types.tools import ToolChoice, ToolSpec
@@ -789,7 +790,10 @@ async def count_tokens(
789790
modelId=self.config["model_id"],
790791
input={"converse": converse_input},
791792
)
792-
total_tokens: int = response["inputTokens"]
793+
input_tokens = response.get("inputTokens")
794+
if input_tokens is None:
795+
raise ProviderTokenCountError("Bedrock count_tokens returned None for inputTokens")
796+
total_tokens: int = input_tokens
793797

794798
logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens)
795799
return total_tokens

src/strands/models/gemini.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing_extensions import Required, Unpack, override
1717

1818
from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages, SystemContentBlock
19-
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
19+
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException, ProviderTokenCountError
2020
from ..types.streaming import StreamEvent
2121
from ..types.tools import ToolChoice, ToolSpec
2222
from ._validation import _has_location_source, validate_config_keys
@@ -465,7 +465,7 @@ async def count_tokens(
465465
contents=contents,
466466
)
467467
if response.total_tokens is None:
468-
raise ValueError("Gemini count_tokens returned None for total_tokens")
468+
raise ProviderTokenCountError("Gemini count_tokens returned None for total_tokens")
469469
total_tokens: int = response.total_tokens
470470

471471
# The google-genai SDK explicitly raises ValueError for system_instruction, tools, and

src/strands/types/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ class SnapshotException(Exception):
8383
pass
8484

8585

86+
class ProviderTokenCountError(Exception):
87+
"""Thrown when a model provider's native token counting API fails.
88+
89+
This error is used as internal control flow within provider ``count_tokens()`` overrides.
90+
When caught, the provider falls back to the base class heuristic estimation.
91+
"""
92+
93+
pass
94+
95+
8696
class ToolProviderException(Exception):
8797
"""Exception raised when a tool provider fails to load or cleanup tools."""
8898

tests/strands/models/test_bedrock.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3210,6 +3210,15 @@ async def test_fallback_on_generic_exception(self, model_with_client, bedrock_cl
32103210
assert isinstance(result, int)
32113211
assert result >= 0
32123212

3213+
@pytest.mark.asyncio
3214+
async def test_fallback_on_none_input_tokens(self, model_with_client, bedrock_client, messages):
3215+
bedrock_client.count_tokens.return_value = {}
3216+
3217+
result = await model_with_client.count_tokens(messages=messages)
3218+
3219+
assert isinstance(result, int)
3220+
assert result >= 0
3221+
32133222
@pytest.mark.asyncio
32143223
async def test_fallback_logs_warning(self, model_with_client, bedrock_client, messages, caplog):
32153224
bedrock_client.count_tokens.side_effect = RuntimeError("API down")

0 commit comments

Comments
 (0)