11# Standard
22from typing import List , Union
3+ import logging
34
45# Third Party
56from datasets import Dataset
67from datasets import IterableDataset as HFIterableDataset
78from datasets import interleave_datasets
89from torch .utils .data import IterableDataset
9- from transformers .utils import logging
1010import torch
1111
12- logger = logging .get_logger ( "transformers" )
12+ logger = logging .getLogger ( __name__ )
1313
1414
1515class ConstantLengthHybridDataset (
@@ -26,6 +26,7 @@ def __init__( # pylint: disable=super-init-not-called
2626 text_field = "contents" ,
2727 add_bos_token = True ,
2828 add_eos_token = True ,
29+ infinite = False ,
2930 ):
3031 """packing for pretokenized datasets for pretraining only
3132 since all tokens are attended upon packing.
@@ -47,6 +48,8 @@ def __init__( # pylint: disable=super-init-not-called
4748 Defaults to True.
4849 add_eos_token (bool, optional): add eos token at the end of each sample.
4950 Defaults to True.
51+ infinite (`bool`, *optional*, defaults to `False`):
52+ If True the iterator is reset after dataset reaches end else stops.
5053 """
5154 self .datasets = datasets
5255 self .sampling_probs = sampling_probs
@@ -62,6 +65,12 @@ def __init__( # pylint: disable=super-init-not-called
6265 self .add_eos_token = add_eos_token
6366 self .dataset = interleave_datasets (datasets = self .datasets , split = "train" )
6467 self .column_names = self .dataset .column_names
68+ self .infinite = infinite
69+ if self .infinite :
70+ logger .warning (
71+ "samples will be provided infinitely.\
72+ Datasets that are exhausted will be reiterated from start."
73+ )
6574 # self._info = self.dataset._info
6675 # self._epoch = 0
6776 logger .warning ("add_bos_token: {}" .format (self .add_bos_token ))
@@ -125,8 +134,16 @@ def __iter__(self):
125134 )
126135 buffer_len = len (buffer )
127136 except StopIteration :
128- more_examples = False
129- break
137+ if self .infinite :
138+ iterators [dataset_id_which_needs_more_tokens ] = iter (
139+ self .datasets [dataset_id_which_needs_more_tokens ]
140+ )
141+ logger .warning (
142+ "iterator is reset for one of the datasets since it is exhausted."
143+ )
144+ else :
145+ more_examples = False
146+ break
130147 all_token_ids = buffer
131148 examples = []
132149 for i in range (0 , len (all_token_ids ), self .seq_length ):
0 commit comments