Skip to content

Commit 147cb2e

Browse files
authored
Track consumed samples in DataLoader; skip per-step total-token reduction (#1652)
This PR refactors how total consumed samples are tracked and resumed: - Move ownership of total consumed samples to DataLoader (with updated save/restore paths); remove the older resume helpers from resume.py. - Fix Token accounting: previously, total consumed tokens used a global all-reduce, which over-counted tokens when sequence parallel (SP) is enabled. Totals are now reduced on the dp_mesh only, matching DP-local semantics. - Stop reducing total tokens on every training step to speedup e2e tgs
1 parent 27080d3 commit 147cb2e

13 files changed

Lines changed: 194 additions & 208 deletions

File tree

ci/scripts/test_vlm_sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def parse_args():
204204

205205

206206
def extract_data_from_log(logfile: Path):
207-
pattern_str = r"\[XTuner\].*Step.*lr:\s(\d+.\d*)\s.*text_tokens:\s(\d+.\d*)\s.*reduced_llm_loss:\s(\d+.\d*)\s.*max_memory:\s(\d+.\d*)\s*GB\s.*grad_norm:\s(\d+.\d*)\s.*e2e_tgs:\s(\d+.\d*)"
207+
pattern_str = r"\[XTuner\].*Step.*lr:\s(\d+.\d*)\s.*text_tokens:\s(\d+.\d*)\s.*reduced_llm_loss:\s(\d+.\d*)\s.*max_memory:\s(\d+.\d*)\s*GB\s.*grad_norm:\s(\d+.\d*)\s.*exp_tgs:\s(\d+.\d*)"
208208
compiled_pattern = re.compile(pattern_str)
209209

210210
cur_lr = []

tests/datasets/test_dataloader.py

Lines changed: 43 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
from pathlib import Path
22
import os
33
import pickle
4+
import socket
45

56
import torch
67

7-
from xtuner.v1.datasets import build_dataloader, build_datasets, get_dataloader_state, load_dataloader_state, FTDPTokenizeFnConfig, DatasetConfig, DataloaderConfig
8+
from xtuner.v1.datasets import (
9+
DataloaderConfig,
10+
DatasetConfig,
11+
FTDPTokenizeFnConfig,
12+
build_dataloader,
13+
)
814
from xtuner.v1.train.toy_tokenizer import UTF8ByteTokenizer
9-
from torch.multiprocessing import spawn, get_context
15+
from torch.multiprocessing import spawn
1016
from torch.distributed.device_mesh import init_device_mesh
1117
import pytest
1218

@@ -15,8 +21,6 @@
1521
from itertools import repeat, chain
1622

1723

18-
19-
2024
class RandomDataset:
2125
def __init__(self, size: int, **kwargs):
2226
self.size = size
@@ -182,25 +186,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
182186
dataset_configs = [
183187
{
184188
"dataset": DatasetConfig(anno_path=str(data_dir1)),
185-
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024)
189+
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024),
186190
},
187191
]
188192

189193
dataloader_config = DataloaderConfig(
194+
dataset_config_list=dataset_configs,
190195
pack_max_length=1024,
191196
pack_level=pack_level,
192197
num_workers=num_workers,
193198
group_by_length=group_by_length,
194199
pack_workers=pack_workers,
195200
)
196201

