Skip to content

Commit 2504e08

Browse files
committed
Enable inherent casting of the datasets.
Handles jumbled columns and mismatching dtypes Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
1 parent 03c7058 commit 2504e08

11 files changed

Lines changed: 105 additions & 65 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies = [
3737
"trl>=0.13,<0.18",
3838
"peft>=0.8.0,<0.14",
3939
"protobuf>=5.28.0,<6.0.0",
40-
"datasets>=2.15.0,<4.0",
40+
"datasets>=3.5.0,<4.0",
4141
"simpleeval>=0.9.13,<2.0",
4242
"pillow>=11.0.0,<12.0",
4343
]

tests/artifacts/testdata/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@
7474
CHAT_DATA_MULTI_TURN_GRANITE_3_1B = os.path.join(
7575
JSONL_DATA_DIR, "multi_turn_chat_granite_instruct.jsonl"
7676
)
77+
CHAT_DATASET_LARGELIST = os.path.join(
78+
PARQUET_DATA_DIR, "chat_dataset_tokenized_largelist.parquet"
79+
)
80+
CHAT_DATASET_SEQUENCE = os.path.join(
81+
PARQUET_DATA_DIR, "chat_dataset_tokenized_sequence.parquet"
82+
)
7783
IMAGE_DATASET = os.path.join(JSONL_DATA_DIR, "image_dataset.jsonl")
7884
EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json")
7985
MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json")
Binary file not shown.
Binary file not shown.

tests/test_sft_trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,15 @@
5151
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT,
5252
DATA_CONFIG_YAML_STREAMING_PRETOKENIZED,
5353
GRANITE_3_1_B_CHAT_TEMPLATE,
54+
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML
5455
)
5556
from tests.artifacts.testdata import (
5657
CHAT_DATA_MULTI_TURN,
5758
CHAT_DATA_MULTI_TURN_CONVERSATIONS,
5859
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
5960
CHAT_DATA_SINGLE_TURN,
61+
CHAT_DATASET_LARGELIST,
62+
CHAT_DATASET_SEQUENCE,
6063
CUSTOM_TOKENIZER_TINYLLAMA,
6164
EMPTY_DATA,
6265
MALFORMATTED_DATA,
@@ -820,7 +823,6 @@ def test_run_causallm_ft_pretokenized(dataset_path, packing):
820823
assert len(output_inference) > 0
821824
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
822825

823-
824826
@pytest.mark.parametrize(
825827
"datafiles, datasetconfigname",
826828
[
@@ -836,9 +838,13 @@ def test_run_causallm_ft_pretokenized(dataset_path, packing):
836838
[TWITTER_COMPLAINTS_TOKENIZED_JSON],
837839
DATA_CONFIG_YAML_STREAMING_PRETOKENIZED,
838840
),
841+
(
842+
[CHAT_DATASET_LARGELIST, CHAT_DATASET_SEQUENCE],
843+
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
844+
),
839845
],
840846
)
841-
def test_run_causallm_ft_and_inference_streaming(datasetconfigname, datafiles):
847+
def test_run_causallm_ft_and_inference(datasetconfigname, datafiles):
842848
"""Check if we can finetune causallm models using multiple datasets with multiple files"""
843849
with tempfile.TemporaryDirectory() as tempdir:
844850
data_formatting_args = copy.deepcopy(DATA_ARGS)

tests/utils/test_config_utils.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
# Standard
1919
import base64
20-
import logging
2120
import os
2221
import pickle
2322

@@ -31,7 +30,8 @@
3130

3231
# Local
3332
from tuning.config import peft_config
34-
from tuning.utils import config_utils, utils
33+
from tuning.data import utils
34+
from tuning.utils import config_utils
3535

3636

3737
def test_get_hf_peft_config_returns_None_for_tuning_config_None():
@@ -236,7 +236,7 @@ def test_get_json_config_can_load_from_envvar():
236236
assert job_config["model_name_or_path"] == "foobar"
237237

238238

239-
def test_validate_datasets_logs_warnings_on_mismatch(caplog):
239+
def test_validate_datasets_throws_error_on_mismatch():
240240
"""Test that `validate_mergeable_datasets` logs warnings when
241241
datasets have different columns or dtypes."""
242242
# Create a reference dataset with columns col1:int64 and col2:string
@@ -251,12 +251,5 @@ def test_validate_datasets_logs_warnings_on_mismatch(caplog):
251251
features=Features({"col1": Value("float64"), "col3": Value("string")}),
252252
)
253253

