diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index ece7cd8d1..04fae220d 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -57,11 +57,15 @@ class AnthropicConfig(BaseModelConfig, total=False): https://docs.anthropic.com/en/docs/about-claude/models/all-models. params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages. + use_native_token_count: Whether to use the native Anthropic count_tokens API. + When True (default), count_tokens() calls the Anthropic API for accurate counts. + When False, skips the API call and uses the local estimator. """ max_tokens: Required[int] model_id: Required[str] params: dict[str, Any] | None + use_native_token_count: bool def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]): """Initialize provider instance. @@ -394,6 +398,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c1cbfa265..c74a63a3b 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -117,6 +117,9 @@ class BedrockConfig(BaseModelConfig, total=False): See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) + use_native_token_count: Whether to use the native Bedrock CountTokens API. + When True (default), count_tokens() calls the Bedrock API for accurate counts. + When False, skips the API call and uses the local estimator. """ additional_args: dict[str, Any] | None @@ -143,6 +146,7 @@ class BedrockConfig(BaseModelConfig, total=False): strict_tools: bool | None temperature: float | None top_p: float | None + use_native_token_count: bool def __init__( self, @@ -794,6 +798,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + model_id: str = self.config["model_id"] if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS: diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 65b925c6d..8ed579d38 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -49,11 +49,15 @@ class GeminiConfig(BaseModelConfig, total=False): Use the standard tools interface for function calling tools. For a complete list of supported tools, see https://ai.google.dev/api/caching#Tool + use_native_token_count: Whether to use the native Gemini count_tokens API. + When True (default), count_tokens() calls the Gemini API for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: Required[str] params: dict[str, Any] gemini_tools: list[genai.types.Tool] + use_native_token_count: bool def __init__( self, @@ -457,6 +461,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: contents = list(self._format_request_content(messages)) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index c31ba11bc..531cf6b50 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -125,10 +125,14 @@ class LlamaCppConfig(BaseModelConfig, total=False): - cache_prompt: Cache the prompt for faster generation - slot_id: Slot ID for parallel inference - samplers: Custom sampler order + use_native_token_count: Whether to use the native llama.cpp /tokenize endpoint. + When True (default), count_tokens() calls the server's tokenize endpoint for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: str params: dict[str, Any] | None + use_native_token_count: bool def __init__( self, @@ -533,6 +537,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 3ded11a28..dd2f9eed2 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -23,6 +23,7 @@ T = TypeVar("T", bound=BaseModel) + def _heuristic_estimate_text(text: str) -> int: """Estimate token count from text using characters / 4 heuristic.""" return math.ceil(len(text) / 4) @@ -84,7 +85,6 @@ def _count_content_block_tokens( return total - def _estimate_tokens_with_heuristic( messages: Messages, tool_specs: list[ToolSpec] | None = None, diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index a78cef73a..c6ddbb9d6 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -136,11 +136,15 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False): stateful: Whether to enable server-side conversation state management. When True, the server stores conversation history and the client does not need to send the full message history with each request. Defaults to False. + use_native_token_count: Whether to use the native OpenAI input_tokens.count API. + When True (default), count_tokens() calls the OpenAI API for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: str params: dict[str, Any] | None stateful: bool + use_native_token_count: bool def __init__( self, @@ -238,6 +242,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index abb56a441..6de821e90 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1162,3 +1162,16 @@ async def test_fallback_logs_debug(self, model_with_client, anthropic_client, me await model_with_client.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false( + self, anthropic_client, model_id, max_tokens, messages + ): + _ = anthropic_client + model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + anthropic_client.messages.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index f177a8a17..2f1f7d1f1 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3484,3 +3484,14 @@ async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, me # Second call should still attempt the API await model.count_tokens(messages=messages) assert bedrock_client.count_tokens.call_count == 2 + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_client, model_id, messages): + _ = bedrock_client + model = BedrockModel(model_id=model_id, use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 91a55d899..b846bfcdf 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -1228,3 +1228,14 @@ async def test_fallback_logs_debug(self, model, gemini_client, messages, caplog) await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, gemini_client, messages): + _ = gemini_client + model = GeminiModel(model_id="m1", use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + gemini_client.aio.models.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index a891ec929..43fb03629 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -803,3 +803,14 @@ async def test_fallback_logs_debug(self, model, messages, caplog): await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, messages): + model = LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=False) + model.client.post = AsyncMock() + + result = await model.count_tokens(messages=messages) + + model.client.post.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index b362740b5..34f4ef328 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -509,7 +509,6 @@ async def test_count_tokens_all_inputs(model): assert result == 50 - class TestHeuristicEstimation: """Tests for _estimate_tokens_with_heuristic.""" diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 97ee9e305..47acfded4 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1318,6 +1318,17 @@ async def test_fallback_logs_debug(self, model, openai_client, messages, caplog) assert any("native token counting failed" in record.message for record in caplog.records) + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, openai_client, messages): + _ = openai_client + model = OpenAIResponsesModel(model_id="gpt-4o", use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + openai_client.responses.input_tokens.count.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + # ============================================================================= # Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel