Skip to content

Commit 6448afb

Browse files
committed
feat: support infinite dataset and move to logging
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent a8d2057 commit 6448afb

2 files changed

Lines changed: 36 additions & 7 deletions

File tree

tuning/config/configs.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
# Standard
1616
from dataclasses import dataclass, field
1717
from typing import List, Optional, Union
18+
import logging
1819
import os
1920

2021
# Third Party
2122
from datasets import Dataset, IterableDataset, interleave_datasets
22-
from pyarrow.lib import ArrowInvalid
2323
from tqdm import tqdm
24-
from transformers.utils import logging
2524
import datasets
2625
import torch
2726
import transformers
@@ -42,7 +41,7 @@
4241
DEFAULT_UNK_TOKEN = "<unk>"
4342

4443

45-
logger = logging.get_logger("sft_trainer")
44+
logger = logging.getLogger(__name__)
4645

4746

4847
def _load_data(data_path, split, streaming, config_kwargs):
@@ -366,6 +365,11 @@ def __post_init__(self):
366365
) = load_multi_dataset_with_sampling(
367366
data_config=data_config, column_name_options=column_name_options
368367
)
368+
if self.packing:
369+
logger.warning(
370+
"packing is enabled and strictly avoid using packing for non pretraining use cases \
371+
like fine-tuning to avoid cross contamination."
372+
)
369373
if data_config.data_sampler == "tokens_based":
370374
if not self.packing:
371375
raise ValueError(
@@ -395,6 +399,13 @@ def __post_init__(self):
395399
cache_dir=self.cache_dir,
396400
use_fast=True,
397401
)
402+
if self.max_steps > 0:
403+
logger.warning(
404+
f"dataset will be iterated infinitely until max_steps {self.max_steps} is met."
405+
)
406+
logger.warning(
407+
f"num_train_epochs {self.num_train_epochs} is ignored by the trainer"
408+
)
398409
self.train_dataset = ConstantLengthHybridDataset(
399410
train_datasets,
400411
train_probs,
@@ -405,6 +416,7 @@ def __post_init__(self):
405416
self.dataset_text_field,
406417
self.add_bos_token,
407418
self.add_eos_token,
419+
True if self.max_steps > 0 else False,
408420
)
409421
if validation_datasets:
410422
self.validation_dataset = ConstantLengthHybridDataset(

tuning/utils/data_loaders.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Standard
22
from typing import List, Union
3+
import logging
34

45
# Third Party
56
from datasets import Dataset
67
from datasets import IterableDataset as HFIterableDataset
78
from datasets import interleave_datasets
89
from torch.utils.data import IterableDataset
9-
from transformers.utils import logging
1010
import torch
1111

12-
logger = logging.get_logger("transformers")
12+
logger = logging.getLogger(__name__)
1313

1414

1515
class ConstantLengthHybridDataset(
@@ -26,6 +26,7 @@ def __init__( # pylint: disable=super-init-not-called
2626
text_field="contents",
2727
add_bos_token=True,
2828
add_eos_token=True,
29+
infinite=False,
2930
):
3031
"""packing for pretokenized datasets for pretraining only
3132
since all tokens are attended upon packing.
@@ -47,6 +48,8 @@ def __init__( # pylint: disable=super-init-not-called
4748
Defaults to True.
4849
add_eos_token (bool, optional): add eos token at the end of each sample.
4950
Defaults to True.
51+
infinite (`bool`, *optional*, defaults to `False`):
52+
If True the iterator is reset after dataset reaches end else stops.
5053
"""
5154
self.datasets = datasets
5255
self.sampling_probs = sampling_probs
@@ -62,6 +65,12 @@ def __init__( # pylint: disable=super-init-not-called
6265
self.add_eos_token = add_eos_token
6366
self.dataset = interleave_datasets(datasets=self.datasets, split="train")
6467
self.column_names = self.dataset.column_names
68+
self.infinite = infinite
69+
if self.infinite:
70+
logger.warning(
71+
"samples will be provided infinitely.\
72+
Datasets that are exhausted will be reiterated from start."
73+
)
6574
# self._info = self.dataset._info
6675
# self._epoch = 0
6776
logger.warning("add_bos_token: {}".format(self.add_bos_token))
@@ -125,8 +134,16 @@ def __iter__(self):
125134
)
126135
buffer_len = len(buffer)
127136
except StopIteration:
128-
more_examples = False
129-
break
137+
if self.infinite:
138+
iterators[dataset_id_which_needs_more_tokens] = iter(
139+
self.datasets[dataset_id_which_needs_more_tokens]
140+
)
141+
logger.warning(
142+
"iterator is reset for one of the datasets since it is exhausted."
143+
)
144+
else:
145+
more_examples = False
146+
break
130147
all_token_ids = buffer
131148
examples = []
132149
for i in range(0, len(all_token_ids), self.seq_length):

0 commit comments

Comments
 (0)