197-
datasets = build_datasets(
198-
dataset_config=dataset_configs,
202+
dataloader1 = dataloader_config.build(
199203
tokenizer=tokenizer,
200-
)
201-
dataloader1 = build_dataloader(
202-
dataloader_config=dataloader_config,
203-
datasets=datasets,
204+
dp_mesh=None,
204205
global_batch_size=GLOBAL_BATCH_SIZE,
205206
micro_batch_size=BATCH_SIZE,
206207
seed=10,
@@ -210,26 +211,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
210211
assert len(dataloader1) > 10
211212

212213
dataloader_iter = iter(dataloader1)
213-
consumed_sample = 0
214214
for _ in range(RESUME_ITER):
215-
batch = next(dataloader_iter)
216-
consumed_sample += len(batch)
215+
next(dataloader_iter)
217216

218-
dataloader_state = get_dataloader_state(dataloader1, consumed_sample)
217+
dataloader_state = dataloader1.get_state_dict()
219218
expected_data = []
220219
for _ in range(AFTER_RESUME_ITER):
221-
batch = next(dataloader_iter)
222-
consumed_sample += len(batch)
223-
expected_data.append(batch)
220+
expected_data.append(next(dataloader_iter))
224221

225-
new_dataloader1 = build_dataloader(
226-
dataloader_config=dataloader_config,
227-
datasets=datasets,
222+
new_dataloader1 = dataloader_config.build(
223+
tokenizer=tokenizer,
224+
dp_mesh=None,
228225
global_batch_size=GLOBAL_BATCH_SIZE,
229226
micro_batch_size=BATCH_SIZE,
230227
seed=10,
231228
)
232-
load_dataloader_state(new_dataloader1, dataloader_state)
229+
new_dataloader1.load_state_dict(dataloader_state)
233230
new_dataloader_iter = iter(new_dataloader1)
234231

235232
resume_data = []
@@ -242,32 +239,29 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
242239
# 2. Test resume after consuming multiple epochs
243240
while True:
244241
try:
245-
batch = next(dataloader_iter)
246-
consumed_sample += len(batch)
242+
next(dataloader_iter)
247243
except StopIteration:
248244
break
249245

250-
251246
dataloader_iter = iter(dataloader1)
252247

253-
for batch in range(RESUME_ITER):
254-
batch = next(dataloader_iter)
255-
consumed_sample += len(batch)
248+
for _ in range(RESUME_ITER):
249+
next(dataloader_iter)
256250

257-
dataloader_state = get_dataloader_state(dataloader1, consumed_sample)
251+
dataloader_state = dataloader1.get_state_dict()
258252

259253
expected_data = []
260254
for _ in range(AFTER_RESUME_ITER):
261255
expected_data.append(next(dataloader_iter))
262256

263-
new_dataloader2 = build_dataloader(
264-
dataloader_config=dataloader_config,
265-
datasets=datasets,
257+
new_dataloader2 = dataloader_config.build(
258+
tokenizer=tokenizer,
259+
dp_mesh=None,
266260
global_batch_size=GLOBAL_BATCH_SIZE,
267261
micro_batch_size=BATCH_SIZE,
268262
seed=10,
269263
)
270-
load_dataloader_state(new_dataloader2, dataloader_state)
264+
new_dataloader2.load_state_dict(dataloader_state)
271265
new_dataloader_iter2 = iter(new_dataloader2)
272266

273267
resume_data = []
@@ -282,65 +276,52 @@ def _test_resume_spmd(
282276
rank: int,
283277
world_size: int,
284278
dataloader_config: DataloaderConfig,
285-
dataset_configs: list[dict],
286279
global_batch_size: int,
287280
micro_batch_size: int,
288-
step:int,
281+
step: int,
289282
seed: int,
290283
save_path: Path,
291284
dataloader_state: dict | None = None,
292-
consumed_samples: int = 0,
293285
):
294286
os.environ["RANK"] = str(rank)
295287
os.environ["LOCAL_RANK"] = str(rank)
296288
os.environ["WORLD_SIZE"] = str(world_size)
297289
os.environ["MASTER_ADDR"] = "localhost"
298290
os.environ["MASTER_PORT"] = "29505"
299291

300-
301292
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
302293
torch.cuda.set_device(rank)
303294
data_mesh = init_device_mesh(
304295
device_type="cuda",
305-
mesh_shape=(world_size,)
296+
mesh_shape=(world_size,),
306297
)
307298
tokenizer = UTF8ByteTokenizer()
308299

309-
datasets = build_datasets(
310-
dataset_config=dataset_configs,
300+
dataloader = dataloader_config.build(
311301
tokenizer=tokenizer,
312-
)
313-
dataloader = build_dataloader(
314-
dataloader_config=dataloader_config,
315-
datasets=datasets,
302+
dp_mesh=data_mesh,
316303
global_batch_size=global_batch_size,
317304
micro_batch_size=micro_batch_size,
318305
seed=seed,
319-
dp_mesh=data_mesh,
320306
)
321307

322308
if dataloader_state is not None:
323-
load_dataloader_state(dataloader, dataloader_state)
309+
dataloader.load_state_dict(dataloader_state)
324310

325311
data_iter = iter(dataloader)
326312
data_list = []
327313
for _ in range(step):
328314
batch = next(data_iter)
329315
data_list.append(batch)
330-
consumed_samples += len(batch)
331316

332-
consumed_samples_list = [None for _ in range(world_size)]
333-
torch.distributed.all_gather_object(consumed_samples_list, consumed_samples)
334-
global_consumed_samples = sum(consumed_samples_list)
317+
# Snapshot after the first `step` batches so total_consumed_samples matches resume intent.
318+
dataloader_state = dataloader.get_state_dict()
335319

336320
expected_data = []
337-
338321
for _ in range(step):
339322
batch = next(data_iter)
340323
expected_data.append(batch)
341324

342-
dataloader_state = get_dataloader_state(dataloader, global_consumed_samples)
343-
344325
all_data_list = [None for _ in range(world_size)]
345326
torch.distributed.all_gather_object(all_data_list, list(chain(*data_list)))
346327

@@ -372,7 +353,6 @@ def _test_resume_spmd(
372353
"dataloader_state": dataloader_state,
373354
"data_list": all_data_list,
374355
"expected_data": all_expected_data,
375-
"consumed_samples": consumed_samples
376356
}
377357
)
378358
)
@@ -389,7 +369,6 @@ def _test_resume_spmd(
389369
("none", 0, False),
390370
("soft", 0, True),
391371
("soft", 4, True),
392-
("soft", 4, True),
393372
]
394373
)
395374
def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, group_by_length):
@@ -402,36 +381,35 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
402381
_create_fake_dataset(data_dir1 / f"depth3", dataset_num=3, max_depth=3, dup_times=9)
403382

404383
# 1. Test resuming with the same world size
384+
dataset_configs = [
385+
{
386+
"dataset": DatasetConfig(anno_path=str(data_dir1)),
387+
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024),
388+
},
389+
]
390+
405391
dataloader_config = DataloaderConfig(
392+
dataset_config_list=dataset_configs,
406393
pack_max_length=1024,
407394
pack_level=pack_level,
408395
num_workers=num_workers,
409396
group_by_length=group_by_length,
410-
collator="fake_collator"
397+
collator="fake_collator",
411398
)
412-
dataset_configs = [
413-
{
414-
"dataset": DatasetConfig(anno_path=str(data_dir1)),
415-
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024)
416-
},
417-
]
418399

