Skip to content

Commit 800e7c4

Browse files
authored
feat: add useNativeTokenCount flag to skip token counting API calls (#2255)
1 parent 6b0df9a commit 800e7c4

12 files changed

Lines changed: 93 additions & 2 deletions

File tree

src/strands/models/anthropic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,15 @@ class AnthropicConfig(BaseModelConfig, total=False):
5757
https://docs.anthropic.com/en/docs/about-claude/models/all-models.
5858
params: Additional model parameters (e.g., temperature).
5959
For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages.
60+
use_native_token_count: Whether to use the native Anthropic count_tokens API.
61+
When True (default), count_tokens() calls the Anthropic API for accurate counts.
62+
When False, skips the API call and uses the local estimator.
6063
"""
6164

6265
max_tokens: Required[int]
6366
model_id: Required[str]
6467
params: dict[str, Any] | None
68+
use_native_token_count: bool
6569

6670
def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]):
6771
"""Initialize provider instance.
@@ -394,6 +398,9 @@ async def count_tokens(
394398
Returns:
395399
Total input token count.
396400
"""
401+
if self.config.get("use_native_token_count") is False:
402+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
403+
397404
try:
398405
# system_prompt_content is not used; this provider only accepts system_prompt as a plain string,
399406
# matching the behavior of stream(). The caller always provides system_prompt alongside

src/strands/models/bedrock.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ class BedrockConfig(BaseModelConfig, total=False):
117117
See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html
118118
temperature: Controls randomness in generation (higher = more random)
119119
top_p: Controls diversity via nucleus sampling (alternative to temperature)
120+
use_native_token_count: Whether to use the native Bedrock CountTokens API.
121+
When True (default), count_tokens() calls the Bedrock API for accurate counts.
122+
When False, skips the API call and uses the local estimator.
120123
"""
121124

122125
additional_args: dict[str, Any] | None
@@ -143,6 +146,7 @@ class BedrockConfig(BaseModelConfig, total=False):
143146
strict_tools: bool | None
144147
temperature: float | None
145148
top_p: float | None
149+
use_native_token_count: bool
146150

147151
def __init__(
148152
self,
@@ -794,6 +798,9 @@ async def count_tokens(
794798
Returns:
795799
Total input token count.
796800
"""
801+
if self.config.get("use_native_token_count") is False:
802+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
803+
797804
model_id: str = self.config["model_id"]
798805

799806
if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS:

src/strands/models/gemini.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,15 @@ class GeminiConfig(BaseModelConfig, total=False):
4949
Use the standard tools interface for function calling tools.
5050
For a complete list of supported tools, see
5151
https://ai.google.dev/api/caching#Tool
52+
use_native_token_count: Whether to use the native Gemini count_tokens API.
53+
When True (default), count_tokens() calls the Gemini API for accurate counts.
54+
When False, skips the API call and uses the local estimator.
5255
"""
5356

5457
model_id: Required[str]
5558
params: dict[str, Any]
5659
gemini_tools: list[genai.types.Tool]
60+
use_native_token_count: bool
5761

5862
def __init__(
5963
self,
@@ -457,6 +461,9 @@ async def count_tokens(
457461
Returns:
458462
Total input token count.
459463
"""
464+
if self.config.get("use_native_token_count") is False:
465+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
466+
460467
try:
461468
contents = list(self._format_request_content(messages))
462469

src/strands/models/llamacpp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,14 @@ class LlamaCppConfig(BaseModelConfig, total=False):
125125
- cache_prompt: Cache the prompt for faster generation
126126
- slot_id: Slot ID for parallel inference
127127
- samplers: Custom sampler order
128+
use_native_token_count: Whether to use the native llama.cpp /tokenize endpoint.
129+
When True (default), count_tokens() calls the server's tokenize endpoint for accurate counts.
130+
When False, skips the API call and uses the local estimator.
128131
"""
129132

130133
model_id: str
131134
params: dict[str, Any] | None
135+
use_native_token_count: bool
132136

133137
def __init__(
134138
self,
@@ -533,6 +537,9 @@ async def count_tokens(
533537
Returns:
534538
Total input token count.
535539
"""
540+
if self.config.get("use_native_token_count") is False:
541+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
542+
536543
try:
537544
# system_prompt_content is not used; this provider only accepts system_prompt as a plain string,
538545
# matching the behavior of stream(). The caller always provides system_prompt alongside

src/strands/models/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
T = TypeVar("T", bound=BaseModel)
2525

26+
2627
def _heuristic_estimate_text(text: str) -> int:
2728
"""Estimate token count from text using characters / 4 heuristic."""
2829
return math.ceil(len(text) / 4)
@@ -84,7 +85,6 @@ def _count_content_block_tokens(
8485
return total
8586

8687

87-
8888
def _estimate_tokens_with_heuristic(
8989
messages: Messages,
9090
tool_specs: list[ToolSpec] | None = None,

src/strands/models/openai_responses.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,15 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False):
136136
stateful: Whether to enable server-side conversation state management.
137137
When True, the server stores conversation history and the client does not need to
138138
send the full message history with each request. Defaults to False.
139+
use_native_token_count: Whether to use the native OpenAI input_tokens.count API.
140+
When True (default), count_tokens() calls the OpenAI API for accurate counts.
141+
When False, skips the API call and uses the local estimator.
139142
"""
140143

141144
model_id: str
142145
params: dict[str, Any] | None
143146
stateful: bool
147+
use_native_token_count: bool
144148

145149
def __init__(
146150
self,
@@ -238,6 +242,9 @@ async def count_tokens(
238242
Returns:
239243
Total input token count.
240244
"""
245+
if self.config.get("use_native_token_count") is False:
246+
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)
247+
241248
try:
242249
# system_prompt_content is not used; this provider only accepts system_prompt as a plain string,
243250
# matching the behavior of stream(). The caller always provides system_prompt alongside

tests/strands/models/test_anthropic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,3 +1162,16 @@ async def test_fallback_logs_debug(self, model_with_client, anthropic_client, me
11621162
await model_with_client.count_tokens(messages=messages)
11631163

11641164
assert any("native token counting failed" in record.message for record in caplog.records)
1165+
1166+
@pytest.mark.asyncio
1167+
async def test_skip_native_api_when_use_native_token_count_false(
1168+
self, anthropic_client, model_id, max_tokens, messages
1169+
):
1170+
_ = anthropic_client
1171+
model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=False)
1172+
1173+
result = await model.count_tokens(messages=messages)
1174+
1175+
anthropic_client.messages.count_tokens.assert_not_called()
1176+
assert isinstance(result, int)
1177+
assert result >= 0

tests/strands/models/test_bedrock.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3484,3 +3484,14 @@ async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, me
34843484
# Second call should still attempt the API
34853485
await model.count_tokens(messages=messages)
34863486
assert bedrock_client.count_tokens.call_count == 2
3487+
3488+
@pytest.mark.asyncio
3489+
async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_client, model_id, messages):
3490+
_ = bedrock_client
3491+
model = BedrockModel(model_id=model_id, use_native_token_count=False)
3492+
3493+
result = await model.count_tokens(messages=messages)
3494+
3495+
bedrock_client.count_tokens.assert_not_called()
3496+
assert isinstance(result, int)
3497+
assert result >= 0

tests/strands/models/test_gemini.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,3 +1228,14 @@ async def test_fallback_logs_debug(self, model, gemini_client, messages, caplog)
12281228
await model.count_tokens(messages=messages)
12291229

12301230
assert any("native token counting failed" in record.message for record in caplog.records)
1231+
1232+
@pytest.mark.asyncio
1233+
async def test_skip_native_api_when_use_native_token_count_false(self, gemini_client, messages):
1234+
_ = gemini_client
1235+
model = GeminiModel(model_id="m1", use_native_token_count=False)
1236+
1237+
result = await model.count_tokens(messages=messages)
1238+
1239+
gemini_client.aio.models.count_tokens.assert_not_called()
1240+
assert isinstance(result, int)
1241+
assert result >= 0

tests/strands/models/test_llamacpp.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,3 +803,14 @@ async def test_fallback_logs_debug(self, model, messages, caplog):
803803
await model.count_tokens(messages=messages)
804804

805805
assert any("native token counting failed" in record.message for record in caplog.records)
806+
807+
@pytest.mark.asyncio
808+
async def test_skip_native_api_when_use_native_token_count_false(self, messages):
809+
model = LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=False)
810+
model.client.post = AsyncMock()
811+
812+
result = await model.count_tokens(messages=messages)
813+
814+
model.client.post.assert_not_called()
815+
assert isinstance(result, int)
816+
assert result >= 0

0 commit comments

Comments
 (0)