Skip to content

Commit 7199762

Browse files
authored
feat: track reasoning token usage (#670)
* feat: track reasoning token usage Capture provider-reported reasoning-token breakdowns alongside output tokens without changing output token totals. Carry the field through model usage aggregation and add coverage for parsing, facade tracking, and deltas. Refs #665 * fix: show reasoning tokens in usage summary Include reasoning token counts in the local model usage summary while preserving output and total token semantics. Telemetry remains unchanged. Refs #665 * fix: estimate missing reasoning token counts When providers return reasoning content without a numeric usage breakdown, estimate reasoning tokens from that content while preserving provider-reported output and total token counts. Refs #665 * fix: track reasoning token count source * fix: simplify reasoning token source * fix: omit unknown reasoning tokens from logs * refactor: clarify reasoning token count helpers * test: move token counting tests * fix: enforce reasoning token source * fix: address reasoning usage review
1 parent 387be6f commit 7199762

19 files changed

Lines changed: 601 additions & 50 deletions

File tree

packages/data-designer-config/src/data_designer/lazy_heavy_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"jsonschema": "jsonschema",
4141
"PIL": "PIL",
4242
"Image": "PIL.Image",
43+
"tiktoken": "tiktoken",
4344
}
4445

4546

packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
from __future__ import annotations
55

66
import logging
7-
from functools import lru_cache
87
from numbers import Number
98
from typing import TYPE_CHECKING, Any
109

11-
import tiktoken
12-
1310
import data_designer.lazy_heavy_imports as lazy
1411
from data_designer.config.analysis.column_statistics import (
1512
CategoricalDistribution,
@@ -25,6 +22,7 @@
2522
RecordBasedPromptRenderer,
2623
create_response_recipe,
2724
)
25+
from data_designer.engine.utils.token_counting import count_text_tokens
2826

2927
if TYPE_CHECKING:
3028
import pandas as pd
@@ -38,12 +36,6 @@
3836
logger = logging.getLogger(__name__)
3937

4038

41-
@lru_cache(maxsize=1)
42-
def _get_tokenizer() -> tiktoken.Encoding:
43-
"""Lazily initialize tokenizer to avoid import-time side effects."""
44-
return tiktoken.get_encoding("cl100k_base")
45-
46-
4739
def calculate_column_distribution(
4840
column_name: str, df: pd.DataFrame, distribution_type: ColumnDistributionType
4941
) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
@@ -106,7 +98,6 @@ def calculate_input_token_stats(
10698
column_config: LLMTextColumnConfig, df: pd.DataFrame
10799
) -> dict[str, float | MissingValue]:
108100
try:
109-
tokenizer = _get_tokenizer()
110101
num_tokens = []
111102
num_samples = min(MAX_PROMPT_SAMPLE_SIZE, len(df))
112103
renderer = RecordBasedPromptRenderer(response_recipe=create_response_recipe(column_config))
@@ -118,7 +109,7 @@ def calculate_input_token_stats(
118109
prompt_template=column_config.prompt, record=record, prompt_type=PromptType.USER_PROMPT
119110
)
120111
concatenated_prompt = str(system_prompt + "\n\n" + prompt)
121-
num_tokens.append(len(tokenizer.encode(concatenated_prompt, disallowed_special=())))
112+
num_tokens.append(count_text_tokens(concatenated_prompt))
122113
except Exception as e:
123114
logger.warning(f"{WARNING_PREFIX} failed to calculate input token stats for column {column_config.name!r}: {e}")
124115
return {
@@ -137,10 +128,7 @@ def calculate_output_token_stats(
137128
column_config: LLMTextColumnConfig, df: pd.DataFrame
138129
) -> dict[str, float | MissingValue]:
139130
try:
140-
tokenizer = _get_tokenizer()
141-
tokens_per_record = df[column_config.name].apply(
142-
lambda value: len(tokenizer.encode(str(value), disallowed_special=()))
143-
)
131+
tokens_per_record = df[column_config.name].apply(lambda value: count_text_tokens(str(value)))
144132
return {
145133
"output_tokens_mean": tokens_per_record.mean(),
146134
"output_tokens_median": tokens_per_record.median(),

packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
from typing import Any
99

10-
from data_designer.engine.models.clients.parsing import extract_usage
10+
from data_designer.engine.models.clients.parsing import extract_usage, fill_reasoning_token_count_from_content
1111
from data_designer.engine.models.clients.types import (
1212
AssistantMessage,
1313
ChatCompletionRequest,
@@ -100,6 +100,7 @@ def parse_anthropic_response(response_json: dict[str, Any]) -> ChatCompletionRes
100100
usage: Usage | None = None
101101
if raw_usage:
102102
usage = extract_usage(raw_usage)
103+
usage = fill_reasoning_token_count_from_content(usage, message.reasoning_content)
103104

104105
return ChatCompletionResponse(message=message, usage=usage, raw=response_json)
105106

packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import json
99
import logging
1010
import uuid
11+
from dataclasses import replace
1112
from typing import Any
1213

1314
from data_designer.config.utils.image_helpers import (
@@ -23,6 +24,8 @@
2324
ToolCall,
2425
Usage,
2526
)
27+
from data_designer.engine.models.usage import TokenCountSource
28+
from data_designer.engine.utils.token_counting import count_text_tokens
2629

2730
logger = logging.getLogger(__name__)
2831

@@ -44,6 +47,7 @@ def parse_chat_completion_response(response: Any) -> ChatCompletionResponse:
4447
images=images,
4548
)
4649
usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None)
50+
usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content)
4751
return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response)
4852

