Skip to content

Commit a3c62c2

Browse files
ChenhanYuclaude
andcommitted
refactor(examples): update imports, docstrings, and README for dataset/
- _specdec_aug → conversation_utils: update imports in ptv2/ptv3 scripts and generalize the module docstring - augmentations.yaml: reference both ptv2 and ptv3 in header comment - speculative_decoding/README.md: update all paths to ../dataset/ Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: chenhany <chenhany@nvidia.com>
1 parent 510c4b2 commit a3c62c2

File tree

5 files changed

+23
-17
lines changed

5 files changed

+23
-17
lines changed

examples/dataset/augmentations.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Augmentation specs for make_nemotron_ptv2_dataset.py
1+
# Augmentation specs for make_nemotron_ptv2_dataset.py and make_nemotron_ptv3_dataset.py
22
#
33
# Each entry defines one augmentation variant applied cyclically across the dataset.
44
# The augmented copy is the same size as the source — each row gets exactly one variant.

examples/dataset/conversation_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
"""
17-
Shared augmentation utilities for Nemotron speculative-decoding dataset scripts.
17+
Shared conversation manipulation and augmentation utilities for dataset preparation.
1818
1919
Imported by make_nemotron_ptv2_dataset.py and make_nemotron_ptv3_dataset.py.
2020
@@ -24,10 +24,10 @@
2424
2525
Conversation format
2626
-------------------
27-
Each conversation is kept as a full message list (system + user + assistant turns)
28-
with only the *last* assistant turn stripped — that is the response the target model
29-
will generate. All prior assistant turns are preserved so the model has the full
30-
multi-turn context it needs to produce a coherent next response.
27+
Each conversation is stripped down to a skeleton of system + user turns only — all
28+
assistant turns are removed. The downstream generation pipeline (query.py) feeds this
29+
skeleton to the target model turn-by-turn, appending each generated response before
30+
sending the next user turn, so the model produces coherent multi-turn continuations.
3131
3232
Augmentations are applied only to the *last* user message (the new prompt), not to
3333
earlier user turns that are already part of the established context.
@@ -152,6 +152,8 @@ def strip_assistant_turns(example: dict[str, Any], idx: int) -> dict[str, Any]:
152152
Rows with no user turns are returned empty and filtered out by the caller.
153153
"""
154154
messages = [m for m in example["messages"] if m["role"] in ("system", "user")]
155+
if not any(m["role"] == "user" for m in messages):
156+
return {"messages": []}
155157
return {"messages": messages}
156158

157159

examples/dataset/make_nemotron_ptv2_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262
from datasets import concatenate_datasets, load_dataset
6363

64-
from _specdec_aug import (
64+
from conversation_utils import (
6565
has_tool_turns,
6666
load_augmentations,
6767
make_augment_fn,

examples/dataset/make_nemotron_ptv3_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
import yaml
8686
from datasets import concatenate_datasets, load_dataset
8787

88-
from _specdec_aug import has_tool_turns, load_augmentations, make_augment_fn, normalize_messages, strip_assistant_turns
88+
from conversation_utils import has_tool_turns, load_augmentations, make_augment_fn, normalize_messages, strip_assistant_turns
8989

9090
logging.basicConfig(
9191
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S"
@@ -270,6 +270,10 @@ def main() -> None:
270270
if non_augmentable is not None:
271271
parts_to_combine.append(non_augmentable)
272272

273+
if not parts_to_combine:
274+
logger.warning("No data to combine — all rows were filtered out. Exiting.")
275+
return
276+
273277
combined = concatenate_datasets(parts_to_combine)
274278
logger.info("Combined (pre-shuffle): %d rows", len(combined))
275279

examples/speculative_decoding/README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pip install -r requirements.txt
4848
We support a range of input datasets. In this example, we will use the [UltraChat-200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset.
4949

5050
```bash
51-
python prepare_input_conversations/make_dataset.py -f prepare_input_conversations/example_data_config.yaml --full-conversations
51+
python ../dataset/make_dataset.py -f ../dataset/example_data_config.yaml --full-conversations
5252
```
5353

5454
See [other-datasets](#other-datasets) section for other dataset options and instruction for user-provided data.
@@ -203,7 +203,7 @@ See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/RE
203203

204204
### Other Datasets
205205

206-
In addition to the default dataset, we support adding several other commonly used datasets in `prepare_input_conversations/make_dataset.py`:
206+
In addition to the default dataset, we support adding several other commonly used datasets in `../dataset/make_dataset.py`:
207207

208208
- MTBench (for debugging)
209209
- ShareGPT
@@ -232,10 +232,10 @@ For large-scale training we provide dedicated scripts for NVIDIA's Nemotron Post
232232

233233
```bash
234234
# Synthetic data generation (~3.3M rows):
235-
python prepare_input_conversations/make_nemotron_ptv2_dataset.py --output-dir /tmp/ptv2_gen
235+
python ../dataset/make_nemotron_ptv2_dataset.py --output-dir /tmp/ptv2_gen
236236

237237
# Direct SFT training mix (~1.9M rows):
238-
python prepare_input_conversations/make_nemotron_ptv2_dataset.py --mode train --output-dir /tmp/ptv2_train
238+
python ../dataset/make_nemotron_ptv2_dataset.py --mode train --output-dir /tmp/ptv2_train
239239
```
240240

241241
Covers: `stem`, `chat`, `math`, `code` + 5 multilingual splits (ja/de/it/es/fr, capped at 100K each).
@@ -244,19 +244,19 @@ Covers: `stem`, `chat`, `math`, `code` + 5 multilingual splits (ja/de/it/es/fr,
244244

245245
```bash
246246
# Synthetic data generation (~3.4M rows):
247-
python prepare_input_conversations/make_nemotron_ptv3_dataset.py --output-dir /tmp/ptv3_gen
247+
python ../dataset/make_nemotron_ptv3_dataset.py --output-dir /tmp/ptv3_gen
248248

249249
# Direct SFT training mix (~3.9M rows, includes agentic/tool-use datasets):
250-
python prepare_input_conversations/make_nemotron_ptv3_dataset.py --mode train --output-dir /tmp/ptv3_train
250+
python ../dataset/make_nemotron_ptv3_dataset.py --mode train --output-dir /tmp/ptv3_train
251251
```
252252

253-
Covers: math, code, science, instruction-following, agentic/tool-use, safety, finance, and multilingual data. The dataset mix and per-split row caps are configurable via `prepare_input_conversations/nemotron_ptv3_datasets.yaml`.
253+
Covers: math, code, science, instruction-following, agentic/tool-use, safety, finance, and multilingual data. The dataset mix and per-split row caps are configurable via `../dataset/nemotron_ptv3_datasets.yaml`.
254254

255-
**Augmentation** (generate mode only) is controlled by `prepare_input_conversations/augmentations.yaml`. By default it includes 12 language-redirect variants and several style/format hints. The `/no_think` system-prompt variant is disabled by default (enable it for models that support it, e.g. Qwen3):
255+
**Augmentation** (generate mode only) is controlled by `../dataset/augmentations.yaml`. By default it includes 12 language-redirect variants and several style/format hints. The `/no_think` system-prompt variant is disabled by default (enable it for models that support it, e.g. Qwen3):
256256

257257
```bash
258258
# Custom augmentation config:
259-
python prepare_input_conversations/make_nemotron_ptv2_dataset.py \
259+
python ../dataset/make_nemotron_ptv2_dataset.py \
260260
--augmentations-config my_augs.yaml --output-dir /tmp/ptv2_gen
261261
```
262262

0 commit comments

Comments
 (0)