419-
ctx = get_context("spawn")
420400
world_size = 2
421401
save_path1 = tmp_path / "dataloader_state.pkl"
422402
spawn(
423403
_test_resume_spmd,
424404
args=(
425405
world_size,
426406
dataloader_config,
427-
dataset_configs,
428407
16,
429408
BATCH_SIZE,
430409
TOTAL_STEP,
431410
10,
432411
save_path1,
433412
None,
434-
0,
435413
),
436414
nprocs=2,
437415
join=True,
@@ -448,14 +426,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
448426
args=(
449427
world_size,
450428
dataloader_config,
451-
dataset_configs,
452429
16,
453430
BATCH_SIZE,
454431
TOTAL_STEP,
455432
10,
456433
save_path2,
457434
result1["dataloader_state"],
458-
result1["consumed_samples"],
459435
),
460436
nprocs=world_size,
461437
join=True,
@@ -475,14 +451,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
475451
args=(
476452
world_size,
477453
dataloader_config,
478-
dataset_configs,
479454
16,
480455
BATCH_SIZE,
481456
TOTAL_STEP,
482457
10,
483458
save_path3,
484459
result1["dataloader_state"],
485-
result1["consumed_samples"],
486460
),
487461
nprocs=world_size,
488462
join=True,

tests/datasets/test_preset_dataloader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from itertools import chain
2222

23-
from xtuner.v1.datasets import PretrainTokenizeFunctionConfig, get_dataloader_state, load_dataloader_state
23+
from xtuner.v1.datasets import PretrainTokenizeFunctionConfig
2424
from xtuner.v1.datasets.config import DatasetConfig, DataloaderConfig
2525
from xtuner.v1.datasets.packing import get_pack_infos_by_hard_split
2626
from xtuner.v1.datasets.preset_pack import PresetPackDataset
@@ -700,8 +700,8 @@ def _build():
700700
global_consumed_samples = sum(int(x) for x in consumed_samples_list if x is not None)
701701

702702
# 3. Get ckpt state
703-
# dataloader_state = get_dataloader_state(dl, global_consumed_samples)
704-
dataloader_state = dl.get_state_dict(global_consumed_samples)
703+
dataloader_state = dl.get_state_dict()
704+
assert dataloader_state["total_consumed_samples"] == global_consumed_samples
705705

706706
# 4. Continue to consume data at [half_step, 2*half_step)
707707
expected_batches = []
@@ -738,7 +738,7 @@ def _build():
738738
dl2 = _build()
739739
with ckpt_path.open("rb") as f:
740740
ckpt = pickle.load(f)
741-
# load_dataloader_state(dl2, ckpt["dataloader_state"])
741+
# dl2.load_state_dict(ckpt["dataloader_state"])
742742
dl2.load_state_dict(ckpt["dataloader_state"])
743743

744744
resume_iter = iter(dl2)

tests/datasets/test_preset_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_state_dict_resume(tmp_path):
8787

8888
sampler = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1)
8989

