Skip to content

Commit 1847fae

Browse files
authored
feat: cache AccessDenied error for count tokens (#2279)
1 parent f862185 commit 1847fae

2 files changed

Lines changed: 69 additions & 10 deletions

File tree

src/strands/models/bedrock.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@
5555
"anthropic.claude",
5656
]
5757

58-
# Cache of model IDs that do not support the CountTokens API.
59-
_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set()
58+
# Cache of model IDs for which CountTokens API calls should be skipped.
59+
_SKIP_COUNT_TOKENS_MODELS: set[str] = set()
6060

6161

62-
def _clear_unsupported_count_tokens_cache() -> None:
63-
"""Clear the cache of model IDs that do not support the CountTokens API."""
64-
_UNSUPPORTED_COUNT_TOKENS_MODELS.clear()
62+
def _clear_skip_count_tokens_cache() -> None:
63+
"""Clear the cache of model IDs for which CountTokens API calls should be skipped."""
64+
_SKIP_COUNT_TOKENS_MODELS.clear()
6565

6666

6767
T = TypeVar("T", bound=BaseModel)
@@ -803,7 +803,7 @@ async def count_tokens(
803803

804804
model_id: str = self.config["model_id"]
805805

806-
if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS:
806+
if model_id in _SKIP_COUNT_TOKENS_MODELS:
807807
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
808808

809809
try:
@@ -833,6 +833,17 @@ async def count_tokens(
833833
return total_tokens
834834
except Exception as e:
835835
if (
836+
isinstance(e, ClientError)
837+
and e.response.get("Error", {}).get("Code") == "AccessDeniedException"
838+
):
839+
logger.warning(
840+
"model_id=<%s> | bedrock:CountTokens permission denied,"
841+
" falling back to heuristic estimation: %s",
842+
model_id,
843+
e,
844+
)
845+
_SKIP_COUNT_TOKENS_MODELS.add(model_id)
846+
elif (
836847
isinstance(e, ClientError)
837848
and e.response.get("Error", {}).get("Code") == "ValidationException"
838849
and "doesn't support counting tokens" in str(e)
@@ -842,7 +853,7 @@ async def count_tokens(
842853
" falling back to estimation",
843854
model_id,
844855
)
845-
_UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id)
856+
_SKIP_COUNT_TOKENS_MODELS.add(model_id)
846857
else:
847858
logger.debug(
848859
"model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation",

tests/strands/models/test_bedrock.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
DEFAULT_BEDROCK_MODEL_ID,
2020
DEFAULT_BEDROCK_REGION,
2121
DEFAULT_READ_TIMEOUT,
22-
_clear_unsupported_count_tokens_cache,
22+
_clear_skip_count_tokens_cache,
2323
)
2424
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
2525
from strands.types.tools import ToolSpec
@@ -3336,9 +3336,9 @@ class TestCountTokens:
33363336

33373337
@pytest.fixture(autouse=True)
33383338
def clean_cache(self):
3339-
_clear_unsupported_count_tokens_cache()
3339+
_clear_skip_count_tokens_cache()
33403340
yield
3341-
_clear_unsupported_count_tokens_cache()
3341+
_clear_skip_count_tokens_cache()
33423342

33433343
@pytest.fixture
33443344
def model_with_client(self, bedrock_client, model_id):
@@ -3473,6 +3473,54 @@ async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_clien
34733473
await model.count_tokens(messages=messages)
34743474
assert bedrock_client.count_tokens.call_count == 1
34753475

3476+
@pytest.mark.asyncio
3477+
async def test_caches_model_id_when_access_denied(self, bedrock_client, messages):
3478+
model = BedrockModel(model_id="access-denied-cache-test-model")
3479+
bedrock_client.count_tokens.side_effect = ClientError(
3480+
{
3481+
"Error": {
3482+
"Code": "AccessDeniedException",
3483+
"Message": "User: arn:aws:sts::123456789012:assumed-role/role is not authorized"
3484+
" to perform: bedrock:CountTokens",
3485+
}
3486+
},
3487+
"CountTokens",
3488+
)
3489+
3490+
# First call: hits API, gets error, caches
3491+
await model.count_tokens(messages=messages)
3492+
bedrock_client.count_tokens.assert_called_once()
3493+
3494+
# Reset mock to clearly verify second call doesn't hit the API
3495+
bedrock_client.count_tokens.reset_mock()
3496+
3497+
# Second call: skips API entirely due to caching
3498+
result = await model.count_tokens(messages=messages)
3499+
bedrock_client.count_tokens.assert_not_called()
3500+
assert isinstance(result, int)
3501+
assert result >= 0
3502+
3503+
@pytest.mark.asyncio
3504+
async def test_access_denied_logs_warning_with_full_error(
3505+
self, model_with_client, bedrock_client, messages, caplog
3506+
):
3507+
error_message = (
3508+
"User: arn:aws:sts::123456789012:assumed-role/role is not authorized"
3509+
" to perform: bedrock:CountTokens"
3510+
)
3511+
bedrock_client.count_tokens.side_effect = ClientError(
3512+
{"Error": {"Code": "AccessDeniedException", "Message": error_message}},
3513+
"CountTokens",
3514+
)
3515+
3516+
with caplog.at_level(logging.WARNING, logger="strands.models.bedrock"):
3517+
await model_with_client.count_tokens(messages=messages)
3518+
3519+
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
3520+
assert len(warning_records) == 1
3521+
assert "bedrock:CountTokens permission denied" in warning_records[0].message
3522+
assert error_message in warning_records[0].message
3523+
34763524
@pytest.mark.asyncio
34773525
async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages):
34783526
model = BedrockModel(model_id="transient-error-test-model")

0 commit comments

Comments
 (0)