Skip to content

Commit adb12e8

Browse files
committed
Sampler add ConsumedStepsTracker for tracking consumed samples across data-parallel groups
1 parent 3931e83 commit adb12e8

6 files changed

Lines changed: 188 additions & 34 deletions

File tree

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Track consumed samples for checkpointing; aggregate across DP only (not
2+
SP/TP)."""
3+
4+
from __future__ import annotations
5+
6+
import torch
7+
import torch.distributed as dist
8+
from torch.distributed.device_mesh import DeviceMesh
9+
10+
11+
def reduce_sum_across_dp_group(dp_mesh: DeviceMesh | None, local_value: int) -> int:
12+
"""Sum ``local_value`` over the DP process group (one contribution per
13+
data-parallel replica).
14+
15+
Ranks that only differ in SP/TP see identical data batches and must not be summed with the global world group; see
16+
Training notes for SP+DP.
17+
"""
18+
if dp_mesh is None or dp_mesh.size() <= 1:
19+
return int(local_value)
20+
if not dist.is_available() or not dist.is_initialized():
21+
return int(local_value)
22+
if torch.cuda.is_available():
23+
device = torch.device(f"cuda:{torch.cuda.current_device()}")
24+
else:
25+
device = torch.device("cpu")
26+
tensor = torch.tensor([local_value], dtype=torch.int64, device=device)
27+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=dp_mesh.get_group())
28+
return int(tensor.item())
29+
30+
31+
class ConsumedStepsTracker:
32+
"""Holds per-resume totals and per-rank local accumulation; checkpoint
33+
total uses DP-only reduction."""
34+
35+
__slots__ = ("_dp_mesh", "_init_steps", "_local_steps")
36+
37+
def __init__(self, dp_mesh: DeviceMesh | None) -> None:
38+
self._dp_mesh = dp_mesh
39+
self._init_steps = 0
40+
self._local_steps = 0
41+
42+
def record(self, n: int) -> None:
43+
self._local_steps += int(n)
44+
45+
def set_init_from_checkpoint(self, total: int) -> None:
46+
"""After loading a checkpoint: global total consumed so far; reset session-local accumulation."""
47+
self._init_steps = int(total)
48+
self._local_steps = 0
49+
50+
def total_for_checkpoint(self) -> int:
51+
"""Global consumed sample count including this session (collective over
52+
DP group)."""
53+
return self._init_steps + reduce_sum_across_dp_group(self._dp_mesh, self._local_steps)
54+
55+
56+
def apply_old_ckpt_init_steps(sampler: object, sampler_state: dict, train_state_total: int | None) -> None:
57+
"""If the sampler checkpoint predates ``total_consumed_steps``, copy the
58+
total from ``train_state``."""
59+
if train_state_total is None:
60+
return
61+
if sampler_state.get("total_consumed_steps") is not None:
62+
return
63+
consumed: ConsumedStepsTracker | None = getattr(sampler, "_consumed", None)
64+
if consumed is not None:
65+
consumed.set_init_from_checkpoint(train_state_total)

xtuner/v1/datasets/dataloader.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
from xtuner.v1.datasets.collator import ColateItem
77
from xtuner.v1.datasets.resume import get_dataloader_state, load_dataloader_state
8+
from xtuner.v1.utils import get_logger
9+
10+
11+
logger = get_logger()
812

913

1014
class BaseDataloader(ABC):
@@ -16,10 +20,10 @@ class BaseDataloader(ABC):
1620
"""
1721

1822
@abstractmethod
19-
def load_state_dict(self, state_dict: dict) -> None: ...
23+
def load_state_dict(self, state_dict: dict, train_state_total_consumed_samples: int | None = None) -> None: ...
2024

2125
@abstractmethod
22-
def get_state_dict(self, consumed_samples: int) -> dict: ...
26+
def get_state_dict(self, consumed_samples: int = -1) -> dict: ...
2327

