Skip to content

Commit 70adadc

Browse files
committed
fix: cache unsupported models for bedrocks token counting
1 parent 8638fc2 commit 70adadc

2 files changed

Lines changed: 66 additions & 5 deletions

File tree

src/strands/models/bedrock.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@
5454
"anthropic.claude",
5555
]
5656

57+
# Cache of model IDs that do not support the CountTokens API.
58+
_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set()
59+
60+
61+
def _clear_unsupported_count_tokens_cache() -> None:
62+
"""Clear the cache of model IDs that do not support the CountTokens API."""
63+
_UNSUPPORTED_COUNT_TOKENS_MODELS.clear()
64+
65+
5766
T = TypeVar("T", bound=BaseModel)
5867

5968
DEFAULT_READ_TIMEOUT = 120
@@ -784,6 +793,11 @@ async def count_tokens(
784793
Returns:
785794
Total input token count.
786795
"""
796+
model_id: str = self.config["model_id"]
797+
798+
if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS:
799+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
800+
787801
try:
788802
if system_prompt and system_prompt_content is None:
789803
system_prompt_content = [{"text": system_prompt}]
@@ -810,11 +824,23 @@ async def count_tokens(
810824
logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens)
811825
return total_tokens
812826
except Exception as e:
813-
logger.debug(
814-
"model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation",
815-
self.config["model_id"],
816-
e,
817-
)
827+
if (
828+
isinstance(e, ClientError)
829+
and e.response.get("Error", {}).get("Code") == "ValidationException"
830+
and "doesn't support counting tokens" in str(e)
831+
):
832+
logger.debug(
833+
"model_id=<%s> | model does not support CountTokens, caching for future calls,"
834+
" falling back to estimation",
835+
model_id,
836+
)
837+
_UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id)
838+
else:
839+
logger.debug(
840+
"model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation",
841+
model_id,
842+
e,
843+
)
818844
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
819845

820846
@override

tests/strands/models/test_bedrock.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DEFAULT_BEDROCK_MODEL_ID,
2020
DEFAULT_BEDROCK_REGION,
2121
DEFAULT_READ_TIMEOUT,
22+
_clear_unsupported_count_tokens_cache,
2223
)
2324
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
2425
from strands.types.tools import ToolSpec
@@ -3293,6 +3294,12 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model,
32933294
class TestCountTokens:
32943295
"""Tests for BedrockModel.count_tokens native token counting."""
32953296

3297+
@pytest.fixture(autouse=True)
3298+
def clean_cache(self):
3299+
_clear_unsupported_count_tokens_cache()
3300+
yield
3301+
_clear_unsupported_count_tokens_cache()
3302+
32963303
@pytest.fixture
32973304
def model_with_client(self, bedrock_client, model_id):
32983305
_ = bedrock_client
@@ -3409,3 +3416,31 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess
34093416
await model_with_client.count_tokens(messages=messages)
34103417

34113418
assert any("native token counting failed" in record.message for record in caplog.records)
3419+
3420+
@pytest.mark.asyncio
3421+
async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages):
3422+
model = BedrockModel(model_id="unsupported-cache-test-model")
3423+
bedrock_client.count_tokens.side_effect = ClientError(
3424+
{"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}},
3425+
"CountTokens",
3426+
)
3427+
3428+
# First call: hits API, gets error, caches
3429+
await model.count_tokens(messages=messages)
3430+
assert bedrock_client.count_tokens.call_count == 1
3431+
3432+
# Second call: skips API entirely
3433+
await model.count_tokens(messages=messages)
3434+
assert bedrock_client.count_tokens.call_count == 1
3435+
3436+
@pytest.mark.asyncio
3437+
async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages):
3438+
model = BedrockModel(model_id="transient-error-test-model")
3439+
bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error")
3440+
3441+
await model.count_tokens(messages=messages)
3442+
assert bedrock_client.count_tokens.call_count == 1
3443+
3444+
# Second call should still attempt the API
3445+
await model.count_tokens(messages=messages)
3446+
assert bedrock_client.count_tokens.call_count == 2

0 commit comments

Comments
 (0)