|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +from typing import Literal |
| 5 | + |
| 6 | +from data_designer.config.base import SingleColumnConfig |
| 7 | +from data_designer.config.models import ModalityDataType |
| 8 | +from data_designer.config.utils.constants import REASONING_CONTENT_COLUMN_POSTFIX, TRACE_COLUMN_POSTFIX |
| 9 | +from data_designer.config.utils.image_helpers import ImageFormat |
| 10 | +from data_designer.config.utils.misc import assert_valid_jinja2_template, extract_keywords_from_jinja2_template |
| 11 | +from data_designer.config.utils.trace_type import TraceType |
| 12 | +from pydantic import Field, model_validator |
| 13 | +from typing_extensions import Self |
| 14 | + |
| 15 | +VisualSearchToolName = Literal[ |
| 16 | + "open_image", |
| 17 | + "get_image_info", |
| 18 | + "list_images", |
| 19 | + "crop_image", |
| 20 | + "transform_image", |
| 21 | + "edit_color", |
| 22 | +] |
| 23 | + |
| 24 | + |
| 25 | +class VisualSearchColumnConfig(SingleColumnConfig): |
| 26 | + """Configuration for image-grounded visual search with local image-operation tools. |
| 27 | +
|
| 28 | + The column runs a vision-capable chat model with built-in image tools. Each tool |
| 29 | + returns an image ID, and subsequent calls may operate on any previous image ID, |
| 30 | + which lets the model branch from earlier points in the image history. |
| 31 | + """ |
| 32 | + |
| 33 | + column_type: Literal["visual-search"] = "visual-search" |
| 34 | + |
| 35 | + image_column: str = Field(description="Column containing a local image path, URL, base64 string, or data URI.") |
| 36 | + prompt: str = Field(description="Jinja2 prompt template for the visual search task.") |
| 37 | + model_alias: str = Field(description="Alias of the vision-capable chat model to use.") |
| 38 | + system_prompt: str | None = Field(default=None, description="Optional Jinja2 system prompt template.") |
| 39 | + image_data_type: ModalityDataType | None = Field( |
| 40 | + default=None, |
| 41 | + description="Optional explicit format for values in image_column. Leave unset for auto-detection.", |
| 42 | + ) |
| 43 | + image_format: ImageFormat | None = Field( |
| 44 | + default=None, |
| 45 | + description="Required when image_data_type is base64 and the image format cannot be auto-detected.", |
| 46 | + ) |
| 47 | + image_placeholder: str | None = Field( |
| 48 | + default=None, |
| 49 | + description="Optional model-specific image token to include in text for endpoints that require it.", |
| 50 | + ) |
| 51 | + max_tool_call_turns: int = Field( |
| 52 | + default=6, |
| 53 | + ge=1, |
| 54 | + description="Maximum tool-calling turns allowed for each row before the model must answer.", |
| 55 | + ) |
| 56 | + allowed_tools: list[VisualSearchToolName] | None = Field( |
| 57 | + default=None, |
| 58 | + description="Optional allowlist of built-in visual tools. Defaults to all tools.", |
| 59 | + ) |
| 60 | + attach_images_after_tool_calls: bool = Field( |
| 61 | + default=True, |
| 62 | + description="Attach resulting tool images back into the next model turn.", |
| 63 | + ) |
| 64 | + include_image_history: bool = Field( |
| 65 | + default=True, |
| 66 | + description="Add a side-effect column with the tree of image operations and IDs.", |
| 67 | + ) |
| 68 | + with_trace: TraceType = Field(default=TraceType.NONE, description="Optional chat trace capture mode.") |
| 69 | + extract_reasoning_content: bool = Field( |
| 70 | + default=False, |
| 71 | + description="If True, capture reasoning_content from the final assistant message.", |
| 72 | + ) |
| 73 | + use_default_system_prompt: bool = Field( |
| 74 | + default=True, |
| 75 | + description="Prepend built-in instructions explaining image IDs and visual tools.", |
| 76 | + ) |
| 77 | + |
| 78 | + @staticmethod |
| 79 | + def get_column_emoji() -> str: |
| 80 | + return "🔎" |
| 81 | + |
| 82 | + @property |
| 83 | + def required_columns(self) -> list[str]: |
| 84 | + required_cols = [self.image_column, *extract_keywords_from_jinja2_template(self.prompt)] |
| 85 | + if self.system_prompt: |
| 86 | + required_cols.extend(extract_keywords_from_jinja2_template(self.system_prompt)) |
| 87 | + return list(dict.fromkeys(required_cols)) |
| 88 | + |
| 89 | + @property |
| 90 | + def side_effect_columns(self) -> list[str]: |
| 91 | + return [ |
| 92 | + *([f"{self.name}__image_history"] if self.include_image_history else []), |
| 93 | + *([f"{self.name}{TRACE_COLUMN_POSTFIX}"] if self.with_trace != TraceType.NONE else []), |
| 94 | + *([f"{self.name}{REASONING_CONTENT_COLUMN_POSTFIX}"] if self.extract_reasoning_content else []), |
| 95 | + ] |
| 96 | + |
| 97 | + @model_validator(mode="after") |
| 98 | + def validate_templates_and_image_format(self) -> Self: |
| 99 | + """Validate prompt templates and image modality settings.""" |
| 100 | + assert_valid_jinja2_template(self.prompt) |
| 101 | + if self.system_prompt: |
| 102 | + assert_valid_jinja2_template(self.system_prompt) |
| 103 | + if self.image_data_type == ModalityDataType.BASE64 and self.image_format is None: |
| 104 | + raise ValueError("image_format is required when image_data_type is base64") |
| 105 | + return self |
0 commit comments