forked from Ascend/TransferQueue
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathstreaming_dataloader.py
More file actions
161 lines (134 loc) · 5.98 KB
/
streaming_dataloader.py
File metadata and controls
161 lines (134 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from tensordict import TensorDict
from transfer_queue.dataloader.streaming_dataset import StreamingDataset
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.logging_utils import get_logger
logger = get_logger(__name__)
def _identity_collate_fn(data: tuple[TensorDict, BatchMeta]) -> tuple[TensorDict, BatchMeta]:
"""Identity collate function for TransferQueue.
This function acts as a pass-through, preserving the `(TensorDict, BatchMeta)`
structure yielded by `StreamingDataset`. It prevents PyTorch from attempting
to stack or modify the already-batched data.
"""
return data
class StreamingDataLoader(torch.utils.data.DataLoader):
"""StreamingDataLoader interface for TransferQueue.
This DataLoader wraps StreamingDataset and provides a PyTorch DataLoader
interface for distributed training with streaming data access.
Key Features:
- Compatible with PyTorch training loops (for loop iteration)
- Works with StreamingDataset for streaming data access
- Supports distributed training via RankAwareSampler coordination
Note:
This DataLoader is typically used with StreamingDataset which manages
batch size internally. The standard PyTorch DataLoader batch_size
parameter is set to None because batching is handled by the dataset
in coordination with TransferQueue's sampling logic.
Example:
>>> dataset = StreamingDataset(
... config=config,
... micro_batch_size=4,
... required_fields=["input_ids", "attention_mask"],
... partition_id="train",
... task_name="update_actor",
... data_replica_group=0,
... data_replica_rank=0,
... data_replica_world_size=1,
... )
>>> dataloader = StreamingDataLoader(dataset, num_workers=0)
>>> for batch, batch_meta in dataloader:
... # batch: TensorDict with requested fields
... # batch_meta: Metadata for TransferQueue coordination
... loss = model(batch)
... loss.backward()
"""
def __init__(
self,
dataset: StreamingDataset,
num_workers: int = 0,
collate_fn=None,
pin_memory: bool = False,
worker_init_fn=None,
multiprocessing_context=None,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
):
"""Initialize the StreamingDataLoader.
Args:
dataset: StreamingDataset instance.
num_workers: Number of subprocesses for data loading.
collate_fn: Function to collate samples into batches.
pin_memory: If True, pin memory for GPU transfer.
worker_init_fn: Worker initialization function.
multiprocessing_context: Multiprocessing context.
prefetch_factor: Number of batches to prefetch per worker.
persistent_workers: Keep workers alive between epochs.
pin_memory_device: Device for pin_memory.
Note:
This DataLoader is designed to work with StreamingDataset which handles
batch size internally via the micro_batch_size parameter. The batch_size
parameter in PyTorch DataLoader is set to None because batching is managed
by the StreamingDataset in coordination with RankAwareSampler.
"""
self.dataset: StreamingDataset = dataset
if collate_fn is None:
# use identical collate function to directly return the self-defined
# [TensorDict, BatchMeta] output of StreamingDataset
final_collate_fn = _identity_collate_fn
else:
final_collate_fn = collate_fn
super().__init__(
dataset=dataset,
batch_size=None, # Batch size is handled by the dataset
shuffle=None,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=final_collate_fn,
pin_memory=pin_memory,
drop_last=False,
timeout=0,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
generator=None,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
)
def reset(self):
"""Reset the dataset iterator to the beginning.
Clears the buffer and resets the batch index for a fresh iteration.
"""
self.dataset.reset()
def step(self, partition_id):
"""Switch to a new partition and reset the dataset state.
This method clears the buffer, resets the batch index, and updates the partition_id
to fetch data from a different partition (e.g., switching from "train" to "val").
Args:
partition_id: The new partition ID to switch to.
"""
self.dataset.step(partition_id)
def get_buffer(self):
"""Get the current buffer from the underlying dataset.
Returns the batch buffer maintained by StreamingDataset, which stores
pre-fetched batches for efficient data access.
Returns:
list: Buffer containing pre-fetched (TensorDict, BatchMeta) tuples.
"""
return self.dataset.buffer