Skip to content

Commit ad5f962

Browse files
committed
fix sampler save when dataloader num_workers > 0
1 parent 5c783a7 commit ad5f962

4 files changed

Lines changed: 59 additions & 50 deletions

File tree

tests/datasets/test_dataloader.py

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from pathlib import Path
22
import os
33
import pickle
4+
import socket
45

56
import torch
67

7-
from xtuner.v1.datasets import build_dataloader, build_datasets, get_dataloader_state, load_dataloader_state, FTDPTokenizeFnConfig, DatasetConfig, DataloaderConfig
8+
from xtuner.v1.datasets import (
9+
DataloaderConfig,
10+
DatasetConfig,
11+
FTDPTokenizeFnConfig,
12+
build_dataloader,
13+
build_datasets,
14+
get_dataloader_state,
15+
load_dataloader_state,
16+
)
817
from xtuner.v1.train.toy_tokenizer import UTF8ByteTokenizer
9-
from torch.multiprocessing import spawn, get_context
18+
from torch.multiprocessing import spawn
1019
from torch.distributed.device_mesh import init_device_mesh
1120
import pytest
1221

@@ -15,6 +24,12 @@
1524
from itertools import repeat, chain
1625

1726

27+
def _alloc_master_port() -> None:
28+
"""Bind an ephemeral TCP port so concurrent test runs avoid EADDRINUSE on a fixed port."""
29+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
30+
s.bind(("127.0.0.1", 0))
31+
os.environ["MASTER_PORT"] = str(s.getsockname()[1])
32+
1833

1934

2035
class RandomDataset:
@@ -282,65 +297,53 @@ def _test_resume_spmd(
282297
rank: int,
283298
world_size: int,
284299
dataloader_config: DataloaderConfig,
285-
dataset_configs: list[dict],
286300
global_batch_size: int,
287301
micro_batch_size: int,
288-
step:int,
302+
step: int,
289303
seed: int,
290304
save_path: Path,
291305
dataloader_state: dict | None = None,
292-
consumed_samples: int = 0,
293306
):
294307
os.environ["RANK"] = str(rank)
295308
os.environ["LOCAL_RANK"] = str(rank)
296309
os.environ["WORLD_SIZE"] = str(world_size)
297-
os.environ["MASTER_ADDR"] = "localhost"
298-
os.environ["MASTER_PORT"] = "29505"
299-
310+
os.environ.setdefault("MASTER_ADDR", "localhost")
311+
if "MASTER_PORT" not in os.environ:
312+
raise RuntimeError("tests must call _alloc_master_port() before torch.multiprocessing.spawn")
300313

301314
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
302315
torch.cuda.set_device(rank)
303316
data_mesh = init_device_mesh(
304317
device_type="cuda",
305-
mesh_shape=(world_size,)
318+
mesh_shape=(world_size,),
306319
)
307320
tokenizer = UTF8ByteTokenizer()
308321

309-
datasets = build_datasets(
310-
dataset_config=dataset_configs,
322+
dataloader = dataloader_config.build(
311323
tokenizer=tokenizer,
312-
)
313-
dataloader = build_dataloader(
314-
dataloader_config=dataloader_config,
315-
datasets=datasets,
324+
dp_mesh=data_mesh,
316325
global_batch_size=global_batch_size,
317326
micro_batch_size=micro_batch_size,
318327
seed=seed,
319-
dp_mesh=data_mesh,
320328
)
321329

322330
if dataloader_state is not None:
323-
load_dataloader_state(dataloader, dataloader_state)
331+
dataloader.load_state_dict(dataloader_state)
324332

325333
data_iter = iter(dataloader)
326334
data_list = []
327335
for _ in range(step):
328336
batch = next(data_iter)
329337
data_list.append(batch)
330-
consumed_samples += len(batch)
331338

332-
consumed_samples_list = [None for _ in range(world_size)]
333-
torch.distributed.all_gather_object(consumed_samples_list, consumed_samples)
334-
global_consumed_samples = sum(consumed_samples_list)
339+
# Snapshot after the first `step` batches so total_consumed_steps matches resume intent.
340+
dataloader_state = dataloader.get_state_dict()
335341

336342
expected_data = []
337-
338343
for _ in range(step):
339344
batch = next(data_iter)
340345
expected_data.append(batch)
341346

342-
dataloader_state = get_dataloader_state(dataloader, global_consumed_samples)
343-
344347
all_data_list = [None for _ in range(world_size)]
345348
torch.distributed.all_gather_object(all_data_list, list(chain(*data_list)))
346349

@@ -372,7 +375,6 @@ def _test_resume_spmd(
372375
"dataloader_state": dataloader_state,
373376
"data_list": all_data_list,
374377
"expected_data": all_expected_data,
375-
"consumed_samples": consumed_samples
376378
}
377379
)
378380
)
@@ -389,7 +391,6 @@ def _test_resume_spmd(
389391
("none", 0, False),
390392
("soft", 0, True),
391393
("soft", 4, True),
392-
("soft", 4, True),
393394
]
394395
)
395396
def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, group_by_length):
@@ -402,36 +403,36 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
402403
_create_fake_dataset(data_dir1 / f"depth3", dataset_num=3, max_depth=3, dup_times=9)
403404

