11from pathlib import Path
22import os
33import pickle
4+ import socket
45
56import torch
67
7- from xtuner .v1 .datasets import build_dataloader , build_datasets , get_dataloader_state , load_dataloader_state , FTDPTokenizeFnConfig , DatasetConfig , DataloaderConfig
8+ from xtuner .v1 .datasets import (
9+ DataloaderConfig ,
10+ DatasetConfig ,
11+ FTDPTokenizeFnConfig ,
12+ build_dataloader ,
13+ )
814from xtuner .v1 .train .toy_tokenizer import UTF8ByteTokenizer
9- from torch .multiprocessing import spawn , get_context
15+ from torch .multiprocessing import spawn
1016from torch .distributed .device_mesh import init_device_mesh
1117import pytest
1218
1521from itertools import repeat , chain
1622
1723
18-
19-
2024class RandomDataset :
2125 def __init__ (self , size : int , ** kwargs ):
2226 self .size = size
@@ -182,25 +186,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
182186 dataset_configs = [
183187 {
184188 "dataset" : DatasetConfig (anno_path = str (data_dir1 )),
185- "tokenize_fn" : FTDPTokenizeFnConfig (max_length = 1024 )
189+ "tokenize_fn" : FTDPTokenizeFnConfig (max_length = 1024 ),
186190 },
187191 ]
188192
189193 dataloader_config = DataloaderConfig (
194+ dataset_config_list = dataset_configs ,
190195 pack_max_length = 1024 ,
191196 pack_level = pack_level ,
192197 num_workers = num_workers ,
193198 group_by_length = group_by_length ,
194199 pack_workers = pack_workers ,
195200 )
196201
197- datasets = build_datasets (
198- dataset_config = dataset_configs ,
202+ dataloader1 = dataloader_config .build (
199203 tokenizer = tokenizer ,
200- )
201- dataloader1 = build_dataloader (
202- dataloader_config = dataloader_config ,
203- datasets = datasets ,
204+ dp_mesh = None ,
204205 global_batch_size = GLOBAL_BATCH_SIZE ,
205206 micro_batch_size = BATCH_SIZE ,
206207 seed = 10 ,
@@ -210,26 +211,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
210211 assert len (dataloader1 ) > 10
211212
212213 dataloader_iter = iter (dataloader1 )
213- consumed_sample = 0
214214 for _ in range (RESUME_ITER ):
215- batch = next (dataloader_iter )
216- consumed_sample += len (batch )
215+ next (dataloader_iter )
217216
218- dataloader_state = get_dataloader_state ( dataloader1 , consumed_sample )
217+ dataloader_state = dataloader1 . get_state_dict ( )
219218 expected_data = []
220219 for _ in range (AFTER_RESUME_ITER ):
221- batch = next (dataloader_iter )
222- consumed_sample += len (batch )
223- expected_data .append (batch )
220+ expected_data .append (next (dataloader_iter ))
224221
225- new_dataloader1 = build_dataloader (
226- dataloader_config = dataloader_config ,
227- datasets = datasets ,
222+ new_dataloader1 = dataloader_config . build (
223+ tokenizer = tokenizer ,
224+ dp_mesh = None ,
228225 global_batch_size = GLOBAL_BATCH_SIZE ,
229226 micro_batch_size = BATCH_SIZE ,
230227 seed = 10 ,
231228 )
232- load_dataloader_state ( new_dataloader1 , dataloader_state )
229+ new_dataloader1 . load_state_dict ( dataloader_state )
233230 new_dataloader_iter = iter (new_dataloader1 )
234231
235232 resume_data = []
@@ -242,32 +239,29 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
242239 # 2. Test resume after consuming multiple epochs
243240 while True :
244241 try :
245- batch = next (dataloader_iter )
246- consumed_sample += len (batch )
242+ next (dataloader_iter )
247243 except StopIteration :
248244 break
249245
250-
251246 dataloader_iter = iter (dataloader1 )
252247
253- for batch in range (RESUME_ITER ):
254- batch = next (dataloader_iter )
255- consumed_sample += len (batch )
248+ for _ in range (RESUME_ITER ):
249+ next (dataloader_iter )
256250
257- dataloader_state = get_dataloader_state ( dataloader1 , consumed_sample )
251+ dataloader_state = dataloader1 . get_state_dict ( )
258252
259253 expected_data = []
260254 for _ in range (AFTER_RESUME_ITER ):
261255 expected_data .append (next (dataloader_iter ))
262256
263- new_dataloader2 = build_dataloader (
264- dataloader_config = dataloader_config ,
265- datasets = datasets ,
257+ new_dataloader2 = dataloader_config . build (
258+ tokenizer = tokenizer ,
259+ dp_mesh = None ,
266260 global_batch_size = GLOBAL_BATCH_SIZE ,
267261 micro_batch_size = BATCH_SIZE ,
268262 seed = 10 ,
269263 )
270- load_dataloader_state ( new_dataloader2 , dataloader_state )
264+ new_dataloader2 . load_state_dict ( dataloader_state )
271265 new_dataloader_iter2 = iter (new_dataloader2 )
272266
273267 resume_data = []
@@ -282,65 +276,52 @@ def _test_resume_spmd(
282276 rank : int ,
283277 world_size : int ,
284278 dataloader_config : DataloaderConfig ,
285- dataset_configs : list [dict ],
286279 global_batch_size : int ,
287280 micro_batch_size : int ,
288- step :int ,
281+ step : int ,
289282 seed : int ,
290283 save_path : Path ,
291284 dataloader_state : dict | None = None ,
292- consumed_samples : int = 0 ,
293285):
294286 os .environ ["RANK" ] = str (rank )
295287 os .environ ["LOCAL_RANK" ] = str (rank )
296288 os .environ ["WORLD_SIZE" ] = str (world_size )
297289 os .environ ["MASTER_ADDR" ] = "localhost"
298290 os .environ ["MASTER_PORT" ] = "29505"
299291
300-
301292 torch .distributed .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
302293 torch .cuda .set_device (rank )
303294 data_mesh = init_device_mesh (
304295 device_type = "cuda" ,
305- mesh_shape = (world_size ,)
296+ mesh_shape = (world_size ,),
306297 )
307298 tokenizer = UTF8ByteTokenizer ()
308299
309- datasets = build_datasets (
310- dataset_config = dataset_configs ,
300+ dataloader = dataloader_config .build (
311301 tokenizer = tokenizer ,
312- )
313- dataloader = build_dataloader (
314- dataloader_config = dataloader_config ,
315- datasets = datasets ,
302+ dp_mesh = data_mesh ,
316303 global_batch_size = global_batch_size ,
317304 micro_batch_size = micro_batch_size ,
318305 seed = seed ,
319- dp_mesh = data_mesh ,
320306 )
321307
322308 if dataloader_state is not None :
323- load_dataloader_state ( dataloader , dataloader_state )
309+ dataloader . load_state_dict ( dataloader_state )
324310
325311 data_iter = iter (dataloader )
326312 data_list = []
327313 for _ in range (step ):
328314 batch = next (data_iter )
329315 data_list .append (batch )
330- consumed_samples += len (batch )
331316
332- consumed_samples_list = [None for _ in range (world_size )]
333- torch .distributed .all_gather_object (consumed_samples_list , consumed_samples )
334- global_consumed_samples = sum (consumed_samples_list )
317+ # Snapshot after the first `step` batches so total_consumed_samples matches resume intent.
318+ dataloader_state = dataloader .get_state_dict ()
335319
336320 expected_data = []
337-
338321 for _ in range (step ):
339322 batch = next (data_iter )
340323 expected_data .append (batch )
341324
342- dataloader_state = get_dataloader_state (dataloader , global_consumed_samples )
343-
344325 all_data_list = [None for _ in range (world_size )]
345326 torch .distributed .all_gather_object (all_data_list , list (chain (* data_list )))
346327
@@ -372,7 +353,6 @@ def _test_resume_spmd(
372353 "dataloader_state" : dataloader_state ,
373354 "data_list" : all_data_list ,
374355 "expected_data" : all_expected_data ,
375- "consumed_samples" : consumed_samples
376356 }
377357 )
378358 )
@@ -389,7 +369,6 @@ def _test_resume_spmd(
389369 ("none" , 0 , False ),
390370 ("soft" , 0 , True ),
391371 ("soft" , 4 , True ),
392- ("soft" , 4 , True ),
393372 ]
394373)
395374def test_dataloader_resume_multi_process (tmp_path , pack_level , num_workers , group_by_length ):
@@ -402,36 +381,35 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
402381 _create_fake_dataset (data_dir1 / f"depth3" , dataset_num = 3 , max_depth = 3 , dup_times = 9 )
403382
404383 # 1. Test resuming with the same world size
384+ dataset_configs = [
385+ {
386+ "dataset" : DatasetConfig (anno_path = str (data_dir1 )),
387+ "tokenize_fn" : FTDPTokenizeFnConfig (max_length = 1024 ),
388+ },
389+ ]
390+
405391 dataloader_config = DataloaderConfig (
392+ dataset_config_list = dataset_configs ,
406393 pack_max_length = 1024 ,
407394 pack_level = pack_level ,
408395 num_workers = num_workers ,
409396 group_by_length = group_by_length ,
410- collator = "fake_collator"
397+ collator = "fake_collator" ,
411398 )
412- dataset_configs = [
413- {
414- "dataset" : DatasetConfig (anno_path = str (data_dir1 )),
415- "tokenize_fn" : FTDPTokenizeFnConfig (max_length = 1024 )
416- },
417- ]
418399
419- ctx = get_context ("spawn" )
420400 world_size = 2
421401 save_path1 = tmp_path / "dataloader_state.pkl"
422402 spawn (
423403 _test_resume_spmd ,
424404 args = (
425405 world_size ,
426406 dataloader_config ,
427- dataset_configs ,
428407 16 ,
429408 BATCH_SIZE ,
430409 TOTAL_STEP ,
431410 10 ,
432411 save_path1 ,
433412 None ,
434- 0 ,
435413 ),
436414 nprocs = 2 ,
437415 join = True ,
@@ -448,14 +426,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
448426 args = (
449427 world_size ,
450428 dataloader_config ,
451- dataset_configs ,
452429 16 ,
453430 BATCH_SIZE ,
454431 TOTAL_STEP ,
455432 10 ,
456433 save_path2 ,
457434 result1 ["dataloader_state" ],
458- result1 ["consumed_samples" ],
459435 ),
460436 nprocs = world_size ,
461437 join = True ,
@@ -475,14 +451,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
475451 args = (
476452 world_size ,
477453 dataloader_config ,
478- dataset_configs ,
479454 16 ,
480455 BATCH_SIZE ,
481456 TOTAL_STEP ,
482457 10 ,
483458 save_path3 ,
484459 result1 ["dataloader_state" ],
485- result1 ["consumed_samples" ],
486460 ),
487461 nprocs = world_size ,
488462 join = True ,
0 commit comments