|
19 | 19 | DEFAULT_BEDROCK_MODEL_ID, |
20 | 20 | DEFAULT_BEDROCK_REGION, |
21 | 21 | DEFAULT_READ_TIMEOUT, |
| 22 | + _clear_unsupported_count_tokens_cache, |
22 | 23 | ) |
23 | 24 | from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException |
24 | 25 | from strands.types.tools import ToolSpec |
@@ -3293,6 +3294,12 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model, |
3293 | 3294 | class TestCountTokens: |
3294 | 3295 | """Tests for BedrockModel.count_tokens native token counting.""" |
3295 | 3296 |
|
| 3297 | + @pytest.fixture(autouse=True) |
| 3298 | + def clean_cache(self): |
| 3299 | + _clear_unsupported_count_tokens_cache() |
| 3300 | + yield |
| 3301 | + _clear_unsupported_count_tokens_cache() |
| 3302 | + |
3296 | 3303 | @pytest.fixture |
3297 | 3304 | def model_with_client(self, bedrock_client, model_id): |
3298 | 3305 | _ = bedrock_client |
@@ -3409,3 +3416,31 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess |
3409 | 3416 | await model_with_client.count_tokens(messages=messages) |
3410 | 3417 |
|
3411 | 3418 | assert any("native token counting failed" in record.message for record in caplog.records) |
| 3419 | + |
| 3420 | + @pytest.mark.asyncio |
| 3421 | + async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages): |
| 3422 | + model = BedrockModel(model_id="unsupported-cache-test-model") |
| 3423 | + bedrock_client.count_tokens.side_effect = ClientError( |
| 3424 | + {"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}}, |
| 3425 | + "CountTokens", |
| 3426 | + ) |
| 3427 | + |
| 3428 | + # First call: hits API, gets error, caches |
| 3429 | + await model.count_tokens(messages=messages) |
| 3430 | + assert bedrock_client.count_tokens.call_count == 1 |
| 3431 | + |
| 3432 | + # Second call: skips API entirely |
| 3433 | + await model.count_tokens(messages=messages) |
| 3434 | + assert bedrock_client.count_tokens.call_count == 1 |
| 3435 | + |
| 3436 | + @pytest.mark.asyncio |
| 3437 | + async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): |
| 3438 | + model = BedrockModel(model_id="transient-error-test-model") |
| 3439 | + bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error") |
| 3440 | + |
| 3441 | + await model.count_tokens(messages=messages) |
| 3442 | + assert bedrock_client.count_tokens.call_count == 1 |
| 3443 | + |
| 3444 | + # Second call should still attempt the API |
| 3445 | + await model.count_tokens(messages=messages) |
| 3446 | + assert bedrock_client.count_tokens.call_count == 2 |
0 commit comments