Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 1 addition & 72 deletions src/strands/models/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Abstract base class for Agent model providers."""

import abc
import functools
import json
import logging
import math
Expand All @@ -24,9 +23,6 @@

T = TypeVar("T", bound=BaseModel)

_DEFAULT_ENCODING = "cl100k_base"


def _heuristic_estimate_text(text: str) -> int:
Comment thread
mehtarac marked this conversation as resolved.
"""Estimate token count from text using characters / 4 heuristic."""
return math.ceil(len(text) / 4)
Expand All @@ -40,22 +36,6 @@ def _heuristic_estimate_json(obj: Any) -> int:
return 0


@functools.lru_cache(maxsize=1)
def _get_encoding() -> Any:
"""Get the default tiktoken encoding, caching to avoid repeated lookups.

Returns:
The tiktoken encoding, or None if tiktoken is not installed.
"""
try:
import tiktoken

return tiktoken.get_encoding(_DEFAULT_ENCODING)
except ImportError:
logger.debug("tiktoken not available, falling back to heuristic token estimation")
return None


def _count_content_block_tokens(
block: ContentBlock, count_text: Callable[[str], int], count_json: Callable[[Any], int]
) -> int:
Expand Down Expand Up @@ -104,54 +84,6 @@ def _count_content_block_tokens(
return total


def _estimate_tokens_with_tiktoken(
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
) -> int:
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.

This is a best-effort fallback for providers that don't expose native counting.
Accuracy varies by model but is sufficient for threshold-based decisions.

Raises:
ImportError: If tiktoken is not installed.
"""
encoding = _get_encoding()
if encoding is None:
raise ImportError("tiktoken is not available")

def count_text(text: str) -> int:
return len(encoding.encode(text))

def count_json(obj: Any) -> int:
try:
return len(encoding.encode(json.dumps(obj)))
except (TypeError, ValueError):
return 0

total = 0

# Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting,
# since providers wrap system_prompt into system_prompt_content when both are provided.
if system_prompt_content:
for block in system_prompt_content:
if "text" in block:
total += count_text(block["text"])
elif system_prompt:
total += count_text(system_prompt)

for message in messages:
for block in message["content"]:
total += _count_content_block_tokens(block, count_text, count_json)

if tool_specs:
for spec in tool_specs:
total += count_json(spec)

return total


def _estimate_tokens_with_heuristic(
messages: Messages,
Expand Down Expand Up @@ -338,10 +270,7 @@ async def count_tokens(
Returns:
Estimated total input tokens.
"""
try:
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content)
except ImportError:
return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content)
return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content)


class _ModelPlugin(Plugin):
Expand Down
11 changes: 1 addition & 10 deletions src/strands/vended_plugins/context_offloader/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from typing import TYPE_CHECKING

from ...hooks.events import AfterToolCallEvent
from ...models.model import _get_encoding
from ...plugins import Plugin, hook
from ...tools.decorator import tool
from ...types.content import Message
Expand Down Expand Up @@ -318,20 +317,12 @@ async def _handle_tool_result(self, event: AfterToolCallEvent) -> None:
)

def _slice_preview(self, text: str) -> str:
"""Slice text to approximately preview_tokens.

Uses tiktoken for exact token-level slicing when available,
falls back to characters (tokens * 4) otherwise.
"""Slice text to approximately preview_tokens using character-based estimation.

Args:
text: The full text to slice.

Returns:
The preview text.
"""
encoding = _get_encoding()
if encoding is not None:
tokens = encoding.encode(text)
preview: str = encoding.decode(tokens[: self._preview_tokens])
return preview
return text[: self._preview_tokens * _CHARS_PER_TOKEN]
94 changes: 24 additions & 70 deletions tests/strands/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,35 +244,35 @@ async def test_count_tokens_empty_messages(model):
@pytest.mark.asyncio
async def test_count_tokens_system_prompt_only(model):
result = await model.count_tokens(messages=[], system_prompt="You are a helpful assistant.")
assert result == 6
assert result == 7 # ceil(28/4)


@pytest.mark.asyncio
async def test_count_tokens_text_messages(model, messages):
result = await model.count_tokens(messages=messages)
assert result == 1 # "hello"
assert result == 2 # ceil(5/4)


@pytest.mark.asyncio
async def test_count_tokens_with_tool_specs(model, messages, tool_specs):
without_tools = await model.count_tokens(messages=messages)
with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs)
assert without_tools == 1 # "hello"
assert with_tools == 49 # "hello" (1) + tool_spec (48)
assert without_tools == 2 # ceil(5/4)
assert with_tools == 84 # ceil(5/4) + ceil(164/2)


@pytest.mark.asyncio
async def test_count_tokens_with_system_prompt(model, messages, system_prompt):
without_prompt = await model.count_tokens(messages=messages)
with_prompt = await model.count_tokens(messages=messages, system_prompt=system_prompt)
assert without_prompt == 1 # "hello"
assert with_prompt == 3 # "hello" (1) + "s1" (2)
assert without_prompt == 2 # ceil(5/4)
assert with_prompt == 3 # ceil(5/4) + ceil(2/4)


@pytest.mark.asyncio
async def test_count_tokens_combined(model, messages, tool_specs, system_prompt):
result = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt=system_prompt)
assert result == 51 # "hello" (1) + tool_spec (48) + "s1" (2)
assert result == 85 # ceil(5/4) + ceil(164/2) + ceil(2/4)


@pytest.mark.asyncio
Expand All @@ -292,8 +292,8 @@ async def test_count_tokens_tool_use_block(model):
}
]
result = await model.count_tokens(messages=messages)
# name "my_tool" (2) + json.dumps(input) (6) = 8
assert result == 8
# name "my_tool" ceil(7/4)=2 + json.dumps(input) ceil(17/2)=9 = 11
assert result == 11