4953

@@ -59,6 +63,7 @@ async def aparse_chat_completion_response(response: Any) -> ChatCompletionRespon
5963
images=images,
6064
)
6165
usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None)
66+
usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content)
6267
return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response)
6368

6469

@@ -260,6 +265,7 @@ def extract_usage(raw_usage: Any, generated_images: int | None = None) -> Usage
260265
input_tokens = get_value_from(raw_usage, "prompt_tokens")
261266
output_tokens = get_value_from(raw_usage, "completion_tokens")
262267
total_tokens = get_value_from(raw_usage, "total_tokens")
268+
reasoning_token_count = extract_reasoning_token_count(raw_usage)
263269

264270
if input_tokens is None:
265271
input_tokens = get_value_from(raw_usage, "input_tokens")
@@ -269,6 +275,7 @@ def extract_usage(raw_usage: Any, generated_images: int | None = None) -> Usage
269275
input_tokens = coerce_to_int_or_none(input_tokens)
270276
output_tokens = coerce_to_int_or_none(output_tokens)
271277
total_tokens = coerce_to_int_or_none(total_tokens)
278+
reasoning_token_count_source = TokenCountSource.PROVIDER if reasoning_token_count is not None else None
272279

273280
if total_tokens is None and input_tokens is not None and output_tokens is not None:
274281
total_tokens = input_tokens + output_tokens
@@ -280,17 +287,60 @@ def extract_usage(raw_usage: Any, generated_images: int | None = None) -> Usage
280287

281288
generated_images = coerce_to_int_or_none(generated_images)
282289

283-
if input_tokens is None and output_tokens is None and total_tokens is None and generated_images is None:
290+
if (
291+
input_tokens is None
292+
and output_tokens is None
293+
and total_tokens is None
294+
and reasoning_token_count is None
295+
and generated_images is None
296+
):
284297
return None
285298

286299
return Usage(
287300
input_tokens=input_tokens,
288301
output_tokens=output_tokens,
289302
total_tokens=total_tokens,
303+
reasoning_tokens=reasoning_token_count,
304+
reasoning_token_count_source=reasoning_token_count_source,
290305
generated_images=generated_images,
291306
)
292307

293308

