Skip to content

Commit d94d516

Browse files
authored
fix: fix count tokens for bedrock models (#2254)
1 parent 559b2a0 commit d94d516

3 files changed

Lines changed: 26 additions & 152 deletions

File tree

src/strands/models/model.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Abstract base class for Agent model providers."""
22

33
import abc
4-
import functools
54
import json
65
import logging
76
import math
@@ -24,9 +23,6 @@
2423

2524
T = TypeVar("T", bound=BaseModel)
2625

27-
_DEFAULT_ENCODING = "cl100k_base"
28-
29-
3026
def _heuristic_estimate_text(text: str) -> int:
3127
"""Estimate token count from text using characters / 4 heuristic."""
3228
return math.ceil(len(text) / 4)
@@ -40,22 +36,6 @@ def _heuristic_estimate_json(obj: Any) -> int:
4036
return 0
4137

4238

43-
@functools.lru_cache(maxsize=1)
44-
def _get_encoding() -> Any:
45-
"""Get the default tiktoken encoding, caching to avoid repeated lookups.
46-
47-
Returns:
48-
The tiktoken encoding, or None if tiktoken is not installed.
49-
"""
50-
try:
51-
import tiktoken
52-
53-
return tiktoken.get_encoding(_DEFAULT_ENCODING)
54-
except ImportError:
55-
logger.debug("tiktoken not available, falling back to heuristic token estimation")
56-
return None
57-
58-
5939
def _count_content_block_tokens(
6040
block: ContentBlock, count_text: Callable[[str], int], count_json: Callable[[Any], int]
6141
) -> int:
@@ -104,54 +84,6 @@ def _count_content_block_tokens(
10484
return total
10585

10686

107-
def _estimate_tokens_with_tiktoken(
108-
messages: Messages,
109-
tool_specs: list[ToolSpec] | None = None,
110-
system_prompt: str | None = None,
111-
system_prompt_content: list[SystemContentBlock] | None = None,
112-
) -> int:
113-
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.
114-
115-
This is a best-effort fallback for providers that don't expose native counting.
116-
Accuracy varies by model but is sufficient for threshold-based decisions.
117-
118-
Raises:
119-
ImportError: If tiktoken is not installed.
120-
"""
121-
encoding = _get_encoding()
122-
if encoding is None:
123-
raise ImportError("tiktoken is not available")
124-
125-
def count_text(text: str) -> int:
126-
return len(encoding.encode(text))
127-
128-
def count_json(obj: Any) -> int:
129-
try:
130-
return len(encoding.encode(json.dumps(obj)))
131-
except (TypeError, ValueError):
132-
return 0
133-
134-
total = 0
135-
136-
# Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting,
137-
# since providers wrap system_prompt into system_prompt_content when both are provided.
138-
if system_prompt_content:
139-
for block in system_prompt_content:
140-
if "text" in block:
141-
total += count_text(block["text"])
142-
elif system_prompt:
143-
total += count_text(system_prompt)
144-
145-
for message in messages:
146-
for block in message["content"]:
147-
total += _count_content_block_tokens(block, count_text, count_json)
148-
149-
if tool_specs:
150-
for spec in tool_specs:
151-
total += count_json(spec)
152-
153-
return total
154-
15587

15688
def _estimate_tokens_with_heuristic(
15789
messages: Messages,
@@ -338,10 +270,7 @@ async def count_tokens(
338270
Returns:
339271
Estimated total input tokens.
340272
"""
341-
try:
342-
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content)
343-
except ImportError:
344-
return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content)
273+
return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content)
345274

346275

347276
class _ModelPlugin(Plugin):

src/strands/vended_plugins/context_offloader/plugin.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from typing import TYPE_CHECKING
3838

3939
from ...hooks.events import AfterToolCallEvent
40-
from ...models.model import _get_encoding
4140
from ...plugins import Plugin, hook
4241
from ...tools.decorator import tool
4342
from ...types.content import Message
@@ -318,20 +317,12 @@ async def _handle_tool_result(self, event: AfterToolCallEvent) -> None:
318317
)
319318

320319
def _slice_preview(self, text: str) -> str:
321-
"""Slice text to approximately preview_tokens.
322-
323-
Uses tiktoken for exact token-level slicing when available,
324-
falls back to characters (tokens * 4) otherwise.
320+
"""Slice text to approximately preview_tokens using character-based estimation.
325321
326322
Args:
327323
text: The full text to slice.
328324
329325
Returns:
330326
The preview text.
331327
"""
332-
encoding = _get_encoding()
333-
if encoding is not None:
334-
tokens = encoding.encode(text)
335-
preview: str = encoding.decode(tokens[: self._preview_tokens])
336-
return preview
337328
return text[: self._preview_tokens * _CHARS_PER_TOKEN]

tests/strands/models/test_model.py

Lines changed: 24 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -244,35 +244,35 @@ async def test_count_tokens_empty_messages(model):
244244
@pytest.mark.asyncio
245245
async def test_count_tokens_system_prompt_only(model):
246246
result = await model.count_tokens(messages=[], system_prompt="You are a helpful assistant.")
247-
assert result == 6
247+
assert result == 7 # ceil(28/4)
248248

249249

250250
@pytest.mark.asyncio
251251
async def test_count_tokens_text_messages(model, messages):
252252
result = await model.count_tokens(messages=messages)
253-
assert result == 1 # "hello"
253+
assert result == 2 # ceil(5/4)
254254

255255

256256
@pytest.mark.asyncio
257257
async def test_count_tokens_with_tool_specs(model, messages, tool_specs):
258258
without_tools = await model.count_tokens(messages=messages)
259259
with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs)
260-
assert without_tools == 1 # "hello"
261-
assert with_tools == 49 # "hello" (1) + tool_spec (48)
260+
assert without_tools == 2 # ceil(5/4)
261+
assert with_tools == 84 # ceil(5/4) + ceil(164/2)
262262

263263

264264
@pytest.mark.asyncio
265265
async def test_count_tokens_with_system_prompt(model, messages, system_prompt):
266266
without_prompt = await model.count_tokens(messages=messages)
267267
with_prompt = await model.count_tokens(messages=messages, system_prompt=system_prompt)
268-
assert without_prompt == 1 # "hello"
269-
assert with_prompt == 3 # "hello" (1) + "s1" (2)
268+
assert without_prompt == 2 # ceil(5/4)
269+
assert with_prompt == 3 # ceil(5/4) + ceil(2/4)
270270

271271

272272
@pytest.mark.asyncio
273273
async def test_count_tokens_combined(model, messages, tool_specs, system_prompt):
274274
result = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt=system_prompt)
275-
assert result == 51 # "hello" (1) + tool_spec (48) + "s1" (2)
275+
assert result == 85 # ceil(5/4) + ceil(164/2) + ceil(2/4)
276276

277277

278278
@pytest.mark.asyncio
@@ -292,8 +292,8 @@ async def test_count_tokens_tool_use_block(model):
292292
}
293293
]
294294
result = await model.count_tokens(messages=messages)
295-
# name "my_tool" (2) + json.dumps(input) (6) = 8
296-
assert result == 8
295+
# name "my_tool" ceil(7/4)=2 + json.dumps(input) ceil(17/2)=9 = 11
296+
assert result == 11
297297