@pytest.mark.asyncio
Expand All @@ -313,7 +313,7 @@ async def test_count_tokens_tool_result_block(model):
}
]
result = await model.count_tokens(messages=messages)
assert result == 3 # "tool output here"
assert result == 4 # ceil(16/4)


@pytest.mark.asyncio
Expand All @@ -333,7 +333,7 @@ async def test_count_tokens_reasoning_block(model):
}
]
result = await model.count_tokens(messages=messages)
assert result == 9 # "Let me think about this step by step."
assert result == 10 # ceil(37/4)


@pytest.mark.asyncio
Expand Down Expand Up @@ -399,7 +399,7 @@ async def test_count_tokens_guard_content_block(model):
}
]
result = await model.count_tokens(messages=messages)
assert result == 8 # "This content was filtered by guardrails."
assert result == 10 # ceil(40/4)


@pytest.mark.asyncio
Expand All @@ -420,7 +420,7 @@ async def test_count_tokens_tool_use_with_bytes(model):
]
result = await model.count_tokens(messages=messages)
# Should still count the tool name even though input has non-serializable bytes
assert result == 2 # "my_tool" name only
assert result == 2 # ceil(7/4) name only


@pytest.mark.asyncio
Expand All @@ -434,7 +434,7 @@ async def test_count_tokens_non_serializable_tool_spec(model, messages):
]
result = await model.count_tokens(messages=messages, tool_specs=tool_specs)
# Should still count the message tokens even though tool spec fails
assert result == 1 # "hello" only, tool spec skipped
assert result == 2 # ceil(5/4) only, tool spec skipped


@pytest.mark.asyncio
Expand All @@ -453,7 +453,7 @@ async def test_count_tokens_citations_block(model):
}
]
result = await model.count_tokens(messages=messages)
assert result == 11 # "According to the document, the answer is 42."
assert result == 11 # ceil(44/4)


@pytest.mark.asyncio
Expand All @@ -462,7 +462,7 @@ async def test_count_tokens_system_prompt_content(model):
messages=[],
system_prompt_content=[{"text": "You are a helpful assistant."}],
)
assert result == 6 # "You are a helpful assistant."
assert result == 7 # ceil(28/4)


@pytest.mark.asyncio
Expand All @@ -474,7 +474,7 @@ async def test_count_tokens_system_prompt_content_with_cache_point(model):
{"cachePoint": {"type": "default"}},
],
)
assert result == 6 # "You are a helpful assistant.", cachePoint adds 0
assert result == 7 # ceil(28/4), cachePoint adds 0


@pytest.mark.asyncio
Expand All @@ -489,7 +489,7 @@ async def test_count_tokens_system_prompt_content_takes_priority(model):
system_prompt="This is a much longer system prompt that should have more tokens.",
system_prompt_content=[{"text": "Short."}],
)
assert content_only == 2 # "Short."
assert content_only == 2 # ceil(6/4)
assert content_only == both


Expand All @@ -505,41 +505,10 @@ async def test_count_tokens_all_inputs(model):
system_prompt="Be helpful.",
system_prompt_content=[{"text": "Additional system context."}],
)
# system_prompt_content (4) + "hello world" (2) + "hi there" (2) + tool_spec (23) = 31
assert result == 31
# system_prompt_content (7) + "hello world" (3) + "hi there" (2) + tool_spec (38) = 50
assert result == 50


def test__get_encoding_falls_back_without_tiktoken(monkeypatch):
"""Test that _get_encoding returns None and count_tokens falls back to heuristic."""
import strands.models.model as model_module

model_module._get_encoding.cache_clear()
original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__

def _block_tiktoken(name, *args, **kwargs):
if name == "tiktoken":
raise ImportError("No module named 'tiktoken'")
return original_import(name, *args, **kwargs)

monkeypatch.setattr("builtins.__import__", _block_tiktoken)

try:
assert model_module._get_encoding() is None

# _estimate_tokens_with_tiktoken should raise when tiktoken is unavailable
with pytest.raises(ImportError):
model_module._estimate_tokens_with_tiktoken(
messages=[{"role": "user", "content": [{"text": "hello world!"}]}],
)

# _estimate_tokens_with_heuristic uses chars/4 for text
result = model_module._estimate_tokens_with_heuristic(
messages=[{"role": "user", "content": [{"text": "hello world!"}]}],
)
assert result == 3 # ceil(12 / 4)
finally:
model_module._get_encoding.cache_clear()


class TestHeuristicEstimation:
"""Tests for _estimate_tokens_with_heuristic."""
Expand Down Expand Up @@ -592,22 +561,7 @@ def test_non_serializable_inputs(self):
assert result == 2 # only tool name counted: ceil(len("my_tool") / 4)

@pytest.mark.asyncio
async def test_model_falls_back_to_heuristic(self, monkeypatch, model):
"""Model.count_tokens falls back to heuristic when tiktoken unavailable."""
import strands.models.model as model_module

model_module._get_encoding.cache_clear()
original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__

def _block_tiktoken(name, *args, **kwargs):
if name == "tiktoken":
raise ImportError("No module named 'tiktoken'")
return original_import(name, *args, **kwargs)

monkeypatch.setattr("builtins.__import__", _block_tiktoken)

try:
result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}])
assert result == 3 # ceil(12 / 4)
finally:
model_module._get_encoding.cache_clear()
async def test_model_uses_heuristic(self, model):
"""Model.count_tokens uses heuristic estimation."""
result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}])
assert result == 3 # ceil(12 / 4)
Loading