Skip to content

Commit fd477a6

Browse files
authored
fix: include multi_modal_context columns in required_columns (#520) (#522)
`LLMTextColumnConfig.required_columns` and `ImageColumnConfig.required_columns` only extracted dependencies from Jinja2 prompt templates, missing columns referenced via `multi_modal_context`. This caused the async engine's execution graph to dispatch LLM tasks before their multi-modal seed columns were loaded, resulting in KeyError failures under DATA_DESIGNER_ASYNC_ENGINE=1.
1 parent 4a28136 commit fd477a6

File tree

2 files changed

+44
-5
lines changed

2 files changed

+44
-5
lines changed

packages/data-designer-config/src/data_designer/config/column_configs.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,17 @@ def get_column_emoji() -> str:
185185

186186
@property
187187
def required_columns(self) -> list[str]:
188-
"""Get columns referenced in the prompt and system_prompt templates.
188+
"""Get columns referenced in prompt templates and multi-modal context.
189189
190190
Returns:
191-
List of unique column names referenced in Jinja2 templates.
191+
List of unique column names referenced in Jinja2 templates
192+
and multi-modal context configurations.
192193
"""
193194
required_cols = list(extract_keywords_from_jinja2_template(self.prompt))
194195
if self.system_prompt:
195196
required_cols.extend(list(extract_keywords_from_jinja2_template(self.system_prompt)))
197+
if self.multi_modal_context:
198+
required_cols.extend(ctx.column_name for ctx in self.multi_modal_context)
196199
return list(set(required_cols))
197200

198201
@property
@@ -593,12 +596,16 @@ def get_column_emoji() -> str:
593596

594597
@property
595598
def required_columns(self) -> list[str]:
596-
"""Get columns referenced in the prompt template.
599+
"""Get columns referenced in the prompt template and multi-modal context.
597600
598601
Returns:
599-
List of unique column names referenced in Jinja2 templates.
602+
List of unique column names referenced in Jinja2 templates
603+
and multi-modal context configurations.
600604
"""
601-
return list(extract_keywords_from_jinja2_template(self.prompt))
605+
required_cols = list(extract_keywords_from_jinja2_template(self.prompt))
606+
if self.multi_modal_context:
607+
required_cols.extend(ctx.column_name for ctx in self.multi_modal_context)
608+
return list(set(required_cols))
602609

603610
@model_validator(mode="after")
604611
def assert_prompt_valid_jinja(self) -> Self:

packages/data-designer-config/tests/config/test_columns.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from data_designer.config.column_configs import (
1010
EmbeddingColumnConfig,
1111
ExpressionColumnConfig,
12+
ImageColumnConfig,
1213
LLMCodeColumnConfig,
1314
LLMJudgeColumnConfig,
1415
LLMStructuredColumnConfig,
@@ -26,6 +27,7 @@
2627
is_plugin_column_type,
2728
)
2829
from data_designer.config.errors import InvalidConfigError
30+
from data_designer.config.models import ImageContext
2931
from data_designer.config.sampler_params import (
3032
CategorySamplerParams,
3133
GaussianSamplerParams,
@@ -122,6 +124,36 @@ def test_llm_text_column_config():
122124
)
123125

124126

127+
def test_llm_text_column_config_required_columns_includes_multi_modal_context():
128+
config = LLMTextColumnConfig(
129+
name="test_llm_text",
130+
prompt="Classify this image: {{ description }}",
131+
model_alias=stub_model_alias,
132+
multi_modal_context=[ImageContext(column_name="image_base64")],
133+
)
134+
assert set(config.required_columns) == {"description", "image_base64"}
135+
136+
137+
def test_llm_text_column_config_required_columns_deduplicates_multi_modal_and_prompt():
138+
config = LLMTextColumnConfig(
139+
name="test_llm_text",
140+
prompt="Classify this: {{ image_col }}",
141+
model_alias=stub_model_alias,
142+
multi_modal_context=[ImageContext(column_name="image_col")],
143+
)
144+
assert config.required_columns == ["image_col"]
145+
146+
147+
def test_image_column_config_required_columns_includes_multi_modal_context():
148+
config = ImageColumnConfig(
149+
name="test_image",
150+
prompt="Generate based on {{ style }}",
151+
model_alias=stub_model_alias,
152+
multi_modal_context=[ImageContext(column_name="reference_image")],
153+
)
154+
assert set(config.required_columns) == {"style", "reference_image"}
155+
156+
125157
def test_llm_text_column_config_with_trace_serialization() -> None:
126158
"""Test that with_trace field serializes and deserializes correctly."""
127159
config = LLMTextColumnConfig(

0 commit comments

Comments
 (0)