From a2d9c8a5aa318b68717bd5f3516a015b22049426 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 9 Apr 2026 11:10:22 -0600 Subject: [PATCH] fix: include multi_modal_context columns in required_columns (#520) `LLMTextColumnConfig.required_columns` and `ImageColumnConfig.required_columns` only extracted dependencies from Jinja2 prompt templates, missing columns referenced via `multi_modal_context`. This caused the async engine's execution graph to dispatch LLM tasks before their multi-modal seed columns were loaded, resulting in KeyError failures under DATA_DESIGNER_ASYNC_ENGINE=1. Made-with: Cursor --- .../data_designer/config/column_configs.py | 17 +++++++--- .../tests/config/test_columns.py | 32 +++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index 27d1b3b31..59bd9e39a 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -185,14 +185,17 @@ def get_column_emoji() -> str: @property def required_columns(self) -> list[str]: - """Get columns referenced in the prompt and system_prompt templates. + """Get columns referenced in prompt templates and multi-modal context. Returns: - List of unique column names referenced in Jinja2 templates. + List of unique column names referenced in Jinja2 templates + and multi-modal context configurations. """ required_cols = list(extract_keywords_from_jinja2_template(self.prompt)) if self.system_prompt: required_cols.extend(list(extract_keywords_from_jinja2_template(self.system_prompt))) + if self.multi_modal_context: + required_cols.extend(ctx.column_name for ctx in self.multi_modal_context) return list(set(required_cols)) @property @@ -593,12 +596,16 @@ def get_column_emoji() -> str: @property def required_columns(self) -> list[str]: - """Get columns referenced in the prompt template. + """Get columns referenced in the prompt template and multi-modal context. Returns: - List of unique column names referenced in Jinja2 templates. + List of unique column names referenced in Jinja2 templates + and multi-modal context configurations. """ - return list(extract_keywords_from_jinja2_template(self.prompt)) + required_cols = list(extract_keywords_from_jinja2_template(self.prompt)) + if self.multi_modal_context: + required_cols.extend(ctx.column_name for ctx in self.multi_modal_context) + return list(set(required_cols)) @model_validator(mode="after") def assert_prompt_valid_jinja(self) -> Self: diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index 239ff0862..987a158f6 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -9,6 +9,7 @@ from data_designer.config.column_configs import ( EmbeddingColumnConfig, ExpressionColumnConfig, + ImageColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, LLMStructuredColumnConfig, @@ -26,6 +27,7 @@ is_plugin_column_type, ) from data_designer.config.errors import InvalidConfigError +from data_designer.config.models import ImageContext from data_designer.config.sampler_params import ( CategorySamplerParams, GaussianSamplerParams, @@ -122,6 +124,36 @@ def test_llm_text_column_config(): ) +def test_llm_text_column_config_required_columns_includes_multi_modal_context(): + config = LLMTextColumnConfig( + name="test_llm_text", + prompt="Classify this image: {{ description }}", + model_alias=stub_model_alias, + multi_modal_context=[ImageContext(column_name="image_base64")], + ) + assert set(config.required_columns) == {"description", "image_base64"} + + +def test_llm_text_column_config_required_columns_deduplicates_multi_modal_and_prompt(): + config = LLMTextColumnConfig( + name="test_llm_text", + prompt="Classify this: {{ image_col }}", + model_alias=stub_model_alias, + multi_modal_context=[ImageContext(column_name="image_col")], + ) + assert config.required_columns == ["image_col"] + + +def test_image_column_config_required_columns_includes_multi_modal_context(): + config = ImageColumnConfig( + name="test_image", + prompt="Generate based on {{ style }}", + model_alias=stub_model_alias, + multi_modal_context=[ImageContext(column_name="reference_image")], + ) + assert set(config.required_columns) == {"style", "reference_image"} + + def test_llm_text_column_config_with_trace_serialization() -> None: """Test that with_trace field serializes and deserializes correctly.""" config = LLMTextColumnConfig(