309+
def extract_reasoning_token_count(raw_usage: Any) -> int | None:
310+
if raw_usage is None:
311+
return None
312+
313+
top_level = get_value_from(raw_usage, "reasoning_tokens")
314+
if top_level is not None:
315+
return coerce_to_int_or_none(top_level)
316+
317+
for details_key in ("completion_tokens_details", "output_tokens_details"):
318+
details = get_value_from(raw_usage, details_key)
319+
reasoning_token_count = get_value_from(details, "reasoning_tokens")
320+
if reasoning_token_count is not None:
321+
return coerce_to_int_or_none(reasoning_token_count)
322+
323+
return None
324+
325+
326+
def fill_reasoning_token_count_from_content(usage: Usage | None, reasoning_content: str | None) -> Usage | None:
327+
if usage is None:
328+
return None
329+
if usage.reasoning_tokens is not None or not reasoning_content:
330+
return usage
331+
332+
try:
333+
reasoning_token_count = count_text_tokens(reasoning_content)
334+
except Exception:
335+
logger.debug("Failed to estimate reasoning token count", exc_info=True)
336+
return usage
337+
return replace(
338+
usage,
339+
reasoning_tokens=reasoning_token_count,
340+
reasoning_token_count_source=TokenCountSource.ESTIMATED,
341+
)
342+
343+
294344
def extract_embedding_vector(item: Any) -> list[float]:
295345
value = get_value_from(item, "embedding")
296346
if isinstance(value, list):

packages/data-designer-engine/src/data_designer/engine/models/clients/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from dataclasses import dataclass, field, fields
77
from typing import Any, ClassVar, Protocol
88

9+
from data_designer.engine.models.usage import TokenCountSource
10+
911

1012
class HttpResponse(Protocol):
1113
"""Structural type for HTTP response objects (httpx, requests, etc.)."""
@@ -21,8 +23,16 @@ class Usage:
2123
input_tokens: int | None = None
2224
output_tokens: int | None = None
2325
total_tokens: int | None = None
26+
reasoning_tokens: int | None = None
27+
reasoning_token_count_source: TokenCountSource | None = None
2428
generated_images: int | None = None
2529

30+
def __post_init__(self) -> None:
31+
if self.reasoning_tokens is None and self.reasoning_token_count_source is not None:
32+
raise ValueError("reasoning_token_count_source requires reasoning_tokens")
33+
if self.reasoning_tokens is not None and self.reasoning_token_count_source is None:
34+
raise ValueError("reasoning_tokens requires reasoning_token_count_source")
35+
2636

2737
@dataclass
2838
class ImagePayload:

packages/data-designer-engine/src/data_designer/engine/models/facade.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
)
3737
from data_designer.engine.models.parsers.errors import ParserException
3838
from data_designer.engine.models.telemetry import TELEMETRY_ENABLED
39-
from data_designer.engine.models.usage import ImageUsageStats, ModelUsageStats, RequestUsageStats, TokenUsageStats
39+
from data_designer.engine.models.usage import (
40+
ImageUsageStats,
41+
ModelUsageStats,
42+
RequestUsageStats,
43+
TokenUsageStats,
44+
)
4045
from data_designer.engine.models.utils import ChatMessage, prompt_to_messages
4146

4247
if TYPE_CHECKING:
@@ -814,6 +819,8 @@ def _track_usage(self, usage: Usage | None, *, is_request_successful: bool) -> N
814819
token_usage = TokenUsageStats(
815820
input_tokens=usage.input_tokens,
816821
output_tokens=usage.output_tokens or 0,
822+
reasoning_tokens=usage.reasoning_tokens,
823+
reasoning_token_count_source=usage.reasoning_token_count_source,
817824
)
818825

