Skip to content

Commit d3748c2

Browse files
AAnooshehChenhanYuclaude
authored
Allow basename of dataset paths to match registered names (#997)
### What does this PR do? Allow local dataset paths to match registered dataset configs Type of change: Bug fix <!-- Details about the change. --> ### Usage ```python # Add a code snippet demonstrating how to use this ``` ### Testing <!-- Mention how have you tested your change if applicable. --> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, using `torch.load(..., weights_only=True)`, avoiding `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other source, did you follow IP policy in [CONTRIBUTING.md](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md#-copying-code-from-other-sources)?: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a small sample dataset entry (minipile_100_samples) and support for loading datasets from local filesystem paths with automatic detection and config override. * **Chores** * Improved local-path resolution and substring-based matching against registered dataset keys for consistent behavior. * **Tests** * Added a unit test to verify loading samples from a local dataset snapshot. * **Documentation** * Updated docs to describe local-path support, matching behavior, and updated function docstring. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenhan Yu <chenhany@nvidia.com> Signed-off-by: Asha Anoosheh <aanoosheh@nvidia.com> Co-authored-by: Chenhan Yu <chenhany@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 6d77ce7 commit d3748c2

5 files changed

Lines changed: 153 additions & 127 deletions

File tree

examples/megatron_bridge/prune_minitron.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,10 @@ def get_args() -> argparse.Namespace:
8989
"--calib_dataset_name",
9090
type=str,
9191
default="nemotron-post-training-dataset-v2",
92-
choices=get_supported_datasets(),
93-
help="Dataset name for calibration",
92+
help=(
93+
f"HF Dataset name or local path for calibration (supported options: {', '.join(get_supported_datasets())}. "
94+
"You can also pass any other dataset and see if auto-detection for your dataset works."
95+
),
9496
)
9597
parser.add_argument(
9698
"--calib_num_samples", type=int, default=1024, help="Number of samples for calibration"

modelopt/torch/utils/dataset_utils.py

Lines changed: 114 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
import copy
1919
import json
20+
import os
2021
from collections.abc import Callable
22+
from pathlib import Path
2123
from typing import TYPE_CHECKING, Any
2224
from warnings import warn
2325

26+
import requests
2427
import torch
2528
from torch.utils.data import DataLoader
2629
from tqdm import tqdm
@@ -48,9 +51,9 @@
4851
"name": "SFT",
4952
"split": ["code", "math", "science", "chat", "safety"],
5053
},
51-
"preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["input"])
52-
+ "\n"
53-
+ sample["output"],
54+
"preprocess": lambda sample: (
55+
"\n".join(turn["content"] for turn in sample["input"]) + "\n" + sample["output"]
56+
),
5457
},
5558
"nemotron-post-training-dataset-v2": {
5659
"config": {
@@ -104,14 +107,16 @@
104107

105108
__all__ = [
106109
"create_forward_loop",
110+
"download_hf_dataset_as_jsonl",
107111
"get_dataset_dataloader",
108112
"get_dataset_samples",
113+
"get_jsonl_text_samples",
109114
"get_max_batch_size",
110115
"get_supported_datasets",
111116
]
112117

113118

114-
def _get_jsonl_text_samples(jsonl_path: str, num_samples: int) -> list[str]:
119+
def get_jsonl_text_samples(jsonl_path: str, num_samples: int, key: str = "text") -> list[str]:
115120
"""Load up to ``num_samples`` entries from a JSONL file using the ``text`` field.
116121
117122
Each non-empty line must be a JSON object containing a ``text`` field.
@@ -142,12 +147,12 @@ def _get_jsonl_text_samples(jsonl_path: str, num_samples: int) -> list[str]:
142147
f"got {type(obj)}."
143148
)
144149

145-
if "text" not in obj:
150+
if key not in obj:
146151
raise ValueError(
147-
f"Missing required field 'text' in JSONL file {jsonl_path} at line {line_idx}."
152+
f"Missing required field '{key}' in JSONL file {jsonl_path} at line {line_idx}."
148153
)
149154

150-
samples.append(str(obj["text"]))
155+
samples.append(str(obj[key]))
151156

152157
return samples
153158

@@ -158,9 +163,7 @@ def _normalize_splits(split: str | list[str]) -> list[str]:
158163

159164

160165
def _auto_preprocess_sample(
161-
sample: dict,
162-
dataset_name: str,
163-
tokenizer: "PreTrainedTokenizerBase | None" = None,
166+
sample: dict, dataset_name: str, tokenizer: "PreTrainedTokenizerBase | None" = None
164167
) -> str:
165168
"""Auto-detect dataset format and preprocess a single sample based on column conventions.
166169
@@ -223,7 +226,10 @@ def get_dataset_samples(
223226
``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.
224227
225228
Args:
226-
dataset_name: Name or HuggingFace path of the dataset to load, or a path to a ``.jsonl``/``.jsonl.gz`` file.
229+
dataset_name: Name or HuggingFace path of the dataset to load, a local directory path,
230+
or a path to a ``.jsonl`` file. For local directory paths, the
231+
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
232+
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
227233
num_samples: Number of samples to load from the dataset.
228234
apply_chat_template: Whether to apply the chat template to the samples
229235
(if supported by the dataset). For unregistered datasets with a
@@ -240,15 +246,22 @@ def get_dataset_samples(
240246
"""
241247
# Local JSONL file path support (each line is a JSON object with a `text` field).
242248
if dataset_name.endswith(".jsonl"):
243-
return _get_jsonl_text_samples(dataset_name, num_samples)
249+
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
244250

245251
from datasets import load_dataset
246252

253+
local_dataset_path = None
254+
if os.path.exists(dataset_name): # Local path
255+
local_dataset_path = dataset_name
256+
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))
257+
247258
is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
248259

249260
if is_registered:
250261
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
251262
config = dataset_config["config"].copy()
263+
if local_dataset_path:
264+
config["path"] = local_dataset_path
252265
splits = _normalize_splits(split) if split is not None else config.pop("split", [None])
253266
if split is not None:
254267
config.pop("split", None)
@@ -274,17 +287,18 @@ def _preprocess(sample: dict) -> str:
274287
return dataset_config["preprocess"](sample)
275288

276289
else:
277-
warn(
290+
print(
278291
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
279292
"Auto-detecting format from column names."
280293
)
281-
config = {"path": dataset_name}
294+
config = {"path": local_dataset_path or dataset_name}
282295
splits = _normalize_splits(split) if split is not None else ["train"]
283296

284297
def _preprocess(sample: dict) -> str:
285298
return _auto_preprocess_sample(sample, dataset_name, tokenizer)
286299

287300
# load_dataset does not support a list of splits while streaming, so load each separately.
301+
print(f"Loading dataset with {config=} and {splits=}")
288302
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]
289303

290304
num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
@@ -649,3 +663,89 @@ def create_forward_loop(
649663
def model_type_is_enc_dec(model):
650664
enc_dec_model_list = ["t5", "bart", "whisper"]
651665
return any(model_name in model.__class__.__name__.lower() for model_name in enc_dec_model_list)
666+
667+
668+
def download_hf_dataset_as_jsonl(
669+
dataset_name: str,
670+
output_dir: str | Path,
671+
json_keys: list[str] = ["text"],
672+
name: str | None = None,
673+
split: str | None = "train",
674+
max_samples_per_split: int | None = None,
675+
) -> list[str]:
676+
"""Download a Hugging Face dataset and save as JSONL files.
677+
678+
Args:
679+
dataset_name: Name or HuggingFace path of the dataset to download
680+
output_dir: Directory to save the JSONL files
681+
json_keys: List of keys to extract from the dataset. Defaults to ["text"].
682+
name: Name of the subset to download
683+
split: Split of the dataset to download. Defaults to "train".
684+
max_samples_per_split: Maximum number of samples to download per split. Defaults to None.
685+
686+
Returns:
687+
List of paths to downloaded JSONL files.
688+
"""
689+
from datasets import load_dataset
690+
from huggingface_hub.utils import build_hf_headers
691+
692+
print(f"Downloading dataset {dataset_name} from Hugging Face")
693+
jsonl_paths: list[str] = []
694+
695+
try:
696+
response = requests.get(
697+
f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}",
698+
headers=build_hf_headers(),
699+
timeout=10,
700+
)
701+
response.raise_for_status()
702+
except requests.RequestException as e:
703+
raise RuntimeError(f"Failed to fetch dataset splits for {dataset_name}: {e}") from e
704+
705+
response_json = response.json()
706+
print(f"\nFound {len(response_json['splits'])} total splits for {dataset_name}:")
707+
for entry in response_json["splits"]:
708+
print(f"\t{entry}")
709+
710+
splits_to_process = []
711+
for entry in response_json["splits"]:
712+
if name is not None and name != entry.get("config", None):
713+
continue
714+
if split is not None and split != entry["split"]:
715+
continue
716+
splits_to_process.append(entry)
717+
718+
print(f"\nFound {len(splits_to_process)} splits to process:")
719+
for entry in splits_to_process:
720+
print(f"\t{entry}")
721+
722+
for entry in splits_to_process:
723+
skip_processing = False
724+
path = entry["dataset"]
725+
name = entry.get("config", None)
726+
split = entry["split"]
727+
if max_samples_per_split is not None:
728+
split = f"{split}[:{max_samples_per_split}]"
729+
jsonl_file_path = f"{output_dir}/{path.replace('/', '--')}_{name}_{split}.jsonl"
730+
731+
print(f"\nLoading HF dataset {path=}, {name=}, {split=}")
732+
if os.path.exists(jsonl_file_path):
733+
jsonl_paths.append(jsonl_file_path)
734+
print(f"\t[SKIP] Raw dataset {jsonl_file_path} already exists")
735+
continue
736+
ds = load_dataset(path=path, name=name, split=split)
737+
738+
for key in json_keys:
739+
if key not in ds.features:
740+
warn(f"[SKIP] {key=} not found in {ds.features=}")
741+
skip_processing = True
742+
break
743+
744+
if skip_processing:
745+
continue
746+
747+
print(f"Saving raw dataset to {jsonl_file_path}")
748+
ds.to_json(jsonl_file_path)
749+
jsonl_paths.append(jsonl_file_path)
750+
751+
return jsonl_paths

modelopt/torch/utils/plugins/megatron_preprocess_data.py

Lines changed: 5 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,13 @@
5656
import argparse
5757
import json
5858
import multiprocessing
59-
import os
6059
from pathlib import Path
61-
from warnings import warn
6260

63-
import requests
64-
from datasets import load_dataset
65-
from huggingface_hub.utils import build_hf_headers
6661
from megatron.core.datasets import indexed_dataset
6762
from transformers import AutoTokenizer
6863

6964
from modelopt.torch.utils import num2hrb
65+
from modelopt.torch.utils.dataset_utils import download_hf_dataset_as_jsonl
7066

7167
__all__ = ["megatron_preprocess_data"]
7268

@@ -188,82 +184,6 @@ def process_json_file(
188184
return final_enc_len
189185

190186

191-
def _download_hf_dataset(
192-
dataset: str,
193-
output_dir: str | Path,
194-
json_keys: list[str],
195-
name: str | None = None,
196-
split: str | None = "train",
197-
max_samples_per_split: int | None = None,
198-
) -> list[str]:
199-
"""Download a Hugging Face dataset and save as JSONL files.
200-
201-
Returns:
202-
List of paths to downloaded JSONL files.
203-
"""
204-
print(f"Downloading dataset {dataset} from Hugging Face")
205-
jsonl_paths: list[str] = []
206-
207-
try:
208-
response = requests.get(
209-
f"https://datasets-server.huggingface.co/splits?dataset={dataset}",
210-
headers=build_hf_headers(),
211-
timeout=10,
212-
)
213-
response.raise_for_status()
214-
except requests.RequestException as e:
215-
raise RuntimeError(f"Failed to fetch dataset splits for {dataset}: {e}") from e
216-
217-
response_json = response.json()
218-
print(f"\nFound {len(response_json['splits'])} total splits for {dataset}:")
219-
for entry in response_json["splits"]:
220-
print(f"\t{entry}")
221-
222-
splits_to_process = []
223-
for entry in response_json["splits"]:
224-
if name is not None and name != entry.get("config", None):
225-
continue
226-
if split is not None and split != entry["split"]:
227-
continue
228-
splits_to_process.append(entry)
229-
230-
print(f"\nFound {len(splits_to_process)} splits to process:")
231-
for entry in splits_to_process:
232-
print(f"\t{entry}")
233-
234-
for entry in splits_to_process:
235-
skip_processing = False
236-
path = entry["dataset"]
237-
name = entry.get("config", None)
238-
split = entry["split"]
239-
if max_samples_per_split is not None:
240-
split = f"{split}[:{max_samples_per_split}]"
241-
jsonl_file_path = f"{output_dir}/raw/{path.replace('/', '--')}_{name}_{split}.jsonl"
242-
243-
print(f"\nLoading HF dataset {path=}, {name=}, {split=}")
244-
if os.path.exists(jsonl_file_path):
245-
jsonl_paths.append(jsonl_file_path)
246-
print(f"\t[SKIP] Raw dataset {jsonl_file_path} already exists")
247-
continue
248-
ds = load_dataset(path=path, name=name, split=split)
249-
250-
for key in json_keys:
251-
if key not in ds.features:
252-
warn(f"[SKIP] {key=} not found in {ds.features=}")
253-
skip_processing = True
254-
break
255-
256-
if skip_processing:
257-
continue
258-
259-
print(f"Saving raw dataset to {jsonl_file_path}")
260-
ds.to_json(jsonl_file_path)
261-
jsonl_paths.append(jsonl_file_path)
262-
263-
print(f"\n\nTokenizing JSONL paths: {jsonl_paths}\n")
264-
return jsonl_paths
265-
266-
267187
def megatron_preprocess_data(
268188
*,
269189
input_dir: str | Path | None = None,
@@ -309,14 +229,15 @@ def megatron_preprocess_data(
309229
)
310230

311231
if hf_dataset is not None:
312-
jsonl_paths = _download_hf_dataset(
232+
jsonl_paths = download_hf_dataset_as_jsonl(
313233
hf_dataset,
314-
output_dir,
234+
f"{output_dir}/raw",
315235
json_keys,
316236
name=hf_name,
317237
split=hf_split,
318238
max_samples_per_split=hf_max_samples_per_split,
319239
)
240+
print(f"\n\nTokenizing downloaded JSONL files: {jsonl_paths}\n")
320241

321242
if input_dir is not None:
322243
file_names = sorted(Path(input_dir).glob("*.jsonl"))
@@ -338,7 +259,7 @@ def megatron_preprocess_data(
338259
num_tokens = partition.process_json_file(name, output_dir, encoder)
339260
final_enc_len += num_tokens
340261

341-
print(f"\n\n>>> Total number of tokens currently processed: {num2hrb(final_enc_len)}")
262+
print(f"\n\n>>> Total number of tokens currently processed: {num2hrb(final_enc_len)}\nDone!")
342263

343264

344265
def main():

0 commit comments

Comments
 (0)