90-
state = sampler.get_state_dict(step=3)
90+
state = sampler.get_state_dict(3)
9191
assert state["step"] == 3
9292

9393
sampler2 = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1)
@@ -102,7 +102,7 @@ def test_state_dict_world_size_mismatch(tmp_path):
102102
path = _write_order_npy(tmp_path, "order.npy", _i64(0, 1, 2, 3))
103103

104104
sampler = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1)
105-
state = sampler.get_state_dict(step=0)
105+
state = sampler.get_state_dict(0)
106106
state["world_size"] = 99
107107

108108
sampler.load_state_dict(state)

xtuner/v1/datasets/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
PretrainTokenizeFunction,
2626
PretrainTokenizeFunctionConfig,
2727
)
28-
from .resume import get_dataloader_state, load_dataloader_state
2928
from .rl_tokenize_fn import RLTokenizeFnConfig
3029
from .sampler import LengthGroupedSampler, ParallelSampler
3130
from .sft_tokenize_fn import OpenaiTokenizeFunction, OpenaiTokenizeFunctionConfig
@@ -68,8 +67,6 @@
6867
"InternS1VLTokenizeFnConfig",
6968
"fake_collator",
7069
"RLTokenizeFnConfig",
71-
"get_dataloader_state",
72-
"load_dataloader_state",
7370
"DatasetConfigList",
7471
"DataloaderConfig",
7572
"BaseTokenizeFnConfig",

xtuner/v1/datasets/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,5 +544,6 @@ def build(
544544
collate_fn=collator,
545545
multiprocessing_context=ctx if self.num_workers > 0 else None,
546546
persistent_workers=self.num_workers > 0,
547+
dp_mesh=dp_mesh,
547548
)
548549
return dataloader

0 commit comments

Comments
 (0)