Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,17 @@ def get_column_emoji() -> str:

@property
def required_columns(self) -> list[str]:
"""Get columns referenced in the prompt and system_prompt templates.
"""Get columns referenced in prompt templates and multi-modal context.

Returns:
List of unique column names referenced in Jinja2 templates.
List of unique column names referenced in Jinja2 templates
and multi-modal context configurations.
"""
required_cols = list(extract_keywords_from_jinja2_template(self.prompt))
if self.system_prompt:
required_cols.extend(list(extract_keywords_from_jinja2_template(self.system_prompt)))
if self.multi_modal_context:
required_cols.extend(ctx.column_name for ctx in self.multi_modal_context)
return list(set(required_cols))

@property
Expand Down Expand Up @@ -593,12 +596,16 @@ def get_column_emoji() -> str:

@property
def required_columns(self) -> list[str]:
"""Get columns referenced in the prompt template.
"""Get columns referenced in the prompt template and multi-modal context.

Returns:
List of unique column names referenced in Jinja2 templates.
List of unique column names referenced in Jinja2 templates
and multi-modal context configurations.
"""
return list(extract_keywords_from_jinja2_template(self.prompt))
required_cols = list(extract_keywords_from_jinja2_template(self.prompt))
if self.multi_modal_context:
required_cols.extend(ctx.column_name for ctx in self.multi_modal_context)
return list(set(required_cols))

@model_validator(mode="after")
def assert_prompt_valid_jinja(self) -> Self:
Expand Down
32 changes: 32 additions & 0 deletions packages/data-designer-config/tests/config/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from data_designer.config.column_configs import (
EmbeddingColumnConfig,
ExpressionColumnConfig,
ImageColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMStructuredColumnConfig,
Expand All @@ -26,6 +27,7 @@
is_plugin_column_type,
)
from data_designer.config.errors import InvalidConfigError
from data_designer.config.models import ImageContext
from data_designer.config.sampler_params import (
CategorySamplerParams,
GaussianSamplerParams,
Expand Down Expand Up @@ -122,6 +124,36 @@ def test_llm_text_column_config():
)


def test_llm_text_column_config_required_columns_includes_multi_modal_context():
config = LLMTextColumnConfig(
name="test_llm_text",
prompt="Classify this image: {{ description }}",
model_alias=stub_model_alias,
multi_modal_context=[ImageContext(column_name="image_base64")],
)
assert set(config.required_columns) == {"description", "image_base64"}


def test_llm_text_column_config_required_columns_deduplicates_multi_modal_and_prompt():
config = LLMTextColumnConfig(
name="test_llm_text",
prompt="Classify this: {{ image_col }}",
model_alias=stub_model_alias,
multi_modal_context=[ImageContext(column_name="image_col")],
)
assert config.required_columns == ["image_col"]


def test_image_column_config_required_columns_includes_multi_modal_context():
config = ImageColumnConfig(
name="test_image",
prompt="Generate based on {{ style }}",
model_alias=stub_model_alias,
multi_modal_context=[ImageContext(column_name="reference_image")],
)
assert set(config.required_columns) == {"style", "reference_image"}


def test_llm_text_column_config_with_trace_serialization() -> None:
"""Test that with_trace field serializes and deserializes correctly."""
config = LLMTextColumnConfig(
Expand Down
Loading