Skip to content

Commit 6b0df9a

Browse files
authored
fix: cache unsupported models for bedrocks token counting (#2250)
1 parent d94d516 commit 6b0df9a

2 files changed

Lines changed: 66 additions & 5 deletions

File tree

src/strands/models/bedrock.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@
5555
"anthropic.claude",
5656
]
5757

58+
# Cache of model IDs that do not support the CountTokens API.
59+
_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set()
60+
61+
62+
def _clear_unsupported_count_tokens_cache() -> None:
63+
"""Clear the cache of model IDs that do not support the CountTokens API."""
64+
_UNSUPPORTED_COUNT_TOKENS_MODELS.clear()
65+
66+
5867
T = TypeVar("T", bound=BaseModel)
5968

6069
DEFAULT_READ_TIMEOUT = 120
@@ -785,6 +794,11 @@ async def count_tokens(
785794
Returns:
786795
Total input token count.
787796
"""
797+
model_id: str = self.config["model_id"]
798+
799+
if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS:
800+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
801+
788802
try:
789803
if system_prompt and system_prompt_content is None:
790804
system_prompt_content = [{"text": system_prompt}]
@@ -811,11 +825,23 @@ async def count_tokens(
811825
logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens)
812826
return total_tokens
813827
except Exception as e:
814-
logger.debug(
815-
"model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation",
816-
self.config["model_id"],
817-
e,
818-
)
828+
if (
829+
isinstance(e, ClientError)
830+
and e.response.get("Error", {}).get("Code") == "ValidationException"
831+
and "doesn't support counting tokens" in str(e)
832+
):
833+
logger.debug(
834+
"model_id=<%s> | model does not support CountTokens, caching for future calls,"
835+
" falling back to estimation",
836+
model_id,
837+
)
838+
_UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id)
839+
else:
840+
logger.debug(
841+
"model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation",
842+
model_id,
843+
e,
844+
)
819845
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
820846

821847
@override

tests/strands/models/test_bedrock.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DEFAULT_BEDROCK_MODEL_ID,
2020
DEFAULT_BEDROCK_REGION,
2121
DEFAULT_READ_TIMEOUT,
22+
_clear_unsupported_count_tokens_cache,
2223
)
2324
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
2425
from strands.types.tools import ToolSpec
@@ -3333,6 +3334,12 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model,
33333334
class TestCountTokens:
33343335
"""Tests for BedrockModel.count_tokens native token counting."""
33353336

3337+
@pytest.fixture(autouse=True)
3338+
def clean_cache(self):
3339+
_clear_unsupported_count_tokens_cache()
3340+
yield
3341+
_clear_unsupported_count_tokens_cache()
3342+
33363343
@pytest.fixture
33373344
def model_with_client(self, bedrock_client, model_id):
33383345
_ = bedrock_client
@@ -3449,3 +3456,31 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess
34493456
await model_with_client.count_tokens(messages=messages)
34503457

34513458
assert any("native token counting failed" in record.message for record in caplog.records)
3459+
3460+
@pytest.mark.asyncio
3461+
async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages):
3462+
model = BedrockModel(model_id="unsupported-cache-test-model")
3463+
bedrock_client.count_tokens.side_effect = ClientError(
3464+
{"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}},
3465+
"CountTokens",
3466+
)
3467+
3468+
# First call: hits API, gets error, caches
3469+
await model.count_tokens(messages=messages)
3470+
assert bedrock_client.count_tokens.call_count == 1
3471+
3472+
# Second call: skips API entirely
3473+
await model.count_tokens(messages=messages)
3474+
assert bedrock_client.count_tokens.call_count == 1
3475+
3476+
@pytest.mark.asyncio
3477+
async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages):
3478+
model = BedrockModel(model_id="transient-error-test-model")
3479+
bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error")
3480+
3481+
await model.count_tokens(messages=messages)
3482+
assert bedrock_client.count_tokens.call_count == 1
3483+
3484+
# Second call should still attempt the API
3485+
await model.count_tokens(messages=messages)
3486+
assert bedrock_client.count_tokens.call_count == 2

0 commit comments

Comments
 (0)