From 10aa7f461ea5252c5702652ce7d361118f33112e Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Wed, 6 May 2026 11:47:54 -0400 Subject: [PATCH 1/2] feat: add useNativeTokenCount flag to skip token counting API calls --- src/strands/models/anthropic.py | 7 +++++++ src/strands/models/bedrock.py | 7 +++++++ src/strands/models/gemini.py | 7 +++++++ src/strands/models/llamacpp.py | 7 +++++++ src/strands/models/openai_responses.py | 7 +++++++ tests/strands/models/test_anthropic.py | 13 +++++++++++++ tests/strands/models/test_bedrock.py | 11 +++++++++++ tests/strands/models/test_gemini.py | 11 +++++++++++ tests/strands/models/test_llamacpp.py | 11 +++++++++++ tests/strands/models/test_model.py | 1 - tests/strands/models/test_openai_responses.py | 11 +++++++++++ 11 files changed, 92 insertions(+), 1 deletion(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index ece7cd8d1..c41a5fc76 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -57,11 +57,15 @@ class AnthropicConfig(BaseModelConfig, total=False): https://docs.anthropic.com/en/docs/about-claude/models/all-models. params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages. + native_token_counting: Whether to use the native Anthropic count_tokens API. + When True (default), count_tokens() calls the Anthropic API for accurate counts. + When False, skips the API call and uses the local estimator. """ max_tokens: Required[int] model_id: Required[str] params: dict[str, Any] | None + native_token_counting: bool def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]): """Initialize provider instance. @@ -394,6 +398,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("native_token_counting") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c1cbfa265..adba5bb8f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -117,6 +117,9 @@ class BedrockConfig(BaseModelConfig, total=False): See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) + native_token_counting: Whether to use the native Bedrock CountTokens API. + When True (default), count_tokens() calls the Bedrock API for accurate counts. + When False, skips the API call and uses the local estimator. """ additional_args: dict[str, Any] | None @@ -143,6 +146,7 @@ class BedrockConfig(BaseModelConfig, total=False): strict_tools: bool | None temperature: float | None top_p: float | None + native_token_counting: bool def __init__( self, @@ -794,6 +798,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("native_token_counting") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + model_id: str = self.config["model_id"] if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS: diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 65b925c6d..cd97404c9 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -49,11 +49,15 @@ class GeminiConfig(BaseModelConfig, total=False): Use the standard tools interface for function calling tools. For a complete list of supported tools, see https://ai.google.dev/api/caching#Tool + native_token_counting: Whether to use the native Gemini count_tokens API. + When True (default), count_tokens() calls the Gemini API for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: Required[str] params: dict[str, Any] gemini_tools: list[genai.types.Tool] + native_token_counting: bool def __init__( self, @@ -457,6 +461,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("native_token_counting") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: contents = list(self._format_request_content(messages)) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index c31ba11bc..c9fafa28d 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -125,10 +125,14 @@ class LlamaCppConfig(BaseModelConfig, total=False): - cache_prompt: Cache the prompt for faster generation - slot_id: Slot ID for parallel inference - samplers: Custom sampler order + native_token_counting: Whether to use the native llama.cpp /tokenize endpoint. + When True (default), count_tokens() calls the server's tokenize endpoint for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: str params: dict[str, Any] | None + native_token_counting: bool def __init__( self, @@ -533,6 +537,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("native_token_counting") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index a78cef73a..b898d1c28 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -136,11 +136,15 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False): stateful: Whether to enable server-side conversation state management. When True, the server stores conversation history and the client does not need to send the full message history with each request. Defaults to False. + native_token_counting: Whether to use the native OpenAI input_tokens.count API. + When True (default), count_tokens() calls the OpenAI API for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: str params: dict[str, Any] | None stateful: bool + native_token_counting: bool def __init__( self, @@ -238,6 +242,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("native_token_counting") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index abb56a441..129eb30ce 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1162,3 +1162,16 @@ async def test_fallback_logs_debug(self, model_with_client, anthropic_client, me await model_with_client.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_native_token_counting_false( + self, anthropic_client, model_id, max_tokens, messages + ): + _ = anthropic_client + model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, native_token_counting=False) + + result = await model.count_tokens(messages=messages) + + anthropic_client.messages.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index f177a8a17..a87285abf 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3484,3 +3484,14 @@ async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, me # Second call should still attempt the API await model.count_tokens(messages=messages) assert bedrock_client.count_tokens.call_count == 2 + + @pytest.mark.asyncio + async def test_skip_native_api_when_native_token_counting_false(self, bedrock_client, model_id, messages): + _ = bedrock_client + model = BedrockModel(model_id=model_id, native_token_counting=False) + + result = await model.count_tokens(messages=messages) + + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 91a55d899..45840ca25 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -1228,3 +1228,14 @@ async def test_fallback_logs_debug(self, model, gemini_client, messages, caplog) await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_native_token_counting_false(self, gemini_client, messages): + _ = gemini_client + model = GeminiModel(model_id="m1", native_token_counting=False) + + result = await model.count_tokens(messages=messages) + + gemini_client.aio.models.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index a891ec929..7edd63f8b 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -803,3 +803,14 @@ async def test_fallback_logs_debug(self, model, messages, caplog): await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_native_token_counting_false(self, messages): + model = LlamaCppModel(base_url="http://localhost:8080", native_token_counting=False) + model.client.post = AsyncMock() + + result = await model.count_tokens(messages=messages) + + model.client.post.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index b362740b5..34f4ef328 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -509,7 +509,6 @@ async def test_count_tokens_all_inputs(model): assert result == 50 - class TestHeuristicEstimation: """Tests for _estimate_tokens_with_heuristic.""" diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 97ee9e305..05e639e0d 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1318,6 +1318,17 @@ async def test_fallback_logs_debug(self, model, openai_client, messages, caplog) assert any("native token counting failed" in record.message for record in caplog.records) + @pytest.mark.asyncio + async def test_skip_native_api_when_native_token_counting_false(self, openai_client, messages): + _ = openai_client + model = OpenAIResponsesModel(model_id="gpt-4o", native_token_counting=False) + + result = await model.count_tokens(messages=messages) + + openai_client.responses.input_tokens.count.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + # ============================================================================= # Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel From cee3b4bb8eadd85c7172a020d96892de58553624 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Wed, 6 May 2026 11:58:00 -0400 Subject: [PATCH 2/2] feat: add useNativeTokenCount flag to skip token counting API calls --- src/strands/models/anthropic.py | 6 +++--- src/strands/models/bedrock.py | 6 +++--- src/strands/models/gemini.py | 6 +++--- src/strands/models/llamacpp.py | 6 +++--- src/strands/models/model.py | 2 +- src/strands/models/openai_responses.py | 6 +++--- tests/strands/models/test_anthropic.py | 4 ++-- tests/strands/models/test_bedrock.py | 4 ++-- tests/strands/models/test_gemini.py | 4 ++-- tests/strands/models/test_llamacpp.py | 4 ++-- tests/strands/models/test_openai_responses.py | 4 ++-- 11 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index c41a5fc76..04fae220d 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -57,7 +57,7 @@ class AnthropicConfig(BaseModelConfig, total=False): https://docs.anthropic.com/en/docs/about-claude/models/all-models. params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages. - native_token_counting: Whether to use the native Anthropic count_tokens API. + use_native_token_count: Whether to use the native Anthropic count_tokens API. When True (default), count_tokens() calls the Anthropic API for accurate counts. When False, skips the API call and uses the local estimator. """ @@ -65,7 +65,7 @@ class AnthropicConfig(BaseModelConfig, total=False): max_tokens: Required[int] model_id: Required[str] params: dict[str, Any] | None - native_token_counting: bool + use_native_token_count: bool def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]): """Initialize provider instance. @@ -398,7 +398,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("native_token_counting") is False: + if self.config.get("use_native_token_count") is False: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index adba5bb8f..c74a63a3b 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -117,7 +117,7 @@ class BedrockConfig(BaseModelConfig, total=False): See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) - native_token_counting: Whether to use the native Bedrock CountTokens API. + use_native_token_count: Whether to use the native Bedrock CountTokens API. When True (default), count_tokens() calls the Bedrock API for accurate counts. When False, skips the API call and uses the local estimator. """ @@ -146,7 +146,7 @@ class BedrockConfig(BaseModelConfig, total=False): strict_tools: bool | None temperature: float | None top_p: float | None - native_token_counting: bool + use_native_token_count: bool def __init__( self, @@ -798,7 +798,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("native_token_counting") is False: + if self.config.get("use_native_token_count") is False: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) model_id: str = self.config["model_id"] diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index cd97404c9..8ed579d38 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -49,7 +49,7 @@ class GeminiConfig(BaseModelConfig, total=False): Use the standard tools interface for function calling tools. For a complete list of supported tools, see https://ai.google.dev/api/caching#Tool - native_token_counting: Whether to use the native Gemini count_tokens API. + use_native_token_count: Whether to use the native Gemini count_tokens API. When True (default), count_tokens() calls the Gemini API for accurate counts. When False, skips the API call and uses the local estimator. """ @@ -57,7 +57,7 @@ class GeminiConfig(BaseModelConfig, total=False): model_id: Required[str] params: dict[str, Any] gemini_tools: list[genai.types.Tool] - native_token_counting: bool + use_native_token_count: bool def __init__( self, @@ -461,7 +461,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("native_token_counting") is False: + if self.config.get("use_native_token_count") is False: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index c9fafa28d..531cf6b50 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -125,14 +125,14 @@ class LlamaCppConfig(BaseModelConfig, total=False): - cache_prompt: Cache the prompt for faster generation - slot_id: Slot ID for parallel inference - samplers: Custom sampler order - native_token_counting: Whether to use the native llama.cpp /tokenize endpoint. + use_native_token_count: Whether to use the native llama.cpp /tokenize endpoint. When True (default), count_tokens() calls the server's tokenize endpoint for accurate counts. When False, skips the API call and uses the local estimator. """ model_id: str params: dict[str, Any] | None - native_token_counting: bool + use_native_token_count: bool def __init__( self, @@ -537,7 +537,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("native_token_counting") is False: + if self.config.get("use_native_token_count") is False: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 3ded11a28..dd2f9eed2 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -23,6 +23,7 @@ T = TypeVar("T", bound=BaseModel) + def _heuristic_estimate_text(text: str) -> int: """Estimate token count from text using characters / 4 heuristic.""" return math.ceil(len(text) / 4) @@ -84,7 +85,6 @@ def _count_content_block_tokens( return total - def _estimate_tokens_with_heuristic( messages: Messages, tool_specs: list[ToolSpec] | None = None, diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index b898d1c28..c6ddbb9d6 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -136,7 +136,7 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False): stateful: Whether to enable server-side conversation state management. When True, the server stores conversation history and the client does not need to send the full message history with each request. Defaults to False. - native_token_counting: Whether to use the native OpenAI input_tokens.count API. + use_native_token_count: Whether to use the native OpenAI input_tokens.count API. When True (default), count_tokens() calls the OpenAI API for accurate counts. When False, skips the API call and uses the local estimator. """ @@ -144,7 +144,7 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False): model_id: str params: dict[str, Any] | None stateful: bool - native_token_counting: bool + use_native_token_count: bool def __init__( self, @@ -242,7 +242,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("native_token_counting") is False: + if self.config.get("use_native_token_count") is False: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 129eb30ce..6de821e90 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1164,11 +1164,11 @@ async def test_fallback_logs_debug(self, model_with_client, anthropic_client, me assert any("native token counting failed" in record.message for record in caplog.records) @pytest.mark.asyncio - async def test_skip_native_api_when_native_token_counting_false( + async def test_skip_native_api_when_use_native_token_count_false( self, anthropic_client, model_id, max_tokens, messages ): _ = anthropic_client - model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, native_token_counting=False) + model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=False) result = await model.count_tokens(messages=messages) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index a87285abf..2f1f7d1f1 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3486,9 +3486,9 @@ async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, me assert bedrock_client.count_tokens.call_count == 2 @pytest.mark.asyncio - async def test_skip_native_api_when_native_token_counting_false(self, bedrock_client, model_id, messages): + async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_client, model_id, messages): _ = bedrock_client - model = BedrockModel(model_id=model_id, native_token_counting=False) + model = BedrockModel(model_id=model_id, use_native_token_count=False) result = await model.count_tokens(messages=messages) diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 45840ca25..b846bfcdf 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -1230,9 +1230,9 @@ async def test_fallback_logs_debug(self, model, gemini_client, messages, caplog) assert any("native token counting failed" in record.message for record in caplog.records) @pytest.mark.asyncio - async def test_skip_native_api_when_native_token_counting_false(self, gemini_client, messages): + async def test_skip_native_api_when_use_native_token_count_false(self, gemini_client, messages): _ = gemini_client - model = GeminiModel(model_id="m1", native_token_counting=False) + model = GeminiModel(model_id="m1", use_native_token_count=False) result = await model.count_tokens(messages=messages) diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 7edd63f8b..43fb03629 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -805,8 +805,8 @@ async def test_fallback_logs_debug(self, model, messages, caplog): assert any("native token counting failed" in record.message for record in caplog.records) @pytest.mark.asyncio - async def test_skip_native_api_when_native_token_counting_false(self, messages): - model = LlamaCppModel(base_url="http://localhost:8080", native_token_counting=False) + async def test_skip_native_api_when_use_native_token_count_false(self, messages): + model = LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=False) model.client.post = AsyncMock() result = await model.count_tokens(messages=messages) diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 05e639e0d..47acfded4 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1319,9 +1319,9 @@ async def test_fallback_logs_debug(self, model, openai_client, messages, caplog) assert any("native token counting failed" in record.message for record in caplog.records) @pytest.mark.asyncio - async def test_skip_native_api_when_native_token_counting_false(self, openai_client, messages): + async def test_skip_native_api_when_use_native_token_count_false(self, openai_client, messages): _ = openai_client - model = OpenAIResponsesModel(model_id="gpt-4o", native_token_counting=False) + model = OpenAIResponsesModel(model_id="gpt-4o", use_native_token_count=False) result = await model.count_tokens(messages=messages)