298298

299299
@pytest.mark.asyncio
@@ -313,7 +313,7 @@ async def test_count_tokens_tool_result_block(model):
313313
}
314314
]
315315
result = await model.count_tokens(messages=messages)
316-
assert result == 3 # "tool output here"
316+
assert result == 4 # ceil(16/4)
317317

318318

319319
@pytest.mark.asyncio
@@ -333,7 +333,7 @@ async def test_count_tokens_reasoning_block(model):
333333
}
334334
]
335335
result = await model.count_tokens(messages=messages)
336-
assert result == 9 # "Let me think about this step by step."
336+
assert result == 10 # ceil(37/4)
337337

338338

339339
@pytest.mark.asyncio
@@ -399,7 +399,7 @@ async def test_count_tokens_guard_content_block(model):
399399
}
400400
]
401401
result = await model.count_tokens(messages=messages)
402-
assert result == 8 # "This content was filtered by guardrails."
402+
assert result == 10 # ceil(40/4)
403403

404404

405405
@pytest.mark.asyncio
@@ -420,7 +420,7 @@ async def test_count_tokens_tool_use_with_bytes(model):
420420
]
421421
result = await model.count_tokens(messages=messages)
422422
# Should still count the tool name even though input has non-serializable bytes
423-
assert result == 2 # "my_tool" name only
423+
assert result == 2 # ceil(7/4) name only
424424

425425

426426
@pytest.mark.asyncio
@@ -434,7 +434,7 @@ async def test_count_tokens_non_serializable_tool_spec(model, messages):
434434
]
435435
result = await model.count_tokens(messages=messages, tool_specs=tool_specs)
436436
# Should still count the message tokens even though tool spec fails
437-
assert result == 1 # "hello" only, tool spec skipped
437+
assert result == 2 # ceil(5/4) only, tool spec skipped
438438

439439

