|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Typer-backed Data Designer recipe entry point for retrieval SDG.""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +from pathlib import Path |
| 9 | +from typing import Annotated |
| 10 | + |
| 11 | +import click |
| 12 | +import data_designer.config as dd |
| 13 | +import typer |
| 14 | + |
| 15 | +from data_designer_retrieval_sdg.pipeline import ( |
| 16 | + DEFAULT_CHAT_MODEL, |
| 17 | + DEFAULT_EMBED_MODEL, |
| 18 | + DEFAULT_PROVIDER, |
| 19 | + build_qa_generation_pipeline, |
| 20 | +) |
| 21 | +from data_designer_retrieval_sdg.seed_source import DocumentChunkerSeedSource |
| 22 | + |
| 23 | + |
| 24 | +def load_config_builder(params: dd.DataDesignerScriptParams | None = None) -> dd.DataDesignerConfigBuilder: |
| 25 | + """Build the retrieval SDG pipeline from forwarded Data Designer CLI args. |
| 26 | +
|
| 27 | + Args: |
| 28 | + params: Data Designer script parameters. ``params.argv`` contains the |
| 29 | + arguments supplied after ``data-designer preview/create --recipe |
| 30 | + retrieval-sdg --``. |
| 31 | +
|
| 32 | + Returns: |
| 33 | + A configured Data Designer config builder for retrieval SDG generation. |
| 34 | + """ |
| 35 | + argv = list(tuple(getattr(params, "argv", ()))) |
| 36 | + command = typer.main.get_command(build_typer_app()) |
| 37 | + config_builder = command.main( |
| 38 | + args=argv, |
| 39 | + prog_name="data-designer preview/create --recipe retrieval-sdg --", |
| 40 | + standalone_mode=False, |
| 41 | + ) |
| 42 | + |
| 43 | + if config_builder == 0 and any(arg in {"--help", "-h"} for arg in argv): |
| 44 | + raise SystemExit(0) |
| 45 | + if not isinstance(config_builder, dd.DataDesignerConfigBuilder): |
| 46 | + raise TypeError(f"Recipe returned {type(config_builder).__name__}, expected DataDesignerConfigBuilder") |
| 47 | + return config_builder |
| 48 | + |
| 49 | + |
| 50 | +def build_typer_app() -> typer.Typer: |
| 51 | + """Build the Typer app used for recipe inspection and execution. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + Typer app describing the retrieval SDG recipe interface. |
| 55 | + """ |
| 56 | + app = typer.Typer(add_completion=False, help="Build the retrieval SDG Data Designer workflow.") |
| 57 | + app.command(name=None, help="Build the retrieval SDG Data Designer workflow.")(recipe_command) |
| 58 | + return app |
| 59 | + |
| 60 | + |
| 61 | +def recipe_command( |
| 62 | + input_dir: Annotated[Path, typer.Option("--input-dir", help="Directory containing text files")], |
| 63 | + file_pattern: Annotated[str, typer.Option("--file-pattern", help="Filename glob (basenames only)")] = "*", |
| 64 | + recursive: Annotated[ |
| 65 | + bool, |
| 66 | + typer.Option("--recursive/--no-recursive", help="Enable recursive search"), |
| 67 | + ] = True, |
| 68 | + file_extensions: Annotated[ |
| 69 | + list[str] | None, |
| 70 | + typer.Option( |
| 71 | + "--file-extensions", |
| 72 | + help="Allowed file extensions (use empty string '' to match files without extensions)", |
| 73 | + ), |
| 74 | + ] = None, |
| 75 | + min_text_length: Annotated[int, typer.Option("--min-text-length", help="Minimum document text length")] = 50, |
| 76 | + sentences_per_chunk: Annotated[int, typer.Option("--sentences-per-chunk", help="Sentences per chunk")] = 5, |
| 77 | + num_sections: Annotated[int, typer.Option("--num-sections", help="Sections to divide chunks into")] = 1, |
| 78 | + num_files: Annotated[int | None, typer.Option("--num-files", help="Max files to process")] = None, |
| 79 | + multi_doc: Annotated[bool, typer.Option("--multi-doc", help="Enable multi-doc bundling")] = False, |
| 80 | + bundle_size: Annotated[int, typer.Option("--bundle-size", help="Docs per bundle")] = 2, |
| 81 | + bundle_strategy: Annotated[ |
| 82 | + str, |
| 83 | + typer.Option( |
| 84 | + "--bundle-strategy", |
| 85 | + help="Section splitting strategy", |
| 86 | + click_type=click.Choice(["sequential", "doc_balanced", "interleaved"]), |
| 87 | + ), |
| 88 | + ] = "sequential", |
| 89 | + max_docs_per_bundle: Annotated[int, typer.Option("--max-docs-per-bundle", help="Max docs per bundle")] = 3, |
| 90 | + multi_doc_manifest: Annotated[ |
| 91 | + Path | None, typer.Option("--multi-doc-manifest", help="Manifest for explicit bundles") |
| 92 | + ] = None, |
| 93 | + start_index: Annotated[int, typer.Option("--start-index", help="Start seed row index")] = 0, |
| 94 | + end_index: Annotated[int, typer.Option("--end-index", help="End seed row index")] = 199, |
| 95 | + max_artifacts_per_type: Annotated[int, typer.Option("--max-artifacts-per-type", help="Max artifacts per type")] = 2, |
| 96 | + num_pairs: Annotated[int, typer.Option("--num-pairs", help="QA pairs per document")] = 7, |
| 97 | + min_hops: Annotated[int, typer.Option("--min-hops", help="Min hops for multi-hop questions")] = 2, |
| 98 | + max_hops: Annotated[int, typer.Option("--max-hops", help="Max hops for multi-hop questions")] = 4, |
| 99 | + min_complexity: Annotated[int, typer.Option("--min-complexity", help="Min question complexity")] = 4, |
| 100 | + similarity_threshold: Annotated[ |
| 101 | + float, typer.Option("--similarity-threshold", help="Cosine threshold for QA-pair dedup") |
| 102 | + ] = 0.9, |
| 103 | + artifact_extraction_model: Annotated[ |
| 104 | + str, typer.Option("--artifact-extraction-model", help="Artifact extraction model") |
| 105 | + ] = DEFAULT_CHAT_MODEL, |
| 106 | + artifact_extraction_provider: Annotated[ |
| 107 | + str, typer.Option("--artifact-extraction-provider", help="Artifact extraction provider") |
| 108 | + ] = DEFAULT_PROVIDER, |
| 109 | + qa_generation_model: Annotated[str, typer.Option("--qa-generation-model", help="QA generation model")] = ( |
| 110 | + DEFAULT_CHAT_MODEL |
| 111 | + ), |
| 112 | + qa_generation_provider: Annotated[str, typer.Option("--qa-generation-provider", help="QA generation provider")] = ( |
| 113 | + DEFAULT_PROVIDER |
| 114 | + ), |
| 115 | + quality_judge_model: Annotated[str, typer.Option("--quality-judge-model", help="Quality judge model")] = ( |
| 116 | + DEFAULT_CHAT_MODEL |
| 117 | + ), |
| 118 | + quality_judge_provider: Annotated[str, typer.Option("--quality-judge-provider", help="Quality judge provider")] = ( |
| 119 | + DEFAULT_PROVIDER |
| 120 | + ), |
| 121 | + embed_model: Annotated[str, typer.Option("--embed-model", help="Embedding model")] = DEFAULT_EMBED_MODEL, |
| 122 | + embed_provider: Annotated[str, typer.Option("--embed-provider", help="Embedding provider")] = DEFAULT_PROVIDER, |
| 123 | + max_parallel_requests_for_gen: Annotated[ |
| 124 | + int | None, typer.Option("--max-parallel-requests-for-gen", help="Max parallel generation requests") |
| 125 | + ] = None, |
| 126 | +) -> dd.DataDesignerConfigBuilder: |
| 127 | + """Build the retrieval SDG Data Designer workflow. |
| 128 | +
|
| 129 | + Returns: |
| 130 | + A configured Data Designer config builder. |
| 131 | + """ |
| 132 | + if end_index < start_index: |
| 133 | + raise click.BadParameter("--end-index must be greater than or equal to --start-index") |
| 134 | + |
| 135 | + seed_source = DocumentChunkerSeedSource( |
| 136 | + path=str(input_dir), |
| 137 | + file_pattern=file_pattern, |
| 138 | + recursive=recursive, |
| 139 | + file_extensions=file_extensions or [".txt", ".md", ".text"], |
| 140 | + min_text_length=min_text_length, |
| 141 | + sentences_per_chunk=sentences_per_chunk, |
| 142 | + num_sections=num_sections, |
| 143 | + num_files=num_files, |
| 144 | + multi_doc=multi_doc, |
| 145 | + bundle_size=bundle_size, |
| 146 | + bundle_strategy=bundle_strategy, |
| 147 | + max_docs_per_bundle=max_docs_per_bundle, |
| 148 | + multi_doc_manifest=str(multi_doc_manifest) if multi_doc_manifest else None, |
| 149 | + ) |
| 150 | + |
| 151 | + return build_qa_generation_pipeline( |
| 152 | + seed_source=seed_source, |
| 153 | + start_index=start_index, |
| 154 | + end_index=end_index, |
| 155 | + max_artifacts_per_type=max_artifacts_per_type, |
| 156 | + num_pairs=num_pairs, |
| 157 | + min_hops=min_hops, |
| 158 | + max_hops=max_hops, |
| 159 | + min_complexity=min_complexity, |
| 160 | + similarity_threshold=similarity_threshold, |
| 161 | + max_parallel_requests_for_gen=max_parallel_requests_for_gen, |
| 162 | + artifact_extraction_model=artifact_extraction_model, |
| 163 | + artifact_extraction_provider=artifact_extraction_provider, |
| 164 | + qa_generation_model=qa_generation_model, |
| 165 | + qa_generation_provider=qa_generation_provider, |
| 166 | + quality_judge_model=quality_judge_model, |
| 167 | + quality_judge_provider=quality_judge_provider, |
| 168 | + embed_model=embed_model, |
| 169 | + embed_provider=embed_provider, |
| 170 | + ) |
0 commit comments