Skip to content

Commit 145fea3

Browse files
committed
[StreamingDataLoader, 5/N] feat: Refactor the StreamDataLoader implementation to support fully asynchronous mode
Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
1 parent aea4d2f commit 145fea3

12 files changed

Lines changed: 485 additions & 384 deletions

File tree

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ batch_meta = client.get_meta(
247247
batch_size=8,
248248
partition_id="train_0",
249249
task_name="generate_sequences",
250-
sampling_config={"n_samples_per_prompt": 4} # Put the required sampling parameters here
251250
)
252251
```
253252

recipe/simple_use_case/async_demo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def _initialize_data_system(self):
226226
# self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)
227227

228228
# Then use sampling_config in get_meta calls:
229-
# sampling_config={"n_samples_per_prompt": 4}
230229
self.data_system_controller = TransferQueueController.remote()
231230
logger.info("TransferQueueController has been created.")
232231

recipe/simple_use_case/sync_demo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def initialize_data_system(config):
6868
# data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)
6969

7070
# Then use sampling_config in get_meta calls:
71-
# sampling_config={"n_samples_per_prompt": 4}
7271
data_system_controller = TransferQueueController.remote()
7372
logger.info("TransferQueueController has been created.")
7473

tests/test_samplers.py

Lines changed: 173 additions & 242 deletions
Large diffs are not rendered by default.

transfer_queue/client.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ async def async_get_meta(
179179
- 'insert': Internal usage - should not be used by users
180180
task_name: Optional task name associated with the request
181181
sampling_config: Optional sampling configuration for custom samplers.
182-
For GRPOGroupNSampler, should include "n_samples_per_prompt": int
183182
socket: ZMQ async socket for message transmission (injected by decorator)
184183
185184
Returns:
@@ -206,7 +205,6 @@ async def async_get_meta(
206205
... partition_id="train_0",
207206
... mode="fetch",
208207
... task_name="generate_sequences",
209-
... sampling_config={"n_samples_per_prompt": 4}
210208
... ))
211209
>>> print(batch_meta.is_ready) # True if all samples ready
212210
>>>
@@ -698,7 +696,7 @@ async def async_check_consumption_status(
698696
partition_id=partition_id,
699697
)
700698

701-
if consumption_status is None:
699+
if consumption_status is None or consumption_status.numel() == 0:
702700
return False
703701
return torch.all(consumption_status == 1).item()
704702

@@ -883,7 +881,6 @@ def get_meta(
883881
partition_id: Target data partition id
884882
task_name: Optional task name associated with the request
885883
sampling_config: Optional sampling configuration for custom samplers.
886-
For GRPOGroupNSampler, should include "n_samples_per_prompt": int
887884
888885
Returns:
889886
BatchMeta: Batch metadata containing data location information

transfer_queue/controller.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ def __init__(
781781
- If a BaseSampler subclass is provided, it will be instantiated
782782
- Defaults to SequentialSampler for simple sequential sampling
783783
- Example: sampler=GRPOGroupNSampler() (instance)
784-
- Example: sampler=GRPOGroupNSampler (class)
784+
- Example: sampler=SequentialSampler (class)
785785
polling_mode: Whether to use polling mode for TransferQueue controller.
786786
- If False, the controller will raise an error when no enough data is available.
787787
- If True, the controller will return an empty BatchMeta when no enough data is available.
@@ -1015,12 +1015,12 @@ def get_metadata(
10151015
Raises:
10161016
TimeoutError: If waiting for sufficient data times out in fetch mode
10171017
"""
1018-
if partition_id not in self.partitions:
1019-
self.create_partition(partition_id)
10201018

10211019
if mode == "insert":
1022-
partition = self._get_partition(partition_id)
1020+
if partition_id not in self.partitions:
1021+
self.create_partition(partition_id)
10231022

1023+
partition = self._get_partition(partition_id)
10241024
if data_fields:
10251025
# This is called during put_data call without providing metadata.
10261026
# try to use pre-allocated global index first
@@ -1083,6 +1083,7 @@ def get_metadata(
10831083
ready_for_consume_indexes,
10841084
batch_size,
10851085
**(sampling_config or {}),
1086+
**kwargs,
10861087
)
10871088

10881089
# Check if we got valid results from the sampler
@@ -1240,6 +1241,7 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True):
12401241
partition.clear_data(global_indexes_range, clear_consumption)
12411242
self.index_manager.release_partition(partition_id)
12421243
self.partitions.pop(partition_id)
1244+
self.sampler.clear_cache(partition_id)
12431245

12441246
def clear_meta(
12451247
self,

transfer_queue/dataloader/streaming_dataloader.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
parameter in PyTorch DataLoader is set to None because batching is managed
112112
by the StreamingDataset in coordination with RankAwareSampler.
113113
"""
114+
self.dataset: StreamingDataset = dataset
114115

115116
if collate_fn is None:
116117
# use identical collate function to directly return the self-defined
@@ -137,3 +138,32 @@ def __init__(
137138
persistent_workers=persistent_workers,
138139
pin_memory_device=pin_memory_device,
139140
)
141+
142+
def reset(self):
143+
"""Reset the dataset iterator to the beginning.
144+
145+
Clears the buffer and resets the batch index for a fresh iteration.
146+
"""
147+
self.dataset.reset()
148+
149+
def step(self, partition_id):
150+
"""Switch to a new partition and reset the dataset state.
151+
152+
This method clears the buffer, resets the batch index, and updates the partition_id
153+
to fetch data from a different partition (e.g., switching from "train" to "val").
154+
155+
Args:
156+
partition_id: The new partition ID to switch to.
157+
"""
158+
self.dataset.step(partition_id)
159+
160+
def get_buffer(self):
161+
"""Get the current buffer from the underlying dataset.
162+
163+
Returns the batch buffer maintained by StreamingDataset, which stores
164+
pre-fetched batches for efficient data access.
165+
166+
Returns:
167+
list: Buffer containing pre-fetched (TensorDict, BatchMeta) tuples.
168+
"""
169+
return self.dataset.buffer

0 commit comments

Comments
 (0)