Skip to content

Commit 1ee37bc

Browse files
authored
refactor: update single column base class (#206)
* make properties abstract * add private column emoji attribute * update e2e plugin tests * throw error if not default string * add unit tests * make emoji a static method * dont need that docstring * update unit test
1 parent 3d9f518 commit 1ee37bc

14 files changed

Lines changed: 164 additions & 63 deletions

File tree

src/data_designer/config/analysis/utils/reporting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
from data_designer.config.analysis.column_statistics import CategoricalHistogramData
1818
from data_designer.config.analysis.utils.errors import AnalysisReportError
19-
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType, get_column_display_order
19+
from data_designer.config.column_types import (
20+
DataDesignerColumnType,
21+
get_column_display_order,
22+
get_column_emoji_from_type,
23+
)
2024
from data_designer.config.utils.visualization import (
2125
ColorPalette,
2226
convert_to_row_element,
@@ -101,7 +105,7 @@ def generate_analysis_report(
101105
displayed_column_types.add(column_type)
102106
column_label = column_type.replace("_", " ").title().replace("Llm", "LLM")
103107
table = Table(
104-
title=f"{COLUMN_TYPE_EMOJI_MAP[column_type]} {column_label} Columns",
108+
title=f"{get_column_emoji_from_type(column_type)} {column_label} Columns",
105109
**table_kws,
106110
)
107111

src/data_designer/config/column_configs.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from abc import ABC
4+
from abc import ABC, abstractmethod
55
from typing import Annotated, Literal
66

77
from pydantic import BaseModel, Discriminator, Field, model_validator
@@ -13,7 +13,7 @@
1313
from data_designer.config.sampler_params import SamplerParamsT, SamplerType
1414
from data_designer.config.utils.code_lang import CodeLang
1515
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
16-
from data_designer.config.utils.misc import assert_valid_jinja2_template, get_prompt_template_keywords
16+
from data_designer.config.utils.misc import assert_valid_jinja2_template, extract_keywords_from_jinja2_template
1717
from data_designer.config.validator_params import ValidatorParamsT, ValidatorType
1818

1919

@@ -35,17 +35,22 @@ class SingleColumnConfig(ConfigBase, ABC):
3535
drop: bool = False
3636
column_type: str
3737

38+
@staticmethod
39+
def get_column_emoji() -> str:
40+
return "🎨"
41+
3842
@property
43+
@abstractmethod
3944
def required_columns(self) -> list[str]:
4045
"""Returns a list of column names that must exist before this column can be generated.
4146
4247
Returns:
4348
List of column names that this column depends on. Empty list indicates
4449
no dependencies. Override in subclasses to specify dependencies.
4550
"""
46-
return []
4751

4852
@property
53+
@abstractmethod
4954
def side_effect_columns(self) -> list[str]:
5055
"""Returns a list of additional columns that this column will create as a side effect.
5156
@@ -56,7 +61,6 @@ def side_effect_columns(self) -> list[str]:
5661
List of column names that this column will create as a side effect. Empty list
5762
indicates no side effect columns. Override in subclasses to specify side effects.
5863
"""
59-
return []
6064

6165

6266
class SamplerColumnConfig(SingleColumnConfig):
@@ -94,6 +98,18 @@ class SamplerColumnConfig(SingleColumnConfig):
9498
convert_to: str | None = None
9599
column_type: Literal["sampler"] = "sampler"
96100

101+
@staticmethod
102+
def get_column_emoji() -> str:
103+
return "🎲"
104+
105+
@property
106+
def required_columns(self) -> list[str]:
107+
return []
108+
109+
@property
110+
def side_effect_columns(self) -> list[str]:
111+
return []
112+
97113
@model_validator(mode="before")
98114
@classmethod
99115
def inject_sampler_type_into_params(cls, data: dict) -> dict:
@@ -150,16 +166,20 @@ class LLMTextColumnConfig(SingleColumnConfig):
150166
multi_modal_context: list[ImageContext] | None = None
151167
column_type: Literal["llm-text"] = "llm-text"
152168

169+
@staticmethod
170+
def get_column_emoji() -> str:
171+
return "📝"
172+
153173
@property
154174
def required_columns(self) -> list[str]:
155175
"""Get columns referenced in the prompt and system_prompt templates.
156176
157177
Returns:
158178
List of unique column names referenced in Jinja2 templates.
159179
"""
160-
required_cols = list(get_prompt_template_keywords(self.prompt))
180+
required_cols = list(extract_keywords_from_jinja2_template(self.prompt))
161181
if self.system_prompt:
162-
required_cols.extend(list(get_prompt_template_keywords(self.system_prompt)))
182+
required_cols.extend(list(extract_keywords_from_jinja2_template(self.system_prompt)))
163183
return list(set(required_cols))
164184

165185
@property
@@ -207,6 +227,10 @@ class LLMCodeColumnConfig(LLMTextColumnConfig):
207227
code_lang: CodeLang
208228
column_type: Literal["llm-code"] = "llm-code"
209229

230+
@staticmethod
231+
def get_column_emoji() -> str:
232+
return "💻"
233+
210234

211235
class LLMStructuredColumnConfig(LLMTextColumnConfig):
212236
"""Configuration for structured JSON generation columns using Large Language Models.
@@ -225,6 +249,10 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig):
225249
output_format: dict | type[BaseModel]
226250
column_type: Literal["llm-structured"] = "llm-structured"
227251

252+
@staticmethod
253+
def get_column_emoji() -> str:
254+
return "🗂️"
255+
228256
@model_validator(mode="after")
229257
def validate_output_format(self) -> Self:
230258
"""Convert Pydantic model to JSON schema if needed.
@@ -275,6 +303,10 @@ class LLMJudgeColumnConfig(LLMTextColumnConfig):
275303
scores: list[Score] = Field(..., min_length=1)
276304
column_type: Literal["llm-judge"] = "llm-judge"
277305

306+
@staticmethod
307+
def get_column_emoji() -> str:
308+
return "⚖️"
309+
278310

279311
class ExpressionColumnConfig(SingleColumnConfig):
280312
"""Configuration for derived columns using Jinja2 expressions.
@@ -297,10 +329,18 @@ class ExpressionColumnConfig(SingleColumnConfig):
297329
dtype: Literal["int", "float", "str", "bool"] = "str"
298330
column_type: Literal["expression"] = "expression"
299331

332+
@staticmethod
333+
def get_column_emoji() -> str:
334+
return "🧩"
335+
300336
@property
301337
def required_columns(self) -> list[str]:
302338
"""Returns the columns referenced in the expression template."""
303-
return list(get_prompt_template_keywords(self.expr))
339+
return list(extract_keywords_from_jinja2_template(self.expr))
340+
341+
@property
342+
def side_effect_columns(self) -> list[str]:
343+
return []
304344

305345
@model_validator(mode="after")
306346
def assert_expression_valid_jinja(self) -> Self:
@@ -359,11 +399,19 @@ class ValidationColumnConfig(SingleColumnConfig):
359399
batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch")
360400
column_type: Literal["validation"] = "validation"
361401

402+
@staticmethod
403+
def get_column_emoji() -> str:
404+
return "🔍"
405+
362406
@property
363407
def required_columns(self) -> list[str]:
364408
"""Returns the columns that need to be validated."""
365409
return self.target_columns
366410

411+
@property
412+
def side_effect_columns(self) -> list[str]:
413+
return []
414+
367415

368416
class SeedDatasetColumnConfig(SingleColumnConfig):
369417
"""Configuration for columns sourced from seed datasets.
@@ -378,6 +426,18 @@ class SeedDatasetColumnConfig(SingleColumnConfig):
378426

379427
column_type: Literal["seed-dataset"] = "seed-dataset"
380428

429+
@staticmethod
430+
def get_column_emoji() -> str:
431+
return "🌱"
432+
433+
@property
434+
def required_columns(self) -> list[str]:
435+
return []
436+
437+
@property
438+
def side_effect_columns(self) -> list[str]:
439+
return []
440+
381441

382442
class EmbeddingColumnConfig(SingleColumnConfig):
383443
"""Configuration for embedding generation columns.
@@ -395,6 +455,14 @@ class EmbeddingColumnConfig(SingleColumnConfig):
395455
model_alias: str
396456
column_type: Literal["embedding"] = "embedding"
397457

458+
@staticmethod
459+
def get_column_emoji() -> str:
460+
return "🧬"
461+
398462
@property
399463
def required_columns(self) -> list[str]:
400464
return [self.target_column]
465+
466+
@property
467+
def side_effect_columns(self) -> list[str]:
468+
return []

src/data_designer/config/column_types.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
SeedDatasetColumnConfig,
1616
ValidationColumnConfig,
1717
)
18-
from data_designer.config.errors import InvalidColumnTypeError, InvalidConfigError
18+
from data_designer.config.errors import InvalidConfigError
1919
from data_designer.config.sampler_params import SamplerType
2020
from data_designer.config.utils.type_helpers import (
2121
SAMPLER_PARAMS,
@@ -45,22 +45,6 @@
4545
discriminator_field_name="column_type",
4646
)
4747

48-
COLUMN_TYPE_EMOJI_MAP = {
49-
"general": "⚛️", # possible analysis column type
50-
DataDesignerColumnType.EXPRESSION: "🧩",
51-
DataDesignerColumnType.LLM_CODE: "💻",
52-
DataDesignerColumnType.LLM_JUDGE: "⚖️",
53-
DataDesignerColumnType.LLM_STRUCTURED: "🗂️",
54-
DataDesignerColumnType.LLM_TEXT: "📝",
55-
DataDesignerColumnType.SEED_DATASET: "🌱",
56-
DataDesignerColumnType.SAMPLER: "🎲",
57-
DataDesignerColumnType.VALIDATION: "🔍",
58-
DataDesignerColumnType.EMBEDDING: "🧬",
59-
}
60-
COLUMN_TYPE_EMOJI_MAP.update(
61-
{DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()}
62-
)
63-
6448

6549
def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
6650
"""Create a Data Designer column config object from kwargs.
@@ -74,27 +58,20 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
7458
Data Designer column object of the appropriate type.
7559
"""
7660
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
77-
if column_type == DataDesignerColumnType.LLM_TEXT:
78-
return LLMTextColumnConfig(name=name, **kwargs)
79-
if column_type == DataDesignerColumnType.LLM_CODE:
80-
return LLMCodeColumnConfig(name=name, **kwargs)
81-
if column_type == DataDesignerColumnType.LLM_STRUCTURED:
82-
return LLMStructuredColumnConfig(name=name, **kwargs)
83-
if column_type == DataDesignerColumnType.LLM_JUDGE:
84-
return LLMJudgeColumnConfig(name=name, **kwargs)
85-
if column_type == DataDesignerColumnType.VALIDATION:
86-
return ValidationColumnConfig(name=name, **kwargs)
87-
if column_type == DataDesignerColumnType.EXPRESSION:
88-
return ExpressionColumnConfig(name=name, **kwargs)
61+
config_cls = get_column_config_cls_from_type(column_type)
8962
if column_type == DataDesignerColumnType.SAMPLER:
90-
return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs))
91-
if column_type == DataDesignerColumnType.SEED_DATASET:
92-
return SeedDatasetColumnConfig(name=name, **kwargs)
93-
if column_type == DataDesignerColumnType.EMBEDDING:
94-
return EmbeddingColumnConfig(name=name, **kwargs)
63+
kwargs = _resolve_sampler_kwargs(name, kwargs)
64+
return config_cls(name=name, **kwargs)
65+
66+
67+
def get_column_config_cls_from_type(column_type: DataDesignerColumnType) -> type[ColumnConfigT]:
68+
"""Get the column config class for a column type."""
69+
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
70+
if column_type in _COLUMN_TYPE_CONFIG_CLS_MAP:
71+
return _COLUMN_TYPE_CONFIG_CLS_MAP[column_type]
9572
if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value):
96-
return plugin.config_cls(name=name, **kwargs)
97-
raise InvalidColumnTypeError(f"🛑 {column_type} is not a valid column type.") # pragma: no cover
73+
return plugin.config_cls
74+
raise InvalidConfigError(f"🛑 {column_type} is not a valid column type.")
9875

9976

10077
def get_column_display_order() -> list[DataDesignerColumnType]:
@@ -114,6 +91,12 @@ def get_column_display_order() -> list[DataDesignerColumnType]:
11491
return display_order
11592

11693

94+
def get_column_emoji_from_type(column_type: DataDesignerColumnType) -> str:
95+
"""Get the emoji for a column type."""
96+
config_cls = get_column_config_cls_from_type(resolve_string_enum(column_type, DataDesignerColumnType))
97+
return config_cls.get_column_emoji()
98+
99+
117100
def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict:
118101
if "sampler_type" not in kwargs:
119102
raise InvalidConfigError(f"🛑 `sampler_type` is required for sampler column '{name}'.")
@@ -142,3 +125,16 @@ def _resolve_sampler_kwargs(name: str, kwargs: dict) -> dict:
142125
"params": params,
143126
**{k: v for k, v in kwargs.items() if k not in ["sampler_type", "params"]},
144127
}
128+
129+
130+
_COLUMN_TYPE_CONFIG_CLS_MAP = {
131+
DataDesignerColumnType.LLM_TEXT: LLMTextColumnConfig,
132+
DataDesignerColumnType.LLM_CODE: LLMCodeColumnConfig,
133+
DataDesignerColumnType.LLM_STRUCTURED: LLMStructuredColumnConfig,
134+
DataDesignerColumnType.LLM_JUDGE: LLMJudgeColumnConfig,
135+
DataDesignerColumnType.VALIDATION: ValidationColumnConfig,
136+
DataDesignerColumnType.EXPRESSION: ExpressionColumnConfig,
137+
DataDesignerColumnType.SAMPLER: SamplerColumnConfig,
138+
DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnConfig,
139+
DataDesignerColumnType.EMBEDDING: EmbeddingColumnConfig,
140+
}

src/data_designer/config/utils/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def can_run_data_designer_locally() -> bool:
4848
return True
4949

5050

51-
def get_prompt_template_keywords(template: str) -> set[str]:
52-
"""Extract all keywords from a valid string template."""
51+
def extract_keywords_from_jinja2_template(template: str) -> set[str]:
52+
"""Extract all keywords from a valid Jinja2 template."""
5353
with template_error_handler():
5454
ast = ImmutableSandboxedEnvironment().parse(template)
5555
keywords = set(meta.find_undeclared_variables(ast))

src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
MissingValue,
2020
NumericalDistribution,
2121
)
22-
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType
22+
from data_designer.config.column_types import DataDesignerColumnType
2323
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
2424
from data_designer.engine.analysis.utils.judge_score_processing import (
2525
extract_judge_score_distributions,
@@ -43,8 +43,7 @@ def profile(self, column_config_with_df: ColumnConfigWithDataFrame) -> JudgeScor
4343
column_config, df = column_config_with_df.as_tuple()
4444

4545
logger.info(
46-
f"{COLUMN_TYPE_EMOJI_MAP[column_config.column_type]} Analyzing LLM-as-judge "
47-
f"scores for column: '{column_config.name}'"
46+
f"{column_config.get_column_emoji()} Analyzing LLM-as-judge scores for column: '{column_config.name}'"
4847
)
4948

5049
score_summaries = {}

src/data_designer/engine/analysis/dataset_profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
1414
from data_designer.config.base import ConfigBase
1515
from data_designer.config.column_configs import SingleColumnConfig
16-
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP, ColumnConfigT
16+
from data_designer.config.column_types import ColumnConfigT
1717
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
1818
from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
1919
from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
@@ -68,7 +68,7 @@ def profile_dataset(
6868

6969
column_statistics = []
7070
for c in self.config.column_configs:
71-
logger.info(f" |-- {COLUMN_TYPE_EMOJI_MAP[c.column_type]} column: '{c.name}'")
71+
logger.info(f" |-- {c.get_column_emoji()} column: '{c.name}'")
7272
column_statistics.append(
7373
get_column_statistics_calculator(c.column_type)(
7474
column_config_with_df=ColumnConfigWithDataFrame(column_config=c, df=dataset)

src/data_designer/engine/column_generators/generators/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def inference_parameters(self) -> BaseInferenceParams:
9595
return self.model_config.inference_parameters
9696

9797
def log_pre_generation(self) -> None:
98-
logger.info(f"{self.config.column_type} model configuration for generating column '{self.config.name}'")
98+
logger.info(
99+
f"{self.config.get_column_emoji()} {self.config.column_type} model config for column '{self.config.name}'"
100+
)
99101
logger.info(f" |-- model: {self.model_config.model!r}")
100102
logger.info(f" |-- model alias: {self.config.model_alias!r}")
101103
logger.info(f" |-- model provider: {self.get_model_provider_name(model_alias=self.config.model_alias)!r}")

src/data_designer/engine/column_generators/utils/prompt_renderer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from data_designer.config.column_types import DataDesignerColumnType
99
from data_designer.config.models import ModelConfig
1010
from data_designer.config.utils.code_lang import CodeLang
11-
from data_designer.config.utils.misc import get_prompt_template_keywords
11+
from data_designer.config.utils.misc import extract_keywords_from_jinja2_template
1212
from data_designer.config.utils.type_helpers import StrEnum
1313
from data_designer.engine.column_generators.utils.errors import PromptTemplateRenderError
1414
from data_designer.engine.column_generators.utils.judge_score_factory import (
@@ -56,7 +56,7 @@ def _prepare_environment(self, *, prompt_template: str | None, record: dict, pro
5656
dataset_variables=list(record.keys()),
5757
)
5858
except (UserTemplateUnsupportedFiltersError, UserTemplateError) as exc:
59-
template_variables = get_prompt_template_keywords(prompt_template)
59+
template_variables = extract_keywords_from_jinja2_template(prompt_template)
6060
missing_columns = list(set(template_variables) - set(record.keys()))
6161

6262
error_msg = (

0 commit comments

Comments
 (0)