Skip to content

Commit ec0135a

Browse files
committed
[StreamingDataLoader, 5/N] feat: Refactor the StreamDataLoader implementation to support fully asynchronous mode
1 parent aea4d2f commit ec0135a

10 files changed

Lines changed: 454 additions & 267 deletions

File tree

tests/test_samplers.py

Lines changed: 168 additions & 152 deletions
Large diffs are not rendered by default.

transfer_queue/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717

1818
from .client import (
19+
AsyncTransferQueueClient,
1920
TransferQueueClient,
2021
process_zmq_server_info,
2122
)
@@ -32,6 +33,7 @@
3233

3334
__all__ = [
3435
"TransferQueueClient",
36+
"AsyncTransferQueueClient",
3537
"StreamingDataset",
3638
"StreamingDataLoader",
3739
"BatchMeta",

transfer_queue/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ async def async_check_consumption_status(
698698
partition_id=partition_id,
699699
)
700700

701-
if consumption_status is None:
701+
if consumption_status is None or consumption_status.numel() == 0:
702702
return False
703703
return torch.all(consumption_status == 1).item()
704704

transfer_queue/controller.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,12 +1015,12 @@ def get_metadata(
10151015
Raises:
10161016
TimeoutError: If waiting for sufficient data times out in fetch mode
10171017
"""
1018-
if partition_id not in self.partitions:
1019-
self.create_partition(partition_id)
10201018

10211019
if mode == "insert":
1022-
partition = self._get_partition(partition_id)
1020+
if partition_id not in self.partitions:
1021+
self.create_partition(partition_id)
10231022

1023+
partition = self._get_partition(partition_id)
10241024
if data_fields:
10251025
# This is called during put_data call without providing metadata.
10261026
# try to use pre-allocated global index first
@@ -1083,6 +1083,7 @@ def get_metadata(
10831083
ready_for_consume_indexes,
10841084
batch_size,
10851085
**(sampling_config or {}),
1086+
**kwargs,
10861087
)
10871088

10881089
# Check if we got valid results from the sampler
@@ -1240,6 +1241,7 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True):
12401241
partition.clear_data(global_indexes_range, clear_consumption)
12411242
self.index_manager.release_partition(partition_id)
12421243
self.partitions.pop(partition_id)
1244+
self.sampler.clear_cache(partition_id)
12431245

12441246
def clear_meta(
12451247
self,

transfer_queue/dataloader/streaming_dataloader.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
parameter in PyTorch DataLoader is set to None because batching is managed
112112
by the StreamingDataset in coordination with RankAwareSampler.
113113
"""
114+
self.dataset: StreamingDataset = dataset
114115

115116
if collate_fn is None:
116117
# use identical collate function to directly return the self-defined
@@ -137,3 +138,32 @@ def __init__(
137138
persistent_workers=persistent_workers,
138139
pin_memory_device=pin_memory_device,
139140
)
141+
142+
def reset(self):
143+
"""Reset the dataset iterator to the beginning.
144+
145+
Clears the buffer and resets the batch index for a fresh iteration.
146+
"""
147+
self.dataset.reset()
148+
149+
def step(self, partition_id):
150+
"""Switch to a new partition and reset the dataset state.
151+
152+
This method clears the buffer, resets the batch index, and updates the partition_id
153+
to fetch data from a different partition (e.g., switching from "train" to "val").
154+
155+
Args:
156+
partition_id: The new partition ID to switch to.
157+
"""
158+
self.dataset.step(partition_id)
159+
160+
def get_buffer(self):
161+
"""Get the current buffer from the underlying dataset.
162+
163+
Returns the batch buffer maintained by StreamingDataset, which stores
164+
pre-fetched batches for efficient data access.
165+
166+
Returns:
167+
list: Buffer containing pre-fetched (TensorDict, BatchMeta) tuples.
168+
"""
169+
return self.dataset.buffer

transfer_queue/dataloader/streaming_dataset.py

Lines changed: 173 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import logging
1717
import os
18-
import time
1918
import uuid
2019
from typing import Any, Iterator
2120

@@ -53,9 +52,7 @@ class StreamingDataset(IterableDataset):
5352
... required_fields=["input_ids", "attention_mask"],
5453
... partition_id="train",
5554
... task_name="update_actor",
56-
... data_replica_group=data_replica_group_id, # Same for all ranks in data replica group
57-
... data_replica_rank=local_rank, # local rank in data replica group
58-
... data_replica_world_size=world_size/dp_world_size, # size of data replica group
55+
... dp_rank=dp_rank, # Same for all ranks in data replica group
5956
... )
6057
>>> dataloader = StreamingDataLoader(
6158
... dataset,
@@ -71,13 +68,15 @@ class StreamingDataset(IterableDataset):
7168
def __init__(
7269
self,
7370
config: dict[str, Any],
71+
batch_size: int,
7472
micro_batch_size: int,
75-
required_fields: list[str],
73+
data_fields: list[str],
7674
partition_id: str,
7775
task_name: str,
78-
data_replica_group: int,
79-
data_replica_rank: int,
80-
data_replica_world_size: int,
76+
dp_rank: int,
77+
n_samples_per_prompt: int,
78+
custom_get_batch_func: Any = None,
79+
custom_post_process_for_micro_func: Any = None,
8180
):
8281
"""Initialize the StreamingDataset.
8382
@@ -86,20 +85,22 @@ def __init__(
8685
- controller_info: ZMQServerInfo for the TransferQueueController
8786
- storage_backend: Storage backend type (e.g., "AsyncSimpleStorageManager")
8887
- Other backend-specific configuration
88+
batch_size: Batch size for data loading per iter.
8989
micro_batch_size: Number of samples per micro-batch. This is the batch size
9090
that will be requested from TransferQueue for each iteration.
91-
required_fields: List of field names to retrieve from storage. Only these
91+
data_fields: List of field names to retrieve from storage. Only these
9292
fields will be included in the returned batch.
9393
partition_id: Partition ID for data versioning. Different partitions can
9494
be used for different data versions or splits (e.g., "train", "val").
9595
task_name: Unique identifier for the training task. This is used to track
9696
which samples have been consumed by which task.
97-
data_replica_group: The group ID of the current data replica group. All
98-
ranks with the same data_replica_group will receive identical samples.
99-
data_replica_rank: Local rank index within the data_replica_group. Range:
100-
[0, data_replica_world_size - 1]
101-
data_replica_world_size: Total number of ranks in this data_replica_group.
102-
Must be >= 1.
97+
dp_rank: The group ID of the current data group. All
98+
ranks with the same dp_rank will receive identical samples.
99+
n_samples_per_prompt: Number of samples generated per prompt for training.
100+
custom_get_batch_func: Optional custom function to retrieve batch data.
101+
If None, uses default_get_batch function.
102+
custom_post_process_for_micro_func: Optional custom function to post-process
103+
and split data into micro-batches. If None, uses default_post_process_for_micro_func.
103104
104105
Raises:
105106
ValueError: If input parameters are invalid.
@@ -108,41 +109,49 @@ def __init__(
108109
if micro_batch_size < 1:
109110
raise ValueError(f"micro_batch_size must be >= 1, got {micro_batch_size}")
110111

111-
if len(required_fields) < 1:
112-
raise ValueError(f"required_fields must be a list with at least one field name, got {required_fields}")
112+
if len(data_fields) < 1:
113+
raise ValueError(f"required_fields must be a list with at least one field name, got {data_fields}")
113114

114-
if data_replica_world_size < 1:
115-
raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1")
116-
117-
if data_replica_rank >= data_replica_world_size or data_replica_rank < 0:
118-
raise ValueError(
119-
f"data_replica_rank {data_replica_rank} must be greater than or equal to 0 and less than "
120-
f"data_replica_world_size {data_replica_world_size}"
121-
)
115+
if dp_rank < 0:
116+
raise ValueError(f"dp_rank {dp_rank} must be greater than or equal to 0")
122117

123118
self.config = config
119+
self.batch_size = batch_size
124120
self.micro_batch_size = micro_batch_size
125-
self.required_fields = required_fields
121+
self.data_fields = data_fields
126122
self.partition_id = partition_id
127123
self.task_name = task_name
128-
self.data_replica_group = data_replica_group
129-
self.data_replica_rank = data_replica_rank
130-
self.data_replica_world_size = data_replica_world_size
124+
self.dp_rank = dp_rank
125+
self.n_samples_per_prompt = n_samples_per_prompt
126+
self.get_batch_func = custom_get_batch_func if custom_get_batch_func else default_get_batch
127+
self.post_process_for_micro_func = (
128+
custom_post_process_for_micro_func
129+
if custom_post_process_for_micro_func
130+
else default_post_process_for_micro_func
131+
)
131132

132133
# Build sampling config for controller
133134
self.sampling_config = {
134-
"data_replica_group": self.data_replica_group,
135-
"data_replica_rank": self.data_replica_rank,
136-
"data_replica_world_size": self.data_replica_world_size,
135+
"dp_rank": self.dp_rank,
137136
"task_name": self.task_name,
138-
"partition_id": self.partition_id,
137+
"n_samples_per_prompt": self.n_samples_per_prompt,
139138
}
140139

141140
self._tq_client = None
141+
self.buffer: list[tuple] = []
142+
self.batch_index = 0
142143

143144
super().__init__()
144145

145146
def _create_client(self):
147+
"""Create and initialize a TransferQueue client.
148+
149+
This method initializes the TransferQueueClient with the provided configuration
150+
and storage backend, and sets up the storage manager for data retrieval.
151+
152+
Raises:
153+
ValueError: If controller_info or storage_backend is missing or invalid.
154+
"""
146155
client_id = uuid.uuid4().hex[:8]
147156
controller_info = self.config.get("controller_info", None)
148157
if not controller_info or not isinstance(controller_info, ZMQServerInfo):
@@ -175,30 +184,141 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
175184
# Note: For fully streamed production-consumption, please set the environment variable
176185
# TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately
177186
# determine consumption status even before producers have generated the samples.
178-
while not self._tq_client.check_consumption_status(self.task_name, self.partition_id):
187+
while (
188+
not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
189+
or self.batch_index <= len(self.buffer) - 1
190+
):
179191
try:
180-
# Get metadata from controller
181-
batch_meta = self._tq_client.get_meta(
182-
data_fields=self.required_fields,
183-
batch_size=self.micro_batch_size,
184-
partition_id=self.partition_id,
185-
task_name=self.task_name,
186-
sampling_config=self.sampling_config,
187-
)
188-
189-
# Check if we got valid data
190-
if batch_meta.size == 0:
191-
logger.debug(
192-
f"[StreamingDataset]: Received empty batch, waiting for more data... "
193-
f"Required batch_size={self.micro_batch_size}, data_fields={self.required_fields},"
194-
f"partition_id={self.partition_id}, task_name={self.task_name}."
195-
)
192+
if self.batch_index <= len(self.buffer) - 1:
193+
current_data = self.buffer[self.batch_index]
194+
self.batch_index += 1
195+
yield from self.post_process_for_micro_func(*current_data, micro_batch_size=self.micro_batch_size)
196196

197-
time.sleep(TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL)
198197
else:
199-
batch = self._tq_client.get_data(batch_meta)
200-
yield (batch, batch_meta)
198+
batch_data, batch_meta = self.get_batch_func(
199+
self._tq_client,
200+
self.data_fields,
201+
self.batch_size,
202+
self.partition_id,
203+
self.task_name,
204+
self.sampling_config,
205+
self.batch_index,
206+
)
207+
if batch_data is not None:
208+
self.buffer.append((batch_data, batch_meta))
201209

202210
except Exception as e:
203211
logger.error(f"[StreamingDataset]: Error in data iteration: {e}")
204212
raise
213+
214+
def reset(self):
215+
"""Reset the dataset iterator to the beginning.
216+
217+
Clears the buffer and resets the batch index for a fresh iteration.
218+
"""
219+
self.batch_index = 0
220+
221+
def step(self, partition_id):
222+
"""Switch to a new partition and reset the dataset state.
223+
224+
This method clears the buffer, resets the batch index, and updates the partition_id
225+
to fetch data from a different partition (e.g., switching from "train" to "val").
226+
227+
Args:
228+
partition_id: The new partition ID to switch to.
229+
"""
230+
self.buffer = []
231+
self.batch_index = 0
232+
self.partition_id = partition_id
233+
234+
235+
def default_get_batch(tq_client, data_fields, batch_size, partition_id, task_name, sampling_config, batch_index):
236+
"""Retrieve a batch of data from TransferQueue.
237+
238+
This function queries the TransferQueue controller for batch metadata and retrieves
239+
the actual data if available. It handles empty batches gracefully.
240+
241+
Args:
242+
tq_client: The TransferQueueClient instance for data retrieval.
243+
data_fields: List of field names to retrieve from the batch.
244+
batch_size: The requested batch size.
245+
partition_id: The partition ID for data versioning.
246+
task_name: Unique identifier for the training task.
247+
sampling_config: Configuration dictionary for sampling strategy.
248+
batch_index: Current batch index for tracking consumption progress.
249+
250+
Returns:
251+
tuple: A tuple containing:
252+
- batch: TensorDict with the retrieved data, or None if batch is empty.
253+
- batch_meta: BatchMeta object containing batch metadata.
254+
"""
255+
# Get metadata from controller
256+
sampling_config["batch_index"] = batch_index
257+
sampling_config["partition_id"] = partition_id
258+
batch_meta = tq_client.get_meta(
259+
data_fields=data_fields,
260+
batch_size=batch_size,
261+
partition_id=partition_id,
262+
task_name=task_name,
263+
sampling_config=sampling_config,
264+
)
265+
266+
# Check if we got valid data
267+
if batch_meta.size == 0:
268+
logger.debug(
269+
f"[StreamingDataset]: Received empty batch, waiting for more data... "
270+
f"Required batch_size={batch_size}, data_fields={data_fields},"
271+
f"partition_id={partition_id}, task_name={task_name}."
272+
)
273+
return None, batch_meta
274+
else:
275+
batch = tq_client.get_data(batch_meta)
276+
return batch, batch_meta
277+
278+
279+
def default_post_process_for_micro_func(td, batch_meta, micro_batch_size=1):
280+
"""Split TensorDict into micro-batches along the batch dimension.
281+
282+
This function chunks a TensorDict into smaller micro-batches with the specified size,
283+
along with corresponding metadata chunks. Handles cases where batch size is not
284+
evenly divisible by micro_batch_size.
285+
286+
Args:
287+
td: Input TensorDict with non-empty batch_size.
288+
batch_meta: BatchMeta object to be chunked along with the TensorDict.
289+
micro_batch_size: Target size for each micro-batch (positive integer, default: 1).
290+
291+
Returns:
292+
list: List of tuples (micro_batch_td, micro_batch_meta) where each tuple
293+
contains a TensorDict chunk and corresponding metadata chunk.
294+
295+
Raises:
296+
TypeError: If td is not a TensorDict.
297+
ValueError: If micro_batch_size is not a positive integer, batch_size is empty,
298+
or micro_batch_size exceeds total batch size.
299+
"""
300+
if not isinstance(td, TensorDict):
301+
raise TypeError(f"Expected TensorDict, got {type(td).__name__}")
302+
303+
if not isinstance(micro_batch_size, int) or micro_batch_size <= 0:
304+
raise ValueError(f"micro_batch_size must be a positive integer, got {micro_batch_size}")
305+
306+
if len(td.batch_size) == 0:
307+
raise ValueError("Input TensorDict must have non-empty batch_size")
308+
309+
total_size = td.batch_size[0]
310+
if micro_batch_size > total_size:
311+
raise ValueError(f"micro_batch_size ({micro_batch_size}) exceeds total batch size ({total_size})")
312+
313+
# Calculate number of splits (handles uneven division)
314+
num_splits = (total_size + micro_batch_size - 1) // micro_batch_size
315+
splits = []
316+
batch_meta_list = batch_meta.chunk(num_splits)
317+
318+
# Chunk the TensorDict and pair with corresponding metadata chunks
319+
for i in range(num_splits):
320+
start = i * micro_batch_size
321+
end = min(start + micro_batch_size, total_size)
322+
splits.append((td[start:end], batch_meta_list[i]))
323+
324+
return splits

0 commit comments

Comments
 (0)