2428
@abstractmethod
2529
def __iter__(self) -> Iterator[list[ColateItem]]: ...
@@ -33,13 +37,36 @@ class Dataloader(torch.utils.data.DataLoader, BaseDataloader):
3337
implement.
3438
"""
3539

36-
def load_state_dict(self, state_dict: dict) -> None:
37-
load_dataloader_state(self, state_dict)
40+
def load_state_dict(
41+
self,
42+
state_dict: dict,
43+
train_state_total_consumed_samples: int | None = None,
44+
) -> None:
45+
load_dataloader_state(
46+
self,
47+
state_dict,
48+
train_state_total_consumed_samples=train_state_total_consumed_samples,
49+
)
3850

39-
def get_state_dict(self, consumed_samples: int) -> dict:
51+
def get_state_dict(self, consumed_samples: int = -1) -> dict:
52+
if consumed_samples != -1:
53+
logger.warning(
54+
"Dataloader.get_state_dict(consumed_samples=...) is deprecated; use the default (-1). "
55+
"Consumed samples are tracked on the sampler."
56+
)
4057
dataloader_state = get_dataloader_state(self, consumed_samples)
4158
return cast(dict, dataloader_state)
4259

60+
def record_consumed_samples(self, n: int) -> None:
61+
if hasattr(self.sampler, "record_consumed_samples"):
62+
self.sampler.record_consumed_samples(n)
63+
64+
def get_total_consumed_samples(self) -> int:
65+
sampler = self.sampler
66+
if hasattr(sampler, "get_total_consumed_steps"):
67+
return int(sampler.get_total_consumed_steps())
68+
return 0
69+
4370
# __iter__ is inherited from torch.utils.data.DataLoader
4471

4572
# Streaming dataloader may not have `set_epoch` and `__len__` method, so we add here.

xtuner/v1/datasets/preset_sampler.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from xtuner.v1.utils import get_logger
2424

25+
from .consumed_steps import ConsumedStepsTracker
2526
from .preset_pack import PresetPackDataset
2627

2728

@@ -116,6 +117,7 @@ def __init__(
116117
else:
117118
self.rank = 0
118119
self.world_size = 1
120+
self._consumed = ConsumedStepsTracker(dp_mesh)
119121

120122
self.dataset = dataset
121123
self.global_batch_size = global_batch_size
@@ -170,19 +172,35 @@ def __len__(self) -> int:
170172
def set_epoch(self, epoch: int) -> None:
171173
self.epoch = epoch
172174

173-
def get_state_dict(self, step: int) -> dict:
175+
def record_consumed_samples(self, n: int) -> None:
176+
self._consumed.record(n)
177+
178+
def get_total_consumed_steps(self) -> int:
179+
return self._consumed.total_for_checkpoint()
180+
181+
def get_state_dict(self, step: int | None = None) -> dict:
174182
# Same convention as :class:`LengthGroupedSampler`: ``step`` is the global pack offset
175183
# (modulo ``total_size``) into ``global_order``, shared across all ranks in the checkpoint.
176-
global_step = step % self.total_size
184+
if step is None:
185+
total_consumed = self._consumed.total_for_checkpoint()
186+
else:
187+
total_consumed = int(step)
188+
global_step = total_consumed % self.total_size
177189
return {
178190
"epoch": self.epoch,
179191
"step": global_step,
192+
"total_consumed_steps": total_consumed,
180193
"world_size": self.world_size,
181194
"num_samples": self.num_samples,
182195
"total_size": self.total_size,
183196
}
184197

185198
def load_state_dict(self, state_dict: dict) -> None:
199+
tc = state_dict.get("total_consumed_steps")
200+
if tc is not None:
201+
self._consumed.set_init_from_checkpoint(int(tc))
202+
else:
203+
self._consumed.set_init_from_checkpoint(0)
186204
if self.world_size != state_dict.get("world_size"):
187205
logger.warning(
188206
f"PresetSampler: world_size mismatch: checkpoint has "
@@ -191,5 +209,4 @@ def load_state_dict(self, state_dict: dict) -> None:
191209
)
192210

193211
self.epoch = state_dict["epoch"]
194-
global_step = int(state_dict["step"])
195-
self.step = global_step
212+
self.step = int(state_dict["step"])

xtuner/v1/datasets/resume.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from xtuner.v1.utils import get_logger
55

6+
from .consumed_steps import apply_old_ckpt_init_steps
67
from .packing import ExpandSoftPackDataset, _LegacySoftPackDataset
78
from .sampler import LengthGroupedSampler, ParallelSampler
89

@@ -15,15 +16,21 @@ class DataloaderState(TypedDict):
1516
dataset: dict
1617

1718

18-
def get_dataloader_state(dataloader: DataLoader, consumed_samples: int) -> DataloaderState:
19+
def get_dataloader_state(dataloader: DataLoader, consumed_samples: int = -1) -> DataloaderState:
1920
sampler: ParallelSampler | LengthGroupedSampler = dataloader.sampler # type: ignore[assignment]
2021
dataset: ExpandSoftPackDataset | _LegacySoftPackDataset = dataloader.dataset # type: ignore[assignment]
2122
dataloader_state = DataloaderState(sampler={}, dataset={})
2223

2324
if not hasattr(sampler, "load_state_dict") or not hasattr(sampler, "get_state_dict"):
2425
logger.warning(f"Resuming from {type(sampler)} is risky.")
25-
else:
26+
elif consumed_samples != -1:
27+
logger.warning(
28+
"Passing consumed_samples to get_dataloader_state is deprecated; "
29+
"consumed sample totals are tracked on the sampler. Use the default consumed_samples=-1."
30+
)
2631
dataloader_state["sampler"].update(sampler.get_state_dict(step=consumed_samples))
32+
else:
33+
dataloader_state["sampler"].update(sampler.get_state_dict())
2734

2835
if not hasattr(dataset, "load_state_dict") or not hasattr(dataset, "get_state_dict"):
2936
logger.warning(f"Resuming from {type(dataset)} is risky.")
@@ -33,7 +40,11 @@ def get_dataloader_state(dataloader: DataLoader, consumed_samples: int) -> Datal
3340
return dataloader_state
3441

3542

36-
def load_dataloader_state(dataloader: DataLoader, state: dict):
43+
def load_dataloader_state(
44+
dataloader: DataLoader,
45+
state: dict,
46+
train_state_total_consumed_samples: int | None = None,
47+
):
3748
sampler = dataloader.sampler
3849
dataset = dataloader.dataset
3950

@@ -44,6 +55,7 @@ def load_dataloader_state(dataloader: DataLoader, state: dict):
4455

4556
if hasattr(sampler, "load_state_dict"):
4657
sampler.load_state_dict(state["sampler"])
58+
apply_old_ckpt_init_steps(sampler, state["sampler"], train_state_total_consumed_samples)
4759

4860
# If the dataset records the training progress, we also restore it.
4961
if hasattr(dataset, "load_state_dict"):

xtuner/v1/datasets/sampler.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from xtuner.v1.utils import get_logger
1414

15+
from .consumed_steps import ConsumedStepsTracker
1516
from .jsonl import JsonlDataset
1617
from .packing import MLLMPretrainHybridPackDataset, _LegacySoftPackDataset
1718
from .preset_pack import PresetPackDataset
@@ -84,6 +85,7 @@ def __init__(
8485
self.epoch = 0
8586
self.step = 0
8687
self.round_up = round_up
88+
self._consumed = ConsumedStepsTracker(dp_mesh)
8789

8890
if self.round_up:
8991
self.num_samples = math.ceil(len(self.dataset) / global_batch_size) * global_batch_size // world_size
@@ -131,12 +133,23 @@ def set_epoch(self, epoch: int) -> None:
131133
"""
132134
self.epoch = epoch
133135

