Skip to content

Commit b0076cd

Browse files
authored
fix: use run config for profiler token stats (#738)
1 parent 7912544 commit b0076cd

6 files changed

Lines changed: 143 additions & 8 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/analysis/column_profilers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from data_designer.config.base import ConfigBase, SingleColumnConfig
1414
from data_designer.config.column_types import DataDesignerColumnType
15+
from data_designer.config.run_config import JinjaRenderingEngine
1516
from data_designer.engine.configurable_task import ConfigurableTask, TaskConfigT
1617

1718
logger = logging.getLogger(__name__)
@@ -20,6 +21,7 @@
2021
class ColumnConfigWithDataFrame(ConfigBase):
2122
column_config: SingleColumnConfig
2223
df: pd.DataFrame
24+
jinja_rendering_engine: JinjaRenderingEngine = JinjaRenderingEngine.SECURE
2325

2426
@model_validator(mode="after")
2527
def validate_column_exists(self) -> Self:

packages/data-designer-engine/src/data_designer/engine/analysis/column_statistics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ def __repr__(self) -> str:
7373

7474
class LLMTextColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
7575
def calculate_token_stats(self) -> dict[str, Any]:
76-
return calculate_token_stats(self.column_config, self.df)
76+
return calculate_token_stats(
77+
self.column_config,
78+
self.df,
79+
jinja_rendering_engine=self.column_config_with_df.jinja_rendering_engine,
80+
)
7781

7882

7983
class LLMCodeColumnStatisticsCalculator(LLMTextColumnStatisticsCalculator): ...

packages/data-designer-engine/src/data_designer/engine/analysis/dataset_profiler.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def profile_dataset(
7676
logger.info(f"{LOG_INDENT}{c.get_column_emoji()} column: '{c.name}'")
7777
column_statistics.append(
7878
get_column_statistics_calculator(c.column_type)(
79-
column_config_with_df=ColumnConfigWithDataFrame(column_config=c, df=dataset)
79+
column_config_with_df=self._create_column_config_with_df(c, dataset)
8080
).calculate()
8181
)
8282

@@ -86,7 +86,7 @@ def profile_dataset(
8686
applicable_column_types = profiler.get_applicable_column_types()
8787
for c in self.config.column_configs:
8888
if c.column_type in applicable_column_types:
89-
params = ColumnConfigWithDataFrame(column_config=c, df=dataset)
89+
params = self._create_column_config_with_df(c, dataset)
9090
column_profiles.append(profiler.profile(params))
9191
if len(column_profiles) == 0:
9292
logger.warning(
@@ -128,6 +128,17 @@ def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> Col
128128
config=profiler_config, resource_provider=self.resource_provider
129129
)
130130

131+
def _create_column_config_with_df(
132+
self,
133+
column_config: ColumnConfigT,
134+
dataset: pd.DataFrame,
135+
) -> ColumnConfigWithDataFrame:
136+
return ColumnConfigWithDataFrame(
137+
column_config=column_config,
138+
df=dataset,
139+
jinja_rendering_engine=self.resource_provider.run_config.jinja_rendering_engine,
140+
)
141+
131142
def _validate_column_profiler_configs(self) -> None:
132143
if self.config.column_profiler_configs:
133144
if self.resource_provider.model_registry is None:

packages/data-designer-engine/src/data_designer/engine/analysis/utils/column_statistics_calculations.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from data_designer.config.column_configs import (
1818
LLMTextColumnConfig,
1919
)
20+
from data_designer.config.run_config import JinjaRenderingEngine
2021
from data_designer.engine.column_generators.utils.prompt_renderer import (
2122
PromptType,
2223
RecordBasedPromptRenderer,
@@ -95,12 +96,18 @@ def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[st
9596

9697

9798
def calculate_input_token_stats(
98-
column_config: LLMTextColumnConfig, df: pd.DataFrame
99+
column_config: LLMTextColumnConfig,
100+
df: pd.DataFrame,
101+
*,
102+
jinja_rendering_engine: JinjaRenderingEngine = JinjaRenderingEngine.SECURE,
99103
) -> dict[str, float | MissingValue]:
100104
try:
101105
num_tokens = []
102106
num_samples = min(MAX_PROMPT_SAMPLE_SIZE, len(df))
103-
renderer = RecordBasedPromptRenderer(response_recipe=create_response_recipe(column_config))
107+
renderer = RecordBasedPromptRenderer(
108+
response_recipe=create_response_recipe(column_config),
109+
jinja_rendering_engine=jinja_rendering_engine,
110+
)
104111
for record in df.sample(num_samples, random_state=RANDOM_SEED).to_dict(orient="records"):
105112
system_prompt = renderer.render(
106113
prompt_template=column_config.system_prompt, record=record, prompt_type=PromptType.SYSTEM_PROMPT
@@ -143,9 +150,14 @@ def calculate_output_token_stats(
143150
}
144151

145152

146-
def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
153+
def calculate_token_stats(
154+
column_config: LLMTextColumnConfig,
155+
df: pd.DataFrame,
156+
*,
157+
jinja_rendering_engine: JinjaRenderingEngine = JinjaRenderingEngine.SECURE,
158+
) -> dict[str, float | MissingValue]:
147159
return {
148-
**calculate_input_token_stats(column_config, df),
160+
**calculate_input_token_stats(column_config, df, jinja_rendering_engine=jinja_rendering_engine),
149161
**calculate_output_token_stats(column_config, df),
150162
}
151163

packages/data-designer-engine/tests/engine/analysis/test_dataset_profiler.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from __future__ import annotations
5+
6+
from pathlib import Path
47
from unittest.mock import patch
58

69
import pytest
710

11+
import data_designer.lazy_heavy_imports as lazy
12+
from data_designer.config.analysis.column_statistics import MissingValue
813
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
9-
from data_designer.config.column_configs import SamplerColumnConfig
14+
from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig
15+
from data_designer.config.run_config import JinjaRenderingEngine, RunConfig
1016
from data_designer.config.sampler_params import CategorySamplerParams, SamplerType
1117
from data_designer.engine.analysis.column_profilers.judge_score_profiler import JudgeScoreProfilerConfig
1218
from data_designer.engine.analysis.dataset_profiler import DataDesignerDatasetProfiler, DatasetProfilerConfig
1319
from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
1420
from data_designer.engine.analysis.utils.judge_score_processing import JudgeScoreSample
1521
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
22+
from data_designer.engine.resources.resource_provider import ResourceProvider
23+
from data_designer.engine.storage.artifact_storage import ArtifactStorage
1624

1725

1826
def test_dataset_profiler_config_flattens_multi_column_configs():
@@ -88,6 +96,62 @@ def test_dataset_profiler_profile_dataset_with_column_profilers(
8896
stub_model_facade.generate.assert_called()
8997

9098

99+
@pytest.mark.parametrize(
100+
(
101+
"jinja_rendering_engine",
102+
"expected_input_tokens_mean",
103+
"expected_input_tokens_median",
104+
"expected_input_tokens_stddev",
105+
),
106+
[
107+
(JinjaRenderingEngine.NATIVE, 10.0, 10.0, 0.0),
108+
(
109+
JinjaRenderingEngine.SECURE,
110+
MissingValue.CALCULATION_FAILED,
111+
MissingValue.CALCULATION_FAILED,
112+
MissingValue.CALCULATION_FAILED,
113+
),
114+
],
115+
)
116+
def test_dataset_profiler_uses_run_config_jinja_engine_for_input_token_stats(
117+
tmp_path: Path,
118+
jinja_rendering_engine: JinjaRenderingEngine,
119+
expected_input_tokens_mean: float | MissingValue,
120+
expected_input_tokens_median: float | MissingValue,
121+
expected_input_tokens_stddev: float | MissingValue,
122+
) -> None:
123+
column_config = LLMTextColumnConfig(
124+
name="summary",
125+
prompt="Trajectory: {{ messages }}",
126+
system_prompt="System prompt",
127+
model_alias="nano",
128+
)
129+
dataset = lazy.pd.DataFrame(
130+
{
131+
"summary": ["response"],
132+
"messages": ["x" * 512_001],
133+
}
134+
)
135+
profiler = DataDesignerDatasetProfiler(
136+
config=DatasetProfilerConfig(column_configs=[column_config]),
137+
resource_provider=ResourceProvider(
138+
artifact_storage=ArtifactStorage(artifact_path=tmp_path),
139+
run_config=RunConfig(jinja_rendering_engine=jinja_rendering_engine),
140+
),
141+
)
142+
143+
with patch(
144+
"data_designer.engine.analysis.utils.column_statistics_calculations.count_text_tokens",
145+
return_value=10,
146+
):
147+
profile = profiler.profile_dataset(target_num_records=1, dataset=dataset)
148+
149+
stats = profile.column_statistics[0]
150+
assert stats.input_tokens_mean == expected_input_tokens_mean
151+
assert stats.input_tokens_median == expected_input_tokens_median
152+
assert stats.input_tokens_stddev == expected_input_tokens_stddev
153+
154+
91155
@patch(
92156
"data_designer.engine.analysis.dataset_profiler.DataDesignerDatasetProfiler._validate_schema_consistency",
93157
autospec=True,

packages/data-designer-engine/tests/engine/analysis/utils/test_column_statistics_calculations.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
from itertools import cycle
7+
from unittest.mock import patch
78

89
import pytest
910

@@ -16,6 +17,7 @@
1617
NumericalDistribution,
1718
)
1819
from data_designer.config.column_configs import LLMTextColumnConfig
20+
from data_designer.config.run_config import JinjaRenderingEngine
1921
from data_designer.config.utils.numerical_helpers import prepare_number_for_reporting
2022
from data_designer.engine.analysis.utils.column_statistics_calculations import (
2123
calculate_column_distribution,
@@ -188,6 +190,46 @@ def test_calculate_input_token_stats(mock_prompt_renderer_render, stub_column_co
188190
assert result["input_tokens_median"] == MissingValue.CALCULATION_FAILED
189191

190192

193+
@pytest.mark.parametrize(
194+
("prompt", "messages", "expected_token_count"),
195+
[
196+
("Joined: {{ messages | join('-') }}", ["Hello", "World"], 4),
197+
("Trajectory: {{ messages }}", "x" * 512_001, 10),
198+
],
199+
)
200+
def test_calculate_input_token_stats_respects_native_jinja_engine(
201+
prompt: str,
202+
messages: list[str] | str,
203+
expected_token_count: int,
204+
) -> None:
205+
column_config = LLMTextColumnConfig(
206+
name="test_column",
207+
prompt=prompt,
208+
system_prompt="System prompt",
209+
model_alias="test_model_alias",
210+
)
211+
df = lazy.pd.DataFrame(
212+
{
213+
"test_column": ["response"],
214+
"messages": [messages],
215+
}
216+
)
217+
218+
with patch(
219+
"data_designer.engine.analysis.utils.column_statistics_calculations.count_text_tokens",
220+
return_value=expected_token_count,
221+
):
222+
result = calculate_input_token_stats(
223+
column_config,
224+
df,
225+
jinja_rendering_engine=JinjaRenderingEngine.NATIVE,
226+
)
227+
228+
assert result["input_tokens_mean"] == float(expected_token_count)
229+
assert result["input_tokens_median"] == float(expected_token_count)
230+
assert result["input_tokens_stddev"] == 0.0
231+
232+
191233
def test_calculate_output_token_stats(stub_column_config, stub_df_responses):
192234
result = calculate_output_token_stats(stub_column_config, stub_df_responses)
193235
assert "output_tokens_mean" in result

0 commit comments

Comments
 (0)