254-
with caplog.at_level(logging.WARNING):
255-
utils.validate_mergeable_datasets([ds1, ds2])
256-
257-
assert (
258-
"different columns" in caplog.text
259-
), "Expected a warning about differing columns."
260-
assert (
261-
"expected int64" in caplog.text
262-
), "Expected a warning about mismatching column dtypes."
254+
with pytest.raises(ValueError):
255+
utils._validate_mergeable_datasets([ds1, ds2])

tuning/data/collators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Local
16-
from tuning.utils.utils import try_convert_bytes_dict_to_pil
16+
from tuning.data.utils import try_convert_bytes_dict_to_pil
1717

1818

1919
class VisionDataCollator:

tuning/data/data_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import os
2121

2222
# Local
23-
from tuning.utils.utils import load_yaml_or_json
23+
from tuning.data.utils import load_yaml_or_json
2424

2525
logger = logging.getLogger(__name__)
2626

tuning/data/data_handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
import torch
3535

3636
# Local
37+
from tuning.data.utils import try_convert_bytes_dict_to_pil, try_convert_image_to_rgb
3738
from tuning.utils.config_utils import process_jinja_placeholders
38-
from tuning.utils.utils import try_convert_bytes_dict_to_pil, try_convert_image_to_rgb
3939

4040
logger = logging.getLogger(__name__)
4141

tuning/data/data_processors.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
DataHandler,
3232
DataHandlerType,
3333
)
34-
from tuning.utils.utils import (
34+
from tuning.data.utils import (
3535
get_loader_for_filepath,
36+
maybe_align_datasets,
3637
resolve_iterable_dataset_features,
37-
validate_mergeable_datasets,
3838
)
3939

4040
logger = logging.getLogger(__name__)
@@ -223,31 +223,30 @@ def _try_load_dataset(dataset_path, dataset_builder, streaming):
223223

224224
for data_path in data_paths:
225225
dataset = _try_load_dataset(data_path, builder, streaming)
226-
if isinstance(dataset, IterableDataset):
227-
dataset = resolve_iterable_dataset_features(dataset)
228226
all_datasets.append(dataset)
229227

230-
# Logs warning if datasets have different columns
231-
validate_mergeable_datasets(all_datasets)
232-
233228
# Concatenate all datasets
234229
try:
235230
if len(all_datasets) == 1:
236231
return all_datasets[0]
237-
232+
maybe_align_datasets(all_datasets)
238233
raw_datasets = datasets.concatenate_datasets(all_datasets)
239234
logger.info(
240-
"Datasets concatenated from %s .Concatenated dataset columns: %s",
235+
"Datasets %s concatenated. Final column features: %s",
241236
datasetconfig.name,
242-
list(raw_datasets.features.keys()),
237+
str(list(raw_datasets.features)),
243238
)
244-
return raw_datasets
245-
246239
except Exception as e:
247240
raise ValueError(
248241
f"An error occurred while concatenating datasets from {datasetconfig.name}: {e}"
249242
) from e
250243

244+
# Need to resolve dataset features because data handlers use columns.
245+
if isinstance(raw_datasets, IterableDataset):
246+
raw_datasets = resolve_iterable_dataset_features(raw_datasets)
247+
248+
return raw_datasets
249+
251250
def __execute_rename_data_handler(self, raw_datasets, handler, **kwargs):
252251
"""
253252
Rename columns in the dataset using the provided column mapping.
@@ -456,9 +455,6 @@ def _process_dataset_configs(
456455
raw_dataset = self.load_dataset(
457456
d, self.processor_config.streaming, splitName
458457
)
459-
if isinstance(raw_dataset, IterableDataset):
460-
raw_dataset = resolve_iterable_dataset_features(raw_dataset)
461-
462458
logger.info("Loaded raw dataset : %s", str(raw_dataset))
463459

464460
if isinstance(raw_dataset, IterableDataset):
@@ -493,6 +489,9 @@ def _process_dataset_configs(
493489
else:
494490
final_datasets[k].append(v)
495491

492+
# Ensure again datasets are aligned before interleaving or concatenating
493+
maybe_align_datasets(final_datasets)
494+
496495
if sample_datasets:
497496
strategy = self.processor_config.sampling_stopping_strategy
498497
seed = self.processor_config.sampling_seed
@@ -517,6 +516,8 @@ def _process_dataset_configs(
517516
)
518517

519518
train_dataset = final_datasets.get("train", None)
519+
520+
# Just a failsafe in case this is required later.
520521
if isinstance(train_dataset, IterableDataset):
521522
train_dataset = resolve_iterable_dataset_features(train_dataset)
522523

0 commit comments

Comments
 (0)