Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ class AnthropicConfig(BaseModelConfig, total=False):
https://docs.anthropic.com/en/docs/about-claude/models/all-models.
params: Additional model parameters (e.g., temperature).
For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages.
use_native_token_count: Whether to use the native Anthropic count_tokens API.
When True (default), count_tokens() calls the Anthropic API for accurate counts.
When False, skips the API call and uses the local estimator.
"""

max_tokens: Required[int]
model_id: Required[str]
params: dict[str, Any] | None
use_native_token_count: bool

def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]):
"""Initialize provider instance.
Expand Down Expand Up @@ -394,6 +398,9 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
# system_prompt_content is not used; this provider only accepts system_prompt as a plain string,
# matching the behavior of stream(). The caller always provides system_prompt alongside
Expand Down
7 changes: 7 additions & 0 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class BedrockConfig(BaseModelConfig, total=False):
See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html
temperature: Controls randomness in generation (higher = more random)
top_p: Controls diversity via nucleus sampling (alternative to temperature)
use_native_token_count: Whether to use the native Bedrock CountTokens API.
When True (default), count_tokens() calls the Bedrock API for accurate counts.
When False, skips the API call and uses the local estimator.
"""

additional_args: dict[str, Any] | None
Expand All @@ -143,6 +146,7 @@ class BedrockConfig(BaseModelConfig, total=False):
strict_tools: bool | None
temperature: float | None
top_p: float | None
use_native_token_count: bool

def __init__(
self,
Expand Down Expand Up @@ -794,6 +798,9 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

model_id: str = self.config["model_id"]

if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS:
Expand Down
7 changes: 7 additions & 0 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ class GeminiConfig(BaseModelConfig, total=False):
Use the standard tools interface for function calling tools.
For a complete list of supported tools, see
https://ai.google.dev/api/caching#Tool
use_native_token_count: Whether to use the native Gemini count_tokens API.
When True (default), count_tokens() calls the Gemini API for accurate counts.
When False, skips the API call and uses the local estimator.
"""

model_id: Required[str]
params: dict[str, Any]
gemini_tools: list[genai.types.Tool]
use_native_token_count: bool

def __init__(
self,
Expand Down Expand Up @@ -457,6 +461,9 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
contents = list(self._format_request_content(messages))

Expand Down
7 changes: 7 additions & 0 deletions src/strands/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,14 @@ class LlamaCppConfig(BaseModelConfig, total=False):
- cache_prompt: Cache the prompt for faster generation
- slot_id: Slot ID for parallel inference
- samplers: Custom sampler order
use_native_token_count: Whether to use the native llama.cpp /tokenize endpoint.
When True (default), count_tokens() calls the server's tokenize endpoint for accurate counts.
When False, skips the API call and uses the local estimator.
"""

model_id: str
params: dict[str, Any] | None
use_native_token_count: bool

def __init__(
self,
Expand Down Expand Up @@ -533,6 +537,9 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
# system_prompt_content is not used; this provider only accepts system_prompt as a plain string,
# matching the behavior of stream(). The caller always provides system_prompt alongside
Expand Down
2 changes: 1 addition & 1 deletion src/strands/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

T = TypeVar("T", bound=BaseModel)


def _heuristic_estimate_text(text: str) -> int:
"""Estimate token count from text using characters / 4 heuristic."""
return math.ceil(len(text) / 4)
Expand Down Expand Up @@ -84,7 +85,6 @@ def _count_content_block_tokens(
return total



def _estimate_tokens_with_heuristic(
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
Expand Down
7 changes: 7 additions & 0 deletions src/strands/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,15 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False):
stateful: Whether to enable server-side conversation state management.
When True, the server stores conversation history and the client does not need to
send the full message history with each request. Defaults to False.
use_native_token_count: Whether to use the native OpenAI input_tokens.count API.
When True (default), count_tokens() calls the OpenAI API for accurate counts.
When False, skips the API call and uses the local estimator.
"""

model_id: str
params: dict[str, Any] | None
stateful: bool
use_native_token_count: bool

def __init__(
self,
Expand Down Expand Up @@ -238,6 +242,9 @@ async def count_tokens(
Returns:
Total input token count.
"""
if self.config.get("use_native_token_count") is False:
return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content)

try:
# system_prompt_content is not used; this provider only accepts system_prompt as a plain string,
# matching the behavior of stream(). The caller always provides system_prompt alongside
Expand Down
13 changes: 13 additions & 0 deletions tests/strands/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,3 +1162,16 @@ async def test_fallback_logs_debug(self, model_with_client, anthropic_client, me
await model_with_client.count_tokens(messages=messages)

assert any("native token counting failed" in record.message for record in caplog.records)

@pytest.mark.asyncio
async def test_skip_native_api_when_use_native_token_count_false(
self, anthropic_client, model_id, max_tokens, messages
):
_ = anthropic_client
model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=False)

result = await model.count_tokens(messages=messages)

anthropic_client.messages.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0
11 changes: 11 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3484,3 +3484,14 @@ async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, me
# Second call should still attempt the API
await model.count_tokens(messages=messages)
assert bedrock_client.count_tokens.call_count == 2

@pytest.mark.asyncio
async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_client, model_id, messages):
_ = bedrock_client
model = BedrockModel(model_id=model_id, use_native_token_count=False)

result = await model.count_tokens(messages=messages)

bedrock_client.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0
11 changes: 11 additions & 0 deletions tests/strands/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,3 +1228,14 @@ async def test_fallback_logs_debug(self, model, gemini_client, messages, caplog)
await model.count_tokens(messages=messages)

assert any("native token counting failed" in record.message for record in caplog.records)

@pytest.mark.asyncio
async def test_skip_native_api_when_use_native_token_count_false(self, gemini_client, messages):
_ = gemini_client
model = GeminiModel(model_id="m1", use_native_token_count=False)

result = await model.count_tokens(messages=messages)

gemini_client.aio.models.count_tokens.assert_not_called()
assert isinstance(result, int)
assert result >= 0
11 changes: 11 additions & 0 deletions tests/strands/models/test_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,14 @@ async def test_fallback_logs_debug(self, model, messages, caplog):
await model.count_tokens(messages=messages)

assert any("native token counting failed" in record.message for record in caplog.records)

@pytest.mark.asyncio
async def test_skip_native_api_when_use_native_token_count_false(self, messages):
model = LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=False)
model.client.post = AsyncMock()

result = await model.count_tokens(messages=messages)

model.client.post.assert_not_called()
assert isinstance(result, int)
assert result >= 0
1 change: 0 additions & 1 deletion tests/strands/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,6 @@ async def test_count_tokens_all_inputs(model):
assert result == 50



class TestHeuristicEstimation:
"""Tests for _estimate_tokens_with_heuristic."""

Expand Down
11 changes: 11 additions & 0 deletions tests/strands/models/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,17 @@ async def test_fallback_logs_debug(self, model, openai_client, messages, caplog)

assert any("native token counting failed" in record.message for record in caplog.records)

@pytest.mark.asyncio
async def test_skip_native_api_when_use_native_token_count_false(self, openai_client, messages):
_ = openai_client
model = OpenAIResponsesModel(model_id="gpt-4o", use_native_token_count=False)

result = await model.count_tokens(messages=messages)

openai_client.responses.input_tokens.count.assert_not_called()
assert isinstance(result, int)
assert result >= 0


# =============================================================================
# Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel
Expand Down
Loading