Skip to content

Commit cca50d1

Browse files
fix: handle PreservingDataset in dataset concatenation (#2116)
`concatenate_datasets` from HuggingFace only accepts HF Dataset objects, causing a ValueError when `use_preserving_dataset=True`. Added a `merge_datasets` helper that detects PreservingDataset instances and merges their data lists directly instead. Closes #2116 Signed-off-by: Adi Krish <143638558+RudimentaryChef@users.noreply.github.com>
1 parent fe3c4fc commit cca50d1

3 files changed

Lines changed: 78 additions & 6 deletions

File tree

examples/run_sft.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import pprint
1818
from functools import partial
1919

20-
from datasets import concatenate_datasets
2120
from omegaconf import OmegaConf
2221
from transformers import AutoTokenizer
2322

@@ -29,6 +28,7 @@
2928
load_response_dataset,
3029
update_single_dataset_config,
3130
)
31+
from nemo_rl.data.utils import merge_datasets
3232
from nemo_rl.distributed.virtual_cluster import init_ray
3333
from nemo_rl.utils.config import (
3434
load_config,
@@ -89,7 +89,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
8989
if hasattr(data, "preprocessor") and data.preprocessor is not None:
9090
task_data_preprocessors[data.task_name] = data.preprocessor
9191

92-
merged_data = concatenate_datasets([data.dataset for data in data_list])
92+
merged_data = merge_datasets([data.dataset for data in data_list])
9393
dataset = AllTaskProcessedDataset(
9494
merged_data,
9595
tokenizer,
@@ -144,7 +144,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
144144

145145
val_dataset = None
146146
if len(val_data_list) > 0:
147-
merged_val_data = concatenate_datasets(val_data_list)
147+
merged_val_data = merge_datasets(val_data_list)
148148
val_dataset = AllTaskProcessedDataset(
149149
merged_val_data,
150150
tokenizer,

nemo_rl/data/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing import Any, Optional, Union
1616

17-
from datasets import concatenate_datasets
17+
from datasets import Dataset, concatenate_datasets
1818
from transformers import AutoProcessor, AutoTokenizer
1919

2020
from nemo_rl.data import DataConfig
@@ -25,11 +25,27 @@
2525
load_response_dataset,
2626
update_single_dataset_config,
2727
)
28+
from nemo_rl.data.datasets.response_datasets.oai_format_dataset import (
29+
PreservingDataset,
30+
)
2831
from nemo_rl.data.processors import preference_preprocessor
2932
from nemo_rl.environments.interfaces import EnvironmentInterface
3033
from nemo_rl.environments.utils import create_env
3134

3235

36+
def merge_datasets(datasets: list) -> Union[Dataset, "PreservingDataset"]:
37+
"""Merge a list of datasets, handling both HuggingFace Dataset and PreservingDataset.
38+
39+
HuggingFace's ``concatenate_datasets`` does not accept ``PreservingDataset`` objects.
40+
This helper detects the dataset types and merges them appropriately.
41+
"""
42+
if all(isinstance(d, PreservingDataset) for d in datasets):
43+
merged_data = [item for d in datasets for item in d.data]
44+
return PreservingDataset(merged_data)
45+
46+
return concatenate_datasets(datasets)
47+
48+
3349
# TODO: @yukih: unify to setup_data after dataset refactored
3450
def setup_response_data(
3551
tokenizer: AutoProcessor | AutoTokenizer,
@@ -134,7 +150,7 @@ def setup_response_data(
134150
}
135151
else:
136152
# merge datasets into a single dataset
137-
merged_data = concatenate_datasets([data.dataset for data in data_list])
153+
merged_data = merge_datasets([data.dataset for data in data_list])
138154
dataset = AllTaskProcessedDataset(
139155
merged_data,
140156
tokenizer,
@@ -199,7 +215,7 @@ def setup_response_data(
199215
# merge datasets
200216
val_dataset = None
201217
if len(val_data_list) > 0:
202-
merged_val_data = concatenate_datasets(val_data_list)
218+
merged_val_data = merge_datasets(val_data_list)
203219
val_dataset = AllTaskProcessedDataset(
204220
merged_val_data,
205221
tokenizer,

tests/unit/data/datasets/test_preserving_dataset.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,59 @@ def test_comparison_with_standard_dataset(self):
313313
preserving_dataset = PreservingDataset(data)
314314
assert preserving_dataset[0]["tool_id"] == "123"
315315
assert "tool_id" not in preserving_dataset[1] # Key doesn't exist
316+
317+
318+
class TestMergeDatasets:
319+
"""Test merge_datasets helper that handles both HF Dataset and PreservingDataset."""
320+
321+
def test_merge_preserving_datasets(self):
322+
"""Test merging multiple PreservingDatasets."""
323+
from nemo_rl.data.utils import merge_datasets
324+
325+
ds1 = PreservingDataset([{"a": 1}, {"b": 2}])
326+
ds2 = PreservingDataset([{"c": 3}])
327+
328+
merged = merge_datasets([ds1, ds2])
329+
330+
assert isinstance(merged, PreservingDataset)
331+
assert len(merged) == 3
332+
assert merged[0] == {"a": 1}
333+
assert merged[1] == {"b": 2}
334+
assert merged[2] == {"c": 3}
335+
336+
def test_merge_hf_datasets(self):
337+
"""Test merging standard HuggingFace Datasets still works."""
338+
from nemo_rl.data.utils import merge_datasets
339+
340+
ds1 = Dataset.from_list([{"x": 1}, {"x": 2}])
341+
ds2 = Dataset.from_list([{"x": 3}])
342+
343+
merged = merge_datasets([ds1, ds2])
344+
345+
assert isinstance(merged, Dataset)
346+
assert len(merged) == 3
347+
assert merged[0]["x"] == 1
348+
assert merged[2]["x"] == 3
349+
350+
def test_merge_single_preserving_dataset(self):
351+
"""Test merging a single PreservingDataset."""
352+
from nemo_rl.data.utils import merge_datasets
353+
354+
ds = PreservingDataset([{"a": 1, "b": 2}, {"c": 3}])
355+
356+
merged = merge_datasets([ds])
357+
358+
assert isinstance(merged, PreservingDataset)
359+
assert len(merged) == 2
360+
361+
def test_merge_preserving_datasets_preserves_heterogeneous_structure(self):
362+
"""Test that merging PreservingDatasets doesn't introduce None-filling."""
363+
from nemo_rl.data.utils import merge_datasets
364+
365+
ds1 = PreservingDataset([{"role": "user", "content": "hi", "tool_id": "1"}])
366+
ds2 = PreservingDataset([{"role": "assistant", "content": "hello"}])
367+
368+
merged = merge_datasets([ds1, ds2])
369+
370+
assert "tool_id" in merged[0]
371+
assert "tool_id" not in merged[1] # No None-filling

0 commit comments

Comments
 (0)