404405
# 1. Test resuming with the same world size
406+
dataset_configs = [
407+
{
408+
"dataset": DatasetConfig(anno_path=str(data_dir1)),
409+
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024),
410+
},
411+
]
412+
405413
dataloader_config = DataloaderConfig(
414+
dataset_config_list=dataset_configs,
406415
pack_max_length=1024,
407416
pack_level=pack_level,
408417
num_workers=num_workers,
409418
group_by_length=group_by_length,
410-
collator="fake_collator"
419+
collator="fake_collator",
411420
)
412-
dataset_configs = [
413-
{
414-
"dataset": DatasetConfig(anno_path=str(data_dir1)),
415-
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024)
416-
},
417-
]
418421

419-
ctx = get_context("spawn")
420422
world_size = 2
421423
save_path1 = tmp_path / "dataloader_state.pkl"
424+
_alloc_master_port()
422425
spawn(
423426
_test_resume_spmd,
424427
args=(
425428
world_size,
426429
dataloader_config,
427-
dataset_configs,
428430
16,
429431
BATCH_SIZE,
430432
TOTAL_STEP,
431433
10,
432434
save_path1,
433435
None,
434-
0,
435436
),
436437
nprocs=2,
437438
join=True,
@@ -443,19 +444,18 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
443444

444445
# 2. tet Rsume with same world size
445446
save_path2 = tmp_path / "dataloader_state2.pkl"
447+
_alloc_master_port()
446448
spawn(
447449
_test_resume_spmd,
448450
args=(
449451
world_size,
450452
dataloader_config,
451-
dataset_configs,
452453
16,
453454
BATCH_SIZE,
454455
TOTAL_STEP,
455456
10,
456457
save_path2,
457458
result1["dataloader_state"],
458-
result1["consumed_samples"],
459459
),
460460
nprocs=world_size,
461461
join=True,
@@ -470,19 +470,18 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
470470

471471
world_size = 4
472472
save_path3 = tmp_path / "dataloader_state3.pkl"
473+
_alloc_master_port()
473474
spawn(
474475
_test_resume_spmd,
475476
args=(
476477
world_size,
477478
dataloader_config,
478-
dataset_configs,
479479
16,
480480
BATCH_SIZE,
481481
TOTAL_STEP,
482482
10,
483483
save_path3,
484484
result1["dataloader_state"],
485-
result1["consumed_samples"],
486485
),
487486
nprocs=world_size,
488487
join=True,

xtuner/v1/datasets/dataloader.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from xtuner.v1.datasets.collator import ColateItem
7+
from xtuner.v1.datasets.consumed_steps import ConsumedStepsTracker
78
from xtuner.v1.datasets.resume import get_dataloader_state, load_dataloader_state
89
from xtuner.v1.utils import get_logger
910

@@ -42,6 +43,11 @@ def load_state_dict(
4243
state_dict: dict,
4344
train_state_total_consumed_samples: int | None = None,
4445
) -> None:
46+
if train_state_total_consumed_samples is not None:
47+
logger.warning(
48+
"Dataloader.load_state_dict(train_state_total_consumed_samples=...) is deprecated; "
49+
"use the default (None). Consumed samples are tracked on the sampler."
50+
)
4551
load_dataloader_state(
4652
self,
4753
state_dict,
@@ -58,7 +64,17 @@ def get_state_dict(self, consumed_samples: int = -1) -> dict:
5864
dataloader_state = get_dataloader_state(self, consumed_samples)
5965
return cast(dict, dataloader_state)
6066

61-
# __iter__ is inherited from torch.utils.data.DataLoader
67+
def __iter__(self) -> Iterator[list[ColateItem]]: # type: ignore[override]
68+
# Override to count delivered batches, not prefetched indices.
69+
# With num_workers > 0 the sampler is iterated ahead by DataLoader's prefetch queue,
70+
# so recording inside sampler.__iter__ would count too many samples. Instead we
71+
# increment _consumed exactly once per batch that reaches the caller.
72+
sampler = self.sampler
73+
consumed: ConsumedStepsTracker | None = getattr(sampler, "_consumed", None)
74+
for batch in super().__iter__():
75+
if consumed is not None:
76+
consumed.record(len(batch))
77+
yield batch
6278

6379
# Streaming dataloader may not have `set_epoch` and `__len__` method, so we add here.
6480
def set_epoch(self, epoch: int) -> None:

xtuner/v1/datasets/preset_sampler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,7 @@ def __init__(
163163
def __iter__(self) -> Iterator[int]:
164164
# load order from npy → global_order → rank_view 类型均为 memmap, 子视图 的路径仍然保持
165165
# memmap 语义(视图、按需分页、文件后端);单机多进程可共享同一份文件页缓存
166-
for idx in self.global_order[self.step + self.rank : self.total_size : self.world_size]:
167-
self._consumed.record(1)
168-
yield int(idx)
166+
yield from (int(idx) for idx in self.global_order[self.step + self.rank : self.total_size : self.world_size])
169167
self.step = 0
170168

171169
def __len__(self) -> int:

xtuner/v1/datasets/sampler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def __iter__(self) -> Iterator[int]:
113113
# subsample
114114
indices = indices[self.step + self.rank : self.total_size : self.world_size]
115115

116-
for idx in indices:
117-
self._consumed.record(1)
118-
yield idx
116+
yield from indices
119117
self.step = 0
120118

121119
def __len__(self) -> int:
@@ -268,9 +266,7 @@ def __iter__(self) -> Iterator[int]:
268266
assert len(indices) == self.total_size
269267
indices = indices[self.step + self.rank : self.total_size : self.world_size]
270268
assert len(indices) == self.num_samples - self.step // self.world_size
271-
for idx in indices:
272-
self._consumed.record(1)
273-
yield idx
269+
yield from indices
274270
self.step = 0
275271

276272
def __len__(self) -> int:

0 commit comments

Comments
 (0)