819826
self._usage_stats.extend(

packages/data-designer-engine/src/data_designer/engine/models/registry.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from data_designer.config.models import GenerationType, ModelConfig
1010
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
11-
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenUsageStats
11+
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenCountSource, TokenUsageStats
1212
from data_designer.engine.secret_resolver import SecretResolver
1313
from data_designer.logging import LOG_INDENT
1414

@@ -27,6 +27,18 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30+
def format_reasoning_token_count(reasoning_token_count: int, source: TokenCountSource | str | None) -> str:
31+
if source == TokenCountSource.ESTIMATED or source == TokenCountSource.ESTIMATED.value:
32+
return f"{reasoning_token_count} (estimated)"
33+
return str(reasoning_token_count)
34+
35+
36+
def get_token_count_delta(current: int | None, previous: int | None) -> int | None:
37+
if current is None:
38+
return None
39+
return current - (previous or 0)
40+
41+
3042
class ModelRegistry:
3143
def __init__(
3244
self,
@@ -115,9 +127,17 @@ def log_model_usage(self, total_time_elapsed: float) -> None:
115127
output_tokens = token_usage["output_tokens"]
116128
total_tokens = token_usage["total_tokens"]
117129
tokens_per_second = stats["tokens_per_second"]
118-
logger.info(
119-
f"{LOG_INDENT}tokens: input={input_tokens}, output={output_tokens}, total={total_tokens}, tps={tokens_per_second}"
120-
)
130+
token_parts = [f"input={input_tokens}", f"output={output_tokens}"]
131+
if (reasoning_token_count := token_usage.get("reasoning_tokens")) is not None:
132+
formatted_reasoning_token_count = format_reasoning_token_count(
133+
reasoning_token_count,
134+
token_usage.get("reasoning_token_count_source"),
135+
)
136+
token_parts.append(f"reasoning={formatted_reasoning_token_count}")
137+
token_parts.extend([f"total={total_tokens}", f"tps={tokens_per_second}"])
138+
logger.info(f"{LOG_INDENT}tokens: {', '.join(token_parts)}")
139+
if token_usage.get("reasoning_token_count_source") == TokenCountSource.ESTIMATED.value:
140+
logger.info(f"{LOG_INDENT}reasoning token count estimated with tiktoken")
121141

122142
request_usage = stats["request_usage"]
123143
successful_requests = request_usage["successful_requests"]
@@ -160,14 +180,31 @@ def get_usage_deltas(self, snapshot: dict[str, ModelUsageStats]) -> dict[str, Mo
160180
prev = snapshot.get(model_name)
161181
delta_input = current.token_usage.input_tokens - (prev.token_usage.input_tokens if prev else 0)
162182
delta_output = current.token_usage.output_tokens - (prev.token_usage.output_tokens if prev else 0)
183+
delta_reasoning_token_count = get_token_count_delta(
184+
current.token_usage.reasoning_tokens,
185+
prev.token_usage.reasoning_tokens if prev else None,
186+
)
163187
delta_successful = current.request_usage.successful_requests - (
164188
prev.request_usage.successful_requests if prev else 0
165189
)
166190
delta_failed = current.request_usage.failed_requests - (prev.request_usage.failed_requests if prev else 0)
167191

168-
if delta_input > 0 or delta_output > 0 or delta_successful > 0 or delta_failed > 0:
192+
if (
193+
delta_input > 0
194+
or delta_output > 0
195+
or (delta_reasoning_token_count is not None and delta_reasoning_token_count > 0)
196+
or delta_successful > 0
197+
or delta_failed > 0
198+
):
169199
deltas[model_name] = ModelUsageStats(
170-
token_usage=TokenUsageStats(input_tokens=delta_input, output_tokens=delta_output),
200+
token_usage=TokenUsageStats(
201+
input_tokens=delta_input,
202+
output_tokens=delta_output,
203+
reasoning_tokens=delta_reasoning_token_count,
204+
reasoning_token_count_source=current.token_usage.reasoning_token_count_source
205+
if delta_reasoning_token_count is not None
206+
else None,
207+
),
171208
request_usage=RequestUsageStats(successful_requests=delta_successful, failed_requests=delta_failed),
172209
)
173210
return deltas

0 commit comments

Comments
 (0)