|
19 | 19 | DEFAULT_BEDROCK_MODEL_ID, |
20 | 20 | DEFAULT_BEDROCK_REGION, |
21 | 21 | DEFAULT_READ_TIMEOUT, |
22 | | - _clear_unsupported_count_tokens_cache, |
| 22 | + _clear_skip_count_tokens_cache, |
23 | 23 | ) |
24 | 24 | from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException |
25 | 25 | from strands.types.tools import ToolSpec |
@@ -3336,9 +3336,9 @@ class TestCountTokens: |
3336 | 3336 |
|
3337 | 3337 | @pytest.fixture(autouse=True) |
3338 | 3338 | def clean_cache(self): |
3339 | | - _clear_unsupported_count_tokens_cache() |
| 3339 | + _clear_skip_count_tokens_cache() |
3340 | 3340 | yield |
3341 | | - _clear_unsupported_count_tokens_cache() |
| 3341 | + _clear_skip_count_tokens_cache() |
3342 | 3342 |
|
3343 | 3343 | @pytest.fixture |
3344 | 3344 | def model_with_client(self, bedrock_client, model_id): |
@@ -3473,6 +3473,54 @@ async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_clien |
3473 | 3473 | await model.count_tokens(messages=messages) |
3474 | 3474 | assert bedrock_client.count_tokens.call_count == 1 |
3475 | 3475 |
|
| 3476 | + @pytest.mark.asyncio |
| 3477 | + async def test_caches_model_id_when_access_denied(self, bedrock_client, messages): |
| 3478 | + model = BedrockModel(model_id="access-denied-cache-test-model") |
| 3479 | + bedrock_client.count_tokens.side_effect = ClientError( |
| 3480 | + { |
| 3481 | + "Error": { |
| 3482 | + "Code": "AccessDeniedException", |
| 3483 | + "Message": "User: arn:aws:sts::123456789012:assumed-role/role is not authorized" |
| 3484 | + " to perform: bedrock:CountTokens", |
| 3485 | + } |
| 3486 | + }, |
| 3487 | + "CountTokens", |
| 3488 | + ) |
| 3489 | + |
| 3490 | + # First call: hits API, gets error, caches |
| 3491 | + await model.count_tokens(messages=messages) |
| 3492 | + bedrock_client.count_tokens.assert_called_once() |
| 3493 | + |
| 3494 | + # Reset mock to clearly verify second call doesn't hit the API |
| 3495 | + bedrock_client.count_tokens.reset_mock() |
| 3496 | + |
| 3497 | + # Second call: skips API entirely due to caching |
| 3498 | + result = await model.count_tokens(messages=messages) |
| 3499 | + bedrock_client.count_tokens.assert_not_called() |
| 3500 | + assert isinstance(result, int) |
| 3501 | + assert result >= 0 |
| 3502 | + |
| 3503 | + @pytest.mark.asyncio |
| 3504 | + async def test_access_denied_logs_warning_with_full_error( |
| 3505 | + self, model_with_client, bedrock_client, messages, caplog |
| 3506 | + ): |
| 3507 | + error_message = ( |
| 3508 | + "User: arn:aws:sts::123456789012:assumed-role/role is not authorized" |
| 3509 | + " to perform: bedrock:CountTokens" |
| 3510 | + ) |
| 3511 | + bedrock_client.count_tokens.side_effect = ClientError( |
| 3512 | + {"Error": {"Code": "AccessDeniedException", "Message": error_message}}, |
| 3513 | + "CountTokens", |
| 3514 | + ) |
| 3515 | + |
| 3516 | + with caplog.at_level(logging.WARNING, logger="strands.models.bedrock"): |
| 3517 | + await model_with_client.count_tokens(messages=messages) |
| 3518 | + |
| 3519 | + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING] |
| 3520 | + assert len(warning_records) == 1 |
| 3521 | + assert "bedrock:CountTokens permission denied" in warning_records[0].message |
| 3522 | + assert error_message in warning_records[0].message |
| 3523 | + |
3476 | 3524 | @pytest.mark.asyncio |
3477 | 3525 | async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): |
3478 | 3526 | model = BedrockModel(model_id="transient-error-test-model") |
|
0 commit comments