|
19 | 19 | DEFAULT_BEDROCK_MODEL_ID, |
20 | 20 | DEFAULT_BEDROCK_REGION, |
21 | 21 | DEFAULT_READ_TIMEOUT, |
| 22 | + _clear_unsupported_count_tokens_cache, |
22 | 23 | ) |
23 | 24 | from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException |
24 | 25 | from strands.types.tools import ToolSpec |
@@ -3333,6 +3334,12 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model, |
3333 | 3334 | class TestCountTokens: |
3334 | 3335 | """Tests for BedrockModel.count_tokens native token counting.""" |
3335 | 3336 |
|
| 3337 | + @pytest.fixture(autouse=True) |
| 3338 | + def clean_cache(self): |
| 3339 | + _clear_unsupported_count_tokens_cache() |
| 3340 | + yield |
| 3341 | + _clear_unsupported_count_tokens_cache() |
| 3342 | + |
3336 | 3343 | @pytest.fixture |
3337 | 3344 | def model_with_client(self, bedrock_client, model_id): |
3338 | 3345 | _ = bedrock_client |
@@ -3449,3 +3456,31 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess |
3449 | 3456 | await model_with_client.count_tokens(messages=messages) |
3450 | 3457 |
|
3451 | 3458 | assert any("native token counting failed" in record.message for record in caplog.records) |
| 3459 | + |
| 3460 | + @pytest.mark.asyncio |
| 3461 | + async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages): |
| 3462 | + model = BedrockModel(model_id="unsupported-cache-test-model") |
| 3463 | + bedrock_client.count_tokens.side_effect = ClientError( |
| 3464 | + {"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}}, |
| 3465 | + "CountTokens", |
| 3466 | + ) |
| 3467 | + |
| 3468 | + # First call: hits API, gets error, caches |
| 3469 | + await model.count_tokens(messages=messages) |
| 3470 | + assert bedrock_client.count_tokens.call_count == 1 |
| 3471 | + |
| 3472 | + # Second call: skips API entirely |
| 3473 | + await model.count_tokens(messages=messages) |
| 3474 | + assert bedrock_client.count_tokens.call_count == 1 |
| 3475 | + |
| 3476 | + @pytest.mark.asyncio |
| 3477 | + async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): |
| 3478 | + model = BedrockModel(model_id="transient-error-test-model") |
| 3479 | + bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error") |
| 3480 | + |
| 3481 | + await model.count_tokens(messages=messages) |
| 3482 | + assert bedrock_client.count_tokens.call_count == 1 |
| 3483 | + |
| 3484 | + # Second call should still attempt the API |
| 3485 | + await model.count_tokens(messages=messages) |
| 3486 | + assert bedrock_client.count_tokens.call_count == 2 |
0 commit comments