diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 7f7113e83..94df5a84d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -27,6 +27,7 @@ from ..types.exceptions import ( ContextWindowOverflowException, ModelThrottledException, + ProviderTokenCountError, ) from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -789,7 +790,10 @@ async def count_tokens( modelId=self.config["model_id"], input={"converse": converse_input}, ) - total_tokens: int = response["inputTokens"] + input_tokens = response.get("inputTokens") + if input_tokens is None: + raise ProviderTokenCountError("Bedrock count_tokens returned None for inputTokens") + total_tokens: int = input_tokens logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) return total_tokens diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 04e98f359..2ce1c0b42 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -16,7 +16,7 @@ from typing_extensions import Required, Unpack, override from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages, SystemContentBlock -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException, ProviderTokenCountError from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import _has_location_source, validate_config_keys @@ -465,7 +465,7 @@ async def count_tokens( contents=contents, ) if response.total_tokens is None: - raise ValueError("Gemini count_tokens returned None for total_tokens") + raise ProviderTokenCountError("Gemini count_tokens returned None for total_tokens") total_tokens: int = response.total_tokens # The google-genai SDK explicitly raises ValueError for system_instruction, tools, and diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 5db80a26e..7ad49eb24 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -83,6 +83,16 @@ class SnapshotException(Exception): pass +class ProviderTokenCountError(Exception): + """Thrown when a model provider's native token counting API fails. + + This error is used as internal control flow within provider ``count_tokens()`` overrides. + When caught, the provider falls back to the base class heuristic estimation. + """ + + pass + + class ToolProviderException(Exception): """Exception raised when a tool provider fails to load or cleanup tools.""" diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index b8e41d20a..d63838182 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3210,6 +3210,15 @@ async def test_fallback_on_generic_exception(self, model_with_client, bedrock_cl assert isinstance(result, int) assert result >= 0 + @pytest.mark.asyncio + async def test_fallback_on_none_input_tokens(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {} + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + @pytest.mark.asyncio async def test_fallback_logs_warning(self, model_with_client, bedrock_client, messages, caplog): bedrock_client.count_tokens.side_effect = RuntimeError("API down")