Skip to content

Commit f049f4b

Browse files
committed
refine test dataloader ut
1 parent ad5f962 commit f049f4b

3 files changed

Lines changed: 22 additions & 35 deletions

File tree

tests/datasets/test_dataloader.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
DatasetConfig,
1111
FTDPTokenizeFnConfig,
1212
build_dataloader,
13-
build_datasets,
14-
get_dataloader_state,
15-
load_dataloader_state,
1613
)
1714
from xtuner.v1.train.toy_tokenizer import UTF8ByteTokenizer
1815
from torch.multiprocessing import spawn
@@ -197,25 +194,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
197194
dataset_configs = [
198195
{
199196
"dataset": DatasetConfig(anno_path=str(data_dir1)),
200-
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024)
197+
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024),
201198
},
202199
]
203200

204201
dataloader_config = DataloaderConfig(
202+
dataset_config_list=dataset_configs,
205203
pack_max_length=1024,
206204
pack_level=pack_level,
207205
num_workers=num_workers,
208206
group_by_length=group_by_length,
209207
pack_workers=pack_workers,
210208
)
211209

212-
datasets = build_datasets(
213-
dataset_config=dataset_configs,
210+
dataloader1 = dataloader_config.build(
214211
tokenizer=tokenizer,
215-
)
216-
dataloader1 = build_dataloader(
217-
dataloader_config=dataloader_config,
218-
datasets=datasets,
212+
dp_mesh=None,
219213
global_batch_size=GLOBAL_BATCH_SIZE,
220214
micro_batch_size=BATCH_SIZE,
221215
seed=10,
@@ -225,26 +219,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
225219
assert len(dataloader1) > 10
226220

227221
dataloader_iter = iter(dataloader1)
228-
consumed_sample = 0
229222
for _ in range(RESUME_ITER):
230-
batch = next(dataloader_iter)
231-
consumed_sample += len(batch)
223+
next(dataloader_iter)
232224

233-
dataloader_state = get_dataloader_state(dataloader1, consumed_sample)
225+
dataloader_state = dataloader1.get_state_dict()
234226
expected_data = []
235227
for _ in range(AFTER_RESUME_ITER):
236-
batch = next(dataloader_iter)
237-
consumed_sample += len(batch)
238-
expected_data.append(batch)
228+
expected_data.append(next(dataloader_iter))
239229

240-
new_dataloader1 = build_dataloader(
241-
dataloader_config=dataloader_config,
242-
datasets=datasets,
230+
new_dataloader1 = dataloader_config.build(
231+
tokenizer=tokenizer,
232+
dp_mesh=None,
243233
global_batch_size=GLOBAL_BATCH_SIZE,
244234
micro_batch_size=BATCH_SIZE,
245235
seed=10,
246236
)
247-
load_dataloader_state(new_dataloader1, dataloader_state)
237+
new_dataloader1.load_state_dict(dataloader_state)
248238
new_dataloader_iter = iter(new_dataloader1)
249239

250240
resume_data = []
@@ -257,32 +247,29 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
257247
# 2. Test resume after consuming multiple epochs
258248
while True:
259249
try:
260-
batch = next(dataloader_iter)
261-
consumed_sample += len(batch)
250+
next(dataloader_iter)
262251
except StopIteration:
263252
break
264253

265-
266254
dataloader_iter = iter(dataloader1)
267255

268-
for batch in range(RESUME_ITER):
269-
batch = next(dataloader_iter)
270-
consumed_sample += len(batch)
256+
for _ in range(RESUME_ITER):
257+
next(dataloader_iter)
271258

272-
dataloader_state = get_dataloader_state(dataloader1, consumed_sample)
259+
dataloader_state = dataloader1.get_state_dict()
273260

274261
expected_data = []
275262
for _ in range(AFTER_RESUME_ITER):
276263
expected_data.append(next(dataloader_iter))
277264

278-
new_dataloader2 = build_dataloader(
279-
dataloader_config=dataloader_config,
280-
datasets=datasets,
265+
new_dataloader2 = dataloader_config.build(
266+
tokenizer=tokenizer,
267+
dp_mesh=None,
281268
global_batch_size=GLOBAL_BATCH_SIZE,
282269
micro_batch_size=BATCH_SIZE,
283270
seed=10,
284271
)
285-
load_dataloader_state(new_dataloader2, dataloader_state)
272+
new_dataloader2.load_state_dict(dataloader_state)
286273
new_dataloader_iter2 = iter(new_dataloader2)
287274

288275
resume_data = []

xtuner/v1/datasets/preset_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __init__(
163163
def __iter__(self) -> Iterator[int]:
164164
# load order from npy → global_order → rank_view 类型均为 memmap, 子视图 的路径仍然保持
165165
# memmap 语义(视图、按需分页、文件后端);单机多进程可共享同一份文件页缓存
166-
yield from (int(idx) for idx in self.global_order[self.step + self.rank : self.total_size : self.world_size])
166+
yield from self.global_order[self.step + self.rank : self.total_size : self.world_size]
167167
self.step = 0
168168

169169
def __len__(self) -> int:

xtuner/v1/datasets/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __iter__(self) -> Iterator[int]:
113113
# subsample
114114
indices = indices[self.step + self.rank : self.total_size : self.world_size]
115115

116-
yield from indices
116+
yield from iter(indices)
117117
self.step = 0
118118

119119
def __len__(self) -> int:
@@ -266,7 +266,7 @@ def __iter__(self) -> Iterator[int]:
266266
assert len(indices) == self.total_size
267267
indices = indices[self.step + self.rank : self.total_size : self.world_size]
268268
assert len(indices) == self.num_samples - self.step // self.world_size
269-
yield from indices
269+
yield from iter(indices)
270270
self.step = 0
271271

272272
def __len__(self) -> int:

0 commit comments

Comments
 (0)