440440
@pytest.mark.asyncio
@@ -453,7 +453,7 @@ async def test_count_tokens_citations_block(model):
453453
}
454454
]
455455
result = await model.count_tokens(messages=messages)
456-
assert result == 11 # "According to the document, the answer is 42."
456+
assert result == 11 # ceil(44/4)
457457

458458

459459
@pytest.mark.asyncio
@@ -462,7 +462,7 @@ async def test_count_tokens_system_prompt_content(model):
462462
messages=[],
463463
system_prompt_content=[{"text": "You are a helpful assistant."}],
464464
)
465-
assert result == 6 # "You are a helpful assistant."
465+
assert result == 7 # ceil(28/4)
466466

467467

468468
@pytest.mark.asyncio
@@ -474,7 +474,7 @@ async def test_count_tokens_system_prompt_content_with_cache_point(model):
474474
{"cachePoint": {"type": "default"}},
475475
],
476476
)
477-
assert result == 6 # "You are a helpful assistant.", cachePoint adds 0
477+
assert result == 7 # ceil(28/4), cachePoint adds 0
478478

479479

480480
@pytest.mark.asyncio
@@ -489,7 +489,7 @@ async def test_count_tokens_system_prompt_content_takes_priority(model):
489489
system_prompt="This is a much longer system prompt that should have more tokens.",
490490
system_prompt_content=[{"text": "Short."}],
491491
)
492-
assert content_only == 2 # "Short."
492+
assert content_only == 2 # ceil(6/4)
493493
assert content_only == both
494494

495495

@@ -505,41 +505,10 @@ async def test_count_tokens_all_inputs(model):
505505
system_prompt="Be helpful.",
506506
system_prompt_content=[{"text": "Additional system context."}],
507507
)
508-
# system_prompt_content (4) + "hello world" (2) + "hi there" (2) + tool_spec (23) = 31
509-
assert result == 31
508+
# system_prompt_content (7) + "hello world" (3) + "hi there" (2) + tool_spec (38) = 50
509+
assert result == 50
510510

511511

512-
def test__get_encoding_falls_back_without_tiktoken(monkeypatch):
513-
"""Test that _get_encoding returns None and count_tokens falls back to heuristic."""
514-
import strands.models.model as model_module
515-
516-
model_module._get_encoding.cache_clear()
517-
original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__
518-
519-
def _block_tiktoken(name, *args, **kwargs):
520-
if name == "tiktoken":
521-
raise ImportError("No module named 'tiktoken'")
522-
return original_import(name, *args, **kwargs)
523-
524-
monkeypatch.setattr("builtins.__import__", _block_tiktoken)
525-
526-
try:
527-
assert model_module._get_encoding() is None
528-
529-
# _estimate_tokens_with_tiktoken should raise when tiktoken is unavailable
530-
with pytest.raises(ImportError):
531-
model_module._estimate_tokens_with_tiktoken(
532-
messages=[{"role": "user", "content": [{"text": "hello world!"}]}],
533-
)
534-
535-
# _estimate_tokens_with_heuristic uses chars/4 for text
536-
result = model_module._estimate_tokens_with_heuristic(
537-
messages=[{"role": "user", "content": [{"text": "hello world!"}]}],
538-
)
539-
assert result == 3 # ceil(12 / 4)
540-
finally:
541-
model_module._get_encoding.cache_clear()
542-
543512

544513
class TestHeuristicEstimation:
545514
"""Tests for _estimate_tokens_with_heuristic."""
@@ -592,22 +561,7 @@ def test_non_serializable_inputs(self):
592561
assert result == 2 # only tool name counted: ceil(len("my_tool") / 4)
593562

594563
@pytest.mark.asyncio
595-
async def test_model_falls_back_to_heuristic(self, monkeypatch, model):
596-
"""Model.count_tokens falls back to heuristic when tiktoken unavailable."""
597-
import strands.models.model as model_module
598-
599-
model_module._get_encoding.cache_clear()
600-
original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__
601-
602-
def _block_tiktoken(name, *args, **kwargs):
603-
if name == "tiktoken":
604-
raise ImportError("No module named 'tiktoken'")
605-
return original_import(name, *args, **kwargs)
606-
607-
monkeypatch.setattr("builtins.__import__", _block_tiktoken)
608-
609-
try:
610-
result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}])
611-
assert result == 3 # ceil(12 / 4)
612-
finally:
613-
model_module._get_encoding.cache_clear()
564+
async def test_model_uses_heuristic(self, model):
565+
"""Model.count_tokens uses heuristic estimation."""
566+
result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}])
567+
assert result == 3 # ceil(12 / 4)

0 commit comments

Comments
 (0)