136+
def record_consumed_samples(self, n: int) -> None:
137+
self._consumed.record(n)
138+
139+
def get_total_consumed_steps(self) -> int:
140+
return self._consumed.total_for_checkpoint()
141+
134142
def load_state_dict(self, state_dict) -> None:
135143
"""Load the sampler state.
136144
137145
Args:
138146
state_dict (dict): The state of the sampler.
139147
"""
148+
tc = state_dict.get("total_consumed_steps")
149+
if tc is not None:
150+
self._consumed.set_init_from_checkpoint(int(tc))
151+
else:
152+
self._consumed.set_init_from_checkpoint(0)
140153
self.epoch = state_dict["epoch"]
141154
self.step = state_dict["step"]
142155

@@ -146,12 +159,17 @@ def load_state_dict(self, state_dict) -> None:
146159
f"is different from the current shuffle ({self.shuffle})."
147160
)
148161

149-
def get_state_dict(self, step: int):
162+
def get_state_dict(self, step: int | None = None):
150163
# Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples.
151-
step = step % self.total_size
164+
if step is None:
165+
total_consumed = self._consumed.total_for_checkpoint()
166+
else:
167+
total_consumed = int(step)
168+
step_mod = total_consumed % self.total_size
152169
return {
153170
"epoch": self.epoch,
154-
"step": step,
171+
"step": step_mod,
172+
"total_consumed_steps": total_consumed,
155173
"world_size": self.world_size,
156174
"shuffle": self.shuffle,
157175
"round_up": self.round_up,
@@ -233,6 +251,7 @@ def __init__(
233251
assert isinstance(self.max_lengths, (list, tuple, Column, np.ndarray))
234252

235253
self.global_batch_size = global_batch_size
254+
self._consumed = ConsumedStepsTracker(dp_mesh)
236255

237256
def __iter__(self) -> Iterator[int]:
238257
"""Iterate the indices."""
@@ -275,12 +294,23 @@ def set_epoch(self, epoch: int) -> None:
275294
"""
276295
self.epoch = epoch
277296

297+
def record_consumed_samples(self, n: int) -> None:
298+
self._consumed.record(n)
299+
300+
def get_total_consumed_steps(self) -> int:
301+
return self._consumed.total_for_checkpoint()
302+
278303
def load_state_dict(self, state_dict: dict) -> None:
279304
"""Load the sampler state.
280305
281306
Args:
282307
state_dict (dict): The state of the sampler.
283308
"""
309+
tc = state_dict.get("total_consumed_steps")
310+
if tc is not None:
311+
self._consumed.set_init_from_checkpoint(int(tc))
312+
else:
313+
self._consumed.set_init_from_checkpoint(0)
284314
self.epoch = state_dict["epoch"]
285315
self.step = state_dict["step"]
286316

@@ -298,17 +328,22 @@ def load_state_dict(self, state_dict: dict) -> None:
298328
)
299329
self.group_size = origin_group_size
300330

301-
def get_state_dict(self, step: int):
331+
def get_state_dict(self, step: int | None = None):
302332
"""Get the sampler state dict.
303333
304334
Returns:
305335
dict: The state of the sampler.
306336
"""
307337
# Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples.
308-
step = step % self.total_size
338+
if step is None:
339+
total_consumed = self._consumed.total_for_checkpoint()
340+
else:
341+
total_consumed = int(step)
342+
step_mod = total_consumed % self.total_size
309343
return {
310344
"epoch": self.epoch,
311-
"step": step,
345+
"step": step_mod,
346+
"total_consumed_steps": total_consumed,
312347
"world_size": self.world_size,
313348
"round_up": self.round_up,
314349
"num_samples": self.num_samples,

0 commit comments

Comments
 (0)