11from pathlib import Path
22import os
33import pickle
4+ import socket
45
56import torch
67
7- from xtuner .v1 .datasets import build_dataloader , build_datasets , get_dataloader_state , load_dataloader_state , FTDPTokenizeFnConfig , DatasetConfig , DataloaderConfig
8+ from xtuner .v1 .datasets import (
9+ DataloaderConfig ,
10+ DatasetConfig ,
11+ FTDPTokenizeFnConfig ,
12+ build_dataloader ,
13+ build_datasets ,
14+ get_dataloader_state ,
15+ load_dataloader_state ,
16+ )
817from xtuner .v1 .train .toy_tokenizer import UTF8ByteTokenizer
9- from torch .multiprocessing import spawn , get_context
18+ from torch .multiprocessing import spawn
1019from torch .distributed .device_mesh import init_device_mesh
1120import pytest
1221
1524from itertools import repeat , chain
1625
1726
27+ def _alloc_master_port () -> None :
28+ """Bind an ephemeral TCP port so concurrent test runs avoid EADDRINUSE on a fixed port."""
29+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
30+ s .bind (("127.0.0.1" , 0 ))
31+ os .environ ["MASTER_PORT" ] = str (s .getsockname ()[1 ])
32+
1833
1934
2035class RandomDataset :
@@ -282,65 +297,53 @@ def _test_resume_spmd(
282297 rank : int ,
283298 world_size : int ,
284299 dataloader_config : DataloaderConfig ,
285- dataset_configs : list [dict ],
286300 global_batch_size : int ,
287301 micro_batch_size : int ,
288- step :int ,
302+ step : int ,
289303 seed : int ,
290304 save_path : Path ,
291305 dataloader_state : dict | None = None ,
292- consumed_samples : int = 0 ,
293306):
294307 os .environ ["RANK" ] = str (rank )
295308 os .environ ["LOCAL_RANK" ] = str (rank )
296309 os .environ ["WORLD_SIZE" ] = str (world_size )
297- os .environ [ "MASTER_ADDR" ] = "localhost"
298- os . environ [ "MASTER_PORT" ] = "29505"
299-
310+ os .environ . setdefault ( "MASTER_ADDR" , "localhost" )
311+ if "MASTER_PORT" not in os . environ :
312+ raise RuntimeError ( "tests must call _alloc_master_port() before torch.multiprocessing.spawn" )
300313
301314 torch .distributed .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
302315 torch .cuda .set_device (rank )
303316 data_mesh = init_device_mesh (
304317 device_type = "cuda" ,
305- mesh_shape = (world_size ,)
318+ mesh_shape = (world_size ,),
306319 )
307320 tokenizer = UTF8ByteTokenizer ()
308321
309- datasets = build_datasets (
310- dataset_config = dataset_configs ,
322+ dataloader = dataloader_config .build (
311323 tokenizer = tokenizer ,
312- )
313- dataloader = build_dataloader (
314- dataloader_config = dataloader_config ,
315- datasets = datasets ,
324+ dp_mesh = data_mesh ,
316325 global_batch_size = global_batch_size ,
317326 micro_batch_size = micro_batch_size ,
318327 seed = seed ,
319- dp_mesh = data_mesh ,
320328 )
321329
322330 if dataloader_state is not None :
323- load_dataloader_state ( dataloader , dataloader_state )
331+ dataloader . load_state_dict ( dataloader_state )
324332
325333 data_iter = iter (dataloader )
326334 data_list = []
327335 for _ in range (step ):
328336 batch = next (data_iter )
329337 data_list .append (batch )
330- consumed_samples += len (batch )
331338
332- consumed_samples_list = [None for _ in range (world_size )]
333- torch .distributed .all_gather_object (consumed_samples_list , consumed_samples )
334- global_consumed_samples = sum (consumed_samples_list )
339+ # Snapshot after the first `step` batches so total_consumed_steps matches resume intent.
340+ dataloader_state = dataloader .get_state_dict ()
335341
336342 expected_data = []
337-
338343 for _ in range (step ):
339344 batch = next (data_iter )
340345 expected_data .append (batch )
341346
342- dataloader_state = get_dataloader_state (dataloader , global_consumed_samples )
343-
344347 all_data_list = [None for _ in range (world_size )]
345348 torch .distributed .all_gather_object (all_data_list , list (chain (* data_list )))
346349
@@ -372,7 +375,6 @@ def _test_resume_spmd(
372375 "dataloader_state" : dataloader_state ,
373376 "data_list" : all_data_list ,
374377 "expected_data" : all_expected_data ,
375- "consumed_samples" : consumed_samples
376378 }
377379 )
378380 )
@@ -389,7 +391,6 @@ def _test_resume_spmd(
389391 ("none" , 0 , False ),
390392 ("soft" , 0 , True ),
391393 ("soft" , 4 , True ),
392- ("soft" , 4 , True ),
393394 ]
394395)
395396def test_dataloader_resume_multi_process (tmp_path , pack_level , num_workers , group_by_length ):
@@ -402,36 +403,36 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
402403 _create_fake_dataset (data_dir1 / f"depth3" , dataset_num = 3 , max_depth = 3 , dup_times = 9 )
403404
404405 # 1. Test resuming with the same world size
406+ dataset_configs = [
407+ {
408+ "dataset" : DatasetConfig (anno_path = str (data_dir1 )),
409+ "tokenize_fn" : FTDPTokenizeFnConfig (max_length = 1024 ),
410+ },
411+ ]
412+
405413 dataloader_config = DataloaderConfig (
414+ dataset_config_list = dataset_configs ,
406415 pack_max_length = 1024 ,
407416 pack_level = pack_level ,
408417 num_workers = num_workers ,
409418 group_by_length = group_by_length ,
410- collator = "fake_collator"
419+ collator = "fake_collator" ,
411420 )
412- dataset_configs = [
413- {
414- "dataset" : DatasetConfig (anno_path = str (data_dir1 )),
415- "tokenize_fn" : FTDPTokenizeFnConfig (max_length = 1024 )
416- },
417- ]
418421
419- ctx = get_context ("spawn" )
420422 world_size = 2
421423 save_path1 = tmp_path / "dataloader_state.pkl"
424+ _alloc_master_port ()
422425 spawn (
423426 _test_resume_spmd ,
424427 args = (
425428 world_size ,
426429 dataloader_config ,
427- dataset_configs ,
428430 16 ,
429431 BATCH_SIZE ,
430432 TOTAL_STEP ,
431433 10 ,
432434 save_path1 ,
433435 None ,
434- 0 ,
435436 ),
436437 nprocs = 2 ,
437438 join = True ,
@@ -443,19 +444,18 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
443444
444445 # 2. tet Rsume with same world size
445446 save_path2 = tmp_path / "dataloader_state2.pkl"
447+ _alloc_master_port ()
446448 spawn (
447449 _test_resume_spmd ,
448450 args = (
449451 world_size ,
450452 dataloader_config ,
451- dataset_configs ,
452453 16 ,
453454 BATCH_SIZE ,
454455 TOTAL_STEP ,
455456 10 ,
456457 save_path2 ,
457458 result1 ["dataloader_state" ],
458- result1 ["consumed_samples" ],
459459 ),
460460 nprocs = world_size ,
461461 join = True ,
@@ -470,19 +470,18 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
470470
471471 world_size = 4
472472 save_path3 = tmp_path / "dataloader_state3.pkl"
473+ _alloc_master_port ()
473474 spawn (
474475 _test_resume_spmd ,
475476 args = (
476477 world_size ,
477478 dataloader_config ,
478- dataset_configs ,
479479 16 ,
480480 BATCH_SIZE ,
481481 TOTAL_STEP ,
482482 10 ,
483483 save_path3 ,
484484 result1 ["dataloader_state" ],
485- result1 ["consumed_samples" ],
486485 ),
487486 nprocs = world_size ,
488487 join = True ,
0 commit comments