Skip to content

Commit 46eb4ba

Browse files
committed
fix: cache unsupported models for bedrocks token counting
1 parent 8638fc2 commit 46eb4ba

2 files changed

Lines changed: 62 additions & 5 deletions

File tree

src/strands/models/bedrock.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@
5454
"anthropic.claude",
5555
]
5656

57+
# Cache of model IDs that do not support the CountTokens API.
58+
_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set()
59+
60+
61+
def _clear_unsupported_count_tokens_cache() -> None:
62+
"""Clear the cache of model IDs that do not support the CountTokens API."""
63+
_UNSUPPORTED_COUNT_TOKENS_MODELS.clear()
64+
65+
5766
T = TypeVar("T", bound=BaseModel)
5867

5968
DEFAULT_READ_TIMEOUT = 120
@@ -784,6 +793,11 @@ async def count_tokens(
784793
Returns:
785794
Total input token count.
786795
"""
796+
model_id: str = self.config["model_id"]
797+
798+
if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS:
799+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
800+
787801
try:
788802
if system_prompt and system_prompt_content is None:
789803
system_prompt_content = [{"text": system_prompt}]
@@ -810,11 +824,23 @@ async def count_tokens(
810824
logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens)
811825
return total_tokens
812826
except Exception as e:
813-
logger.debug(
814-
"model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation",
815-
self.config["model_id"],
816-
e,
817-
)
827+
if (
828+
isinstance(e, ClientError)
829+
and e.response.get("Error", {}).get("Code") == "ValidationException"
830+
and "doesn't support counting tokens" in str(e)
831+
):
832+
logger.debug(
833+
"model_id=<%s> | model does not support CountTokens, caching for future calls,"
834+
" falling back to estimation",
835+
model_id,
836+
)
837+
_UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id)
838+
else:
839+
logger.debug(
840+
"model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation",
841+
model_id,
842+
e,
843+
)
818844
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
819845

820846
@override

tests/strands/models/test_bedrock.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DEFAULT_BEDROCK_MODEL_ID,
2020
DEFAULT_BEDROCK_REGION,
2121
DEFAULT_READ_TIMEOUT,
22+
_clear_unsupported_count_tokens_cache,
2223
)
2324
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
2425
from strands.types.tools import ToolSpec
@@ -3409,3 +3410,33 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess
34093410
await model_with_client.count_tokens(messages=messages)
34103411

34113412
assert any("native token counting failed" in record.message for record in caplog.records)
3413+
3414+
@pytest.mark.asyncio
3415+
async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages):
3416+
_clear_unsupported_count_tokens_cache()
3417+
model = BedrockModel(model_id="unsupported-cache-test-model")
3418+
bedrock_client.count_tokens.side_effect = ClientError(
3419+
{"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}},
3420+
"CountTokens",
3421+
)
3422+
3423+
# First call: hits API, gets error, caches
3424+
await model.count_tokens(messages=messages)
3425+
assert bedrock_client.count_tokens.call_count == 1
3426+
3427+
# Second call: skips API entirely
3428+
await model.count_tokens(messages=messages)
3429+
assert bedrock_client.count_tokens.call_count == 1
3430+
3431+
@pytest.mark.asyncio
3432+
async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages):
3433+
_clear_unsupported_count_tokens_cache()
3434+
model = BedrockModel(model_id="transient-error-test-model")
3435+
bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error")
3436+
3437+
await model.count_tokens(messages=messages)
3438+
assert bedrock_client.count_tokens.call_count == 1
3439+
3440+
# Second call should still attempt the API
3441+
await model.count_tokens(messages=messages)
3442+
assert bedrock_client.count_tokens.call_count == 2

0 commit comments

Comments
 (0)