|
9 | 9 | from data_designer.config.column_configs import ( |
10 | 10 | EmbeddingColumnConfig, |
11 | 11 | ExpressionColumnConfig, |
| 12 | + ImageColumnConfig, |
12 | 13 | LLMCodeColumnConfig, |
13 | 14 | LLMJudgeColumnConfig, |
14 | 15 | LLMStructuredColumnConfig, |
|
26 | 27 | is_plugin_column_type, |
27 | 28 | ) |
28 | 29 | from data_designer.config.errors import InvalidConfigError |
| 30 | +from data_designer.config.models import ImageContext |
29 | 31 | from data_designer.config.sampler_params import ( |
30 | 32 | CategorySamplerParams, |
31 | 33 | GaussianSamplerParams, |
@@ -122,6 +124,36 @@ def test_llm_text_column_config(): |
122 | 124 | ) |
123 | 125 |
|
124 | 126 |
|
| 127 | +def test_llm_text_column_config_required_columns_includes_multi_modal_context(): |
| 128 | + config = LLMTextColumnConfig( |
| 129 | + name="test_llm_text", |
| 130 | + prompt="Classify this image: {{ description }}", |
| 131 | + model_alias=stub_model_alias, |
| 132 | + multi_modal_context=[ImageContext(column_name="image_base64")], |
| 133 | + ) |
| 134 | + assert set(config.required_columns) == {"description", "image_base64"} |
| 135 | + |
| 136 | + |
| 137 | +def test_llm_text_column_config_required_columns_deduplicates_multi_modal_and_prompt(): |
| 138 | + config = LLMTextColumnConfig( |
| 139 | + name="test_llm_text", |
| 140 | + prompt="Classify this: {{ image_col }}", |
| 141 | + model_alias=stub_model_alias, |
| 142 | + multi_modal_context=[ImageContext(column_name="image_col")], |
| 143 | + ) |
| 144 | + assert config.required_columns == ["image_col"] |
| 145 | + |
| 146 | + |
| 147 | +def test_image_column_config_required_columns_includes_multi_modal_context(): |
| 148 | + config = ImageColumnConfig( |
| 149 | + name="test_image", |
| 150 | + prompt="Generate based on {{ style }}", |
| 151 | + model_alias=stub_model_alias, |
| 152 | + multi_modal_context=[ImageContext(column_name="reference_image")], |
| 153 | + ) |
| 154 | + assert set(config.required_columns) == {"style", "reference_image"} |
| 155 | + |
| 156 | + |
125 | 157 | def test_llm_text_column_config_with_trace_serialization() -> None: |
126 | 158 | """Test that with_trace field serializes and deserializes correctly.""" |
127 | 159 | config = LLMTextColumnConfig( |
|
0 commit comments