diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index d535bbc51..7444cfc70 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -54,6 +54,15 @@ "anthropic.claude", ] +# Cache of model IDs that do not support the CountTokens API. +_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set() + + +def _clear_unsupported_count_tokens_cache() -> None: + """Clear the cache of model IDs that do not support the CountTokens API.""" + _UNSUPPORTED_COUNT_TOKENS_MODELS.clear() + + T = TypeVar("T", bound=BaseModel) DEFAULT_READ_TIMEOUT = 120 @@ -784,6 +793,11 @@ async def count_tokens( Returns: Total input token count. """ + model_id: str = self.config["model_id"] + + if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: if system_prompt and system_prompt_content is None: system_prompt_content = [{"text": system_prompt}] @@ -810,11 +824,23 @@ async def count_tokens( logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) return total_tokens except Exception as e: - logger.debug( - "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", - self.config["model_id"], - e, - ) + if ( + isinstance(e, ClientError) + and e.response.get("Error", {}).get("Code") == "ValidationException" + and "doesn't support counting tokens" in str(e) + ): + logger.debug( + "model_id=<%s> | model does not support CountTokens, caching for future calls," + " falling back to estimation", + model_id, + ) + _UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id) + else: + logger.debug( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + model_id, + e, + ) return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) @override diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index a80ca091e..5da399f40 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -19,6 +19,7 @@ DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, + _clear_unsupported_count_tokens_cache, ) from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec @@ -3293,6 +3294,12 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model, class TestCountTokens: """Tests for BedrockModel.count_tokens native token counting.""" + @pytest.fixture(autouse=True) + def clean_cache(self): + _clear_unsupported_count_tokens_cache() + yield + _clear_unsupported_count_tokens_cache() + @pytest.fixture def model_with_client(self, bedrock_client, model_id): _ = bedrock_client @@ -3409,3 +3416,31 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess await model_with_client.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages): + model = BedrockModel(model_id="unsupported-cache-test-model") + bedrock_client.count_tokens.side_effect = ClientError( + {"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}}, + "CountTokens", + ) + + # First call: hits API, gets error, caches + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + # Second call: skips API entirely + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + @pytest.mark.asyncio + async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): + model = BedrockModel(model_id="transient-error-test-model") + bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error") + + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + # Second call should still attempt the API + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 2