1515
1616import logging
1717import os
18- import time
1918import uuid
2019from typing import Any , Iterator
2120
@@ -53,9 +52,7 @@ class StreamingDataset(IterableDataset):
5352 ... required_fields=["input_ids", "attention_mask"],
5453 ... partition_id="train",
5554 ... task_name="update_actor",
56- ... data_replica_group=data_replica_group_id, # Same for all ranks in data replica group
57- ... data_replica_rank=local_rank, # local rank in data replica group
58- ... data_replica_world_size=world_size/dp_world_size, # size of data replica group
55+ ... dp_rank=dp_rank, # Same for all ranks in data replica group
5956 ... )
6057 >>> dataloader = StreamingDataLoader(
6158 ... dataset,
@@ -71,13 +68,15 @@ class StreamingDataset(IterableDataset):
7168 def __init__ (
7269 self ,
7370 config : dict [str , Any ],
71+ batch_size : int ,
7472 micro_batch_size : int ,
75- required_fields : list [str ],
73+ data_fields : list [str ],
7674 partition_id : str ,
7775 task_name : str ,
78- data_replica_group : int ,
79- data_replica_rank : int ,
80- data_replica_world_size : int ,
76+ dp_rank : int ,
77+ n_samples_per_prompt : int ,
78+ custom_get_batch_func : Any = None ,
79+ custom_post_process_for_micro_func : Any = None ,
8180 ):
8281 """Initialize the StreamingDataset.
8382
@@ -86,20 +85,22 @@ def __init__(
8685 - controller_info: ZMQServerInfo for the TransferQueueController
8786 - storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager")
8887 - Other backend-specific configuration
88+ batch_size: Batch size for data loading per iter.
8989 micro_batch_size: Number of samples per micro-batch. This is the batch size
9090 that will be requested from TransferQueue for each iteration.
91- required_fields : List of field names to retrieve from storage. Only these
91+ data_fields : List of field names to retrieve from storage. Only these
9292 fields will be included in the returned batch.
9393 partition_id: Partition ID for data versioning. Different partitions can
9494 be used for different data versions or splits (e.g., "train", "val").
9595 task_name: Unique identifier for the training task. This is used to track
9696 which samples have been consumed by which task.
97- data_replica_group: The group ID of the current data replica group. All
98- ranks with the same data_replica_group will receive identical samples.
99- data_replica_rank: Local rank index within the data_replica_group. Range:
100- [0, data_replica_world_size - 1]
101- data_replica_world_size: Total number of ranks in this data_replica_group.
102- Must be >= 1.
97+ dp_rank: The group ID of the current data group. All
98+ ranks with the same dp_rank will receive identical samples.
99+ n_samples_per_prompt: Number of samples generated per prompt for training.
100+ custom_get_batch_func: Optional custom function to retrieve batch data.
101+ If None, uses default_get_batch function.
102+ custom_post_process_for_micro_func: Optional custom function to post-process
103+ and split data into micro-batches. If None, uses default_post_process_for_micro_func.
103104
104105 Raises:
105106 ValueError: If input parameters are invalid.
@@ -108,41 +109,49 @@ def __init__(
108109 if micro_batch_size < 1 :
109110 raise ValueError (f"micro_batch_size must be >= 1, got { micro_batch_size } " )
110111
111- if len (required_fields ) < 1 :
112- raise ValueError (f"required_fields must be a list with at least one field name, got { required_fields } " )
112+ if len (data_fields ) < 1 :
113+ raise ValueError (f"required_fields must be a list with at least one field name, got { data_fields } " )
113114
114- if data_replica_world_size < 1 :
115- raise ValueError (f"data_replica_world_size { data_replica_world_size } must >= 1" )
116-
117- if data_replica_rank >= data_replica_world_size or data_replica_rank < 0 :
118- raise ValueError (
119- f"data_replica_rank { data_replica_rank } must be greater than or equal to 0 and less than "
120- f"data_replica_world_size { data_replica_world_size } "
121- )
115+ if dp_rank < 0 :
116+ raise ValueError (f"dp_rank { dp_rank } must be greater than or equal to 0" )
122117
123118 self .config = config
119+ self .batch_size = batch_size
124120 self .micro_batch_size = micro_batch_size
125- self .required_fields = required_fields
121+ self .data_fields = data_fields
126122 self .partition_id = partition_id
127123 self .task_name = task_name
128- self .data_replica_group = data_replica_group
129- self .data_replica_rank = data_replica_rank
130- self .data_replica_world_size = data_replica_world_size
124+ self .dp_rank = dp_rank
125+ self .n_samples_per_prompt = n_samples_per_prompt
126+ self .get_batch_func = custom_get_batch_func if custom_get_batch_func else default_get_batch
127+ self .post_process_for_micro_func = (
128+ custom_post_process_for_micro_func
129+ if custom_post_process_for_micro_func
130+ else default_post_process_for_micro_func
131+ )
131132
132133 # Build sampling config for controller
133134 self .sampling_config = {
134- "data_replica_group" : self .data_replica_group ,
135- "data_replica_rank" : self .data_replica_rank ,
136- "data_replica_world_size" : self .data_replica_world_size ,
135+ "dp_rank" : self .dp_rank ,
137136 "task_name" : self .task_name ,
138- "partition_id " : self .partition_id ,
137+ "n_samples_per_prompt " : self .n_samples_per_prompt ,
139138 }
140139
141140 self ._tq_client = None
141+ self .buffer : list [tuple ] = []
142+ self .batch_index = 0
142143
143144 super ().__init__ ()
144145
145146 def _create_client (self ):
147+ """Create and initialize a TransferQueue client.
148+
149+ This method initializes the TransferQueueClient with the provided configuration
150+ and storage backend, and sets up the storage manager for data retrieval.
151+
152+ Raises:
153+ ValueError: If controller_info or storage_backend is missing or invalid.
154+ """
146155 client_id = uuid .uuid4 ().hex [:8 ]
147156 controller_info = self .config .get ("controller_info" , None )
148157 if not controller_info or not isinstance (controller_info , ZMQServerInfo ):
@@ -175,30 +184,141 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
175184 # Note: For fully streamed production-consumption, please set the environment variable
176185 # TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately
177186 # determine consumption status even before producers have generated the samples.
178- while not self ._tq_client .check_consumption_status (self .task_name , self .partition_id ):
187+ while (
188+ not self ._tq_client .check_consumption_status (self .task_name , self .partition_id )
189+ or self .batch_index <= len (self .buffer ) - 1
190+ ):
179191 try :
180- # Get metadata from controller
181- batch_meta = self ._tq_client .get_meta (
182- data_fields = self .required_fields ,
183- batch_size = self .micro_batch_size ,
184- partition_id = self .partition_id ,
185- task_name = self .task_name ,
186- sampling_config = self .sampling_config ,
187- )
188-
189- # Check if we got valid data
190- if batch_meta .size == 0 :
191- logger .debug (
192- f"[StreamingDataset]: Received empty batch, waiting for more data... "
193- f"Required batch_size={ self .micro_batch_size } , data_fields={ self .required_fields } ,"
194- f"partition_id={ self .partition_id } , task_name={ self .task_name } ."
195- )
192+ if self .batch_index <= len (self .buffer ) - 1 :
193+ current_data = self .buffer [self .batch_index ]
194+ self .batch_index += 1
195+ yield from self .post_process_for_micro_func (* current_data , micro_batch_size = self .micro_batch_size )
196196
197- time .sleep (TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL )
198197 else :
199- batch = self ._tq_client .get_data (batch_meta )
200- yield (batch , batch_meta )
198+ batch_data , batch_meta = self .get_batch_func (
199+ self ._tq_client ,
200+ self .data_fields ,
201+ self .batch_size ,
202+ self .partition_id ,
203+ self .task_name ,
204+ self .sampling_config ,
205+ self .batch_index ,
206+ )
207+ if batch_data is not None :
208+ self .buffer .append ((batch_data , batch_meta ))
201209
202210 except Exception as e :
203211 logger .error (f"[StreamingDataset]: Error in data iteration: { e } " )
204212 raise
213+
214+ def reset (self ):
215+ """Reset the dataset iterator to the beginning.
216+
217+ Clears the buffer and resets the batch index for a fresh iteration.
218+ """
219+ self .batch_index = 0
220+
221+ def step (self , partition_id ):
222+ """Switch to a new partition and reset the dataset state.
223+
224+ This method clears the buffer, resets the batch index, and updates the partition_id
225+ to fetch data from a different partition (e.g., switching from "train" to "val").
226+
227+ Args:
228+ partition_id: The new partition ID to switch to.
229+ """
230+ self .buffer = []
231+ self .batch_index = 0
232+ self .partition_id = partition_id
233+
234+
235+ def default_get_batch (tq_client , data_fields , batch_size , partition_id , task_name , sampling_config , batch_index ):
236+ """Retrieve a batch of data from TransferQueue.
237+
238+ This function queries the TransferQueue controller for batch metadata and retrieves
239+ the actual data if available. It handles empty batches gracefully.
240+
241+ Args:
242+ tq_client: The TransferQueueClient instance for data retrieval.
243+ data_fields: List of field names to retrieve from the batch.
244+ batch_size: The requested batch size.
245+ partition_id: The partition ID for data versioning.
246+ task_name: Unique identifier for the training task.
247+ sampling_config: Configuration dictionary for sampling strategy.
248+ batch_index: Current batch index for tracking consumption progress.
249+
250+ Returns:
251+ tuple: A tuple containing:
252+ - batch: TensorDict with the retrieved data, or None if batch is empty.
253+ - batch_meta: BatchMeta object containing batch metadata.
254+ """
255+ # Get metadata from controller
256+ sampling_config ["batch_index" ] = batch_index
257+ sampling_config ["partition_id" ] = partition_id
258+ batch_meta = tq_client .get_meta (
259+ data_fields = data_fields ,
260+ batch_size = batch_size ,
261+ partition_id = partition_id ,
262+ task_name = task_name ,
263+ sampling_config = sampling_config ,
264+ )
265+
266+ # Check if we got valid data
267+ if batch_meta .size == 0 :
268+ logger .debug (
269+ f"[StreamingDataset]: Received empty batch, waiting for more data... "
270+ f"Required batch_size={ batch_size } , data_fields={ data_fields } ,"
271+ f"partition_id={ partition_id } , task_name={ task_name } ."
272+ )
273+ return None , batch_meta
274+ else :
275+ batch = tq_client .get_data (batch_meta )
276+ return batch , batch_meta
277+
278+
279+ def default_post_process_for_micro_func (td , batch_meta , micro_batch_size = 1 ):
280+ """Split TensorDict into micro-batches along the batch dimension.
281+
282+ This function chunks a TensorDict into smaller micro-batches with the specified size,
283+ along with corresponding metadata chunks. Handles cases where batch size is not
284+ evenly divisible by micro_batch_size.
285+
286+ Args:
287+ td: Input TensorDict with non-empty batch_size.
288+ batch_meta: BatchMeta object to be chunked along with the TensorDict.
289+ micro_batch_size: Target size for each micro-batch (positive integer, default: 1).
290+
291+ Returns:
292+ list: List of tuples (micro_batch_td, micro_batch_meta) where each tuple
293+ contains a TensorDict chunk and corresponding metadata chunk.
294+
295+ Raises:
296+ TypeError: If td is not a TensorDict.
297+ ValueError: If micro_batch_size is not a positive integer, batch_size is empty,
298+ or micro_batch_size exceeds total batch size.
299+ """
300+ if not isinstance (td , TensorDict ):
301+ raise TypeError (f"Expected TensorDict, got { type (td ).__name__ } " )
302+
303+ if not isinstance (micro_batch_size , int ) or micro_batch_size <= 0 :
304+ raise ValueError (f"micro_batch_size must be a positive integer, got { micro_batch_size } " )
305+
306+ if len (td .batch_size ) == 0 :
307+ raise ValueError ("Input TensorDict must have non-empty batch_size" )
308+
309+ total_size = td .batch_size [0 ]
310+ if micro_batch_size > total_size :
311+ raise ValueError (f"micro_batch_size ({ micro_batch_size } ) exceeds total batch size ({ total_size } )" )
312+
313+ # Calculate number of splits (handles uneven division)
314+ num_splits = (total_size + micro_batch_size - 1 ) // micro_batch_size
315+ splits = []
316+ batch_meta_list = batch_meta .chunk (num_splits )
317+
318+ # Chunk the TensorDict and pair with corresponding metadata chunks
319+ for i in range (num_splits ):
320+ start = i * micro_batch_size
321+ end = min (start + micro_batch_size , total_size )
322+ splits .append ((td [start :end ], batch_meta_list [i ]))
323+
324+ return splits
0 commit comments