From ec10d9841ad53795cafb58c465cc4986a62fbace Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Thu, 2 Apr 2026 13:46:36 +0000 Subject: [PATCH 01/14] skip reduce total tokens in every step --- ci/scripts/test_vlm_sft_trainer.py | 2 +- xtuner/v1/train/trainer.py | 71 ++++++++++++++---------------- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/ci/scripts/test_vlm_sft_trainer.py b/ci/scripts/test_vlm_sft_trainer.py index 204eb9f5d..e36edc130 100644 --- a/ci/scripts/test_vlm_sft_trainer.py +++ b/ci/scripts/test_vlm_sft_trainer.py @@ -204,7 +204,7 @@ def parse_args(): def extract_data_from_log(logfile: Path): - pattern_str = r"\[XTuner\].*Step.*lr:\s(\d+.\d*)\s.*text_tokens:\s(\d+.\d*)\s.*reduced_llm_loss:\s(\d+.\d*)\s.*max_memory:\s(\d+.\d*)\s*GB\s.*grad_norm:\s(\d+.\d*)\s.*e2e_tgs:\s(\d+.\d*)" + pattern_str = r"\[XTuner\].*Step.*lr:\s(\d+.\d*)\s.*text_tokens:\s(\d+.\d*)\s.*reduced_llm_loss:\s(\d+.\d*)\s.*max_memory:\s(\d+.\d*)\s*GB\s.*grad_norm:\s(\d+.\d*)\s.*exp_tgs:\s(\d+.\d*)" compiled_pattern = re.compile(pattern_str) cur_lr = [] diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index d09b5f652..feb471050 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -86,11 +86,9 @@ class ExpHistory(TypedDict): class PerformanceStatistics(TypedDict): local_step_consumed_tokens: int local_step_consumed_img_tokens: int | None - step_consumed_tokens: int + local_total_consumed_tokens: int total_consumed_tokens: int - total_consumed_tokens_per_rank: float tgs: float - e2e_tgs: float exp_tgs: float eta_seconds: float eta_hms: str @@ -537,9 +535,11 @@ def __init__( self._debug = debug self._seed = seed - self._total_consumed_tokens = 0 - self._exp_consumed_tokens = 0 - self._total_consumed_samples = 0 + self._local_total_consumed_tokens = 0 + self._local_exp_consumed_tokens = 0 + self._local_consumed_samples = 0 + self._init_total_tokens = 0 + self._init_total_samples = 0 self._train_time = 0 self._train_time_offset = 0 @@ -759,17 +759,16 @@ def fit(self): internal_metrics = self._maybe_pop_model_internal_metrics(engine_input) self._cur_step += 1 - reduced_step_consumed_tokens = self._reduce_number_across_rank(train_step_info["step_consumed_tokens"]) - self._total_consumed_tokens += reduced_step_consumed_tokens - self._exp_consumed_tokens += reduced_step_consumed_tokens - self._total_consumed_samples += self._reduce_number_across_rank(consumed_samples) + step_tokens = train_step_info["step_consumed_tokens"] + self._local_total_consumed_tokens += step_tokens + self._local_exp_consumed_tokens += step_tokens + self._local_consumed_samples += consumed_samples self._train_time = time_after_train_step - train_begin # Compute training metrics training_metrics = self._compute_performance_metrics( - local_step_consumed_tokens=train_step_info["step_consumed_tokens"], + local_step_consumed_tokens=step_tokens, local_step_consumed_img_tokens=train_step_info.get("step_consumed_img_tokens"), - step_consumed_tokens=reduced_step_consumed_tokens, step_time=step_time, ) @@ -1129,8 +1128,15 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: optimizer_dir=optimizer_path, ) + total_consumed_tokens = ( + self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens + ) + total_consumed_samples = ( + self._reduce_number_across_rank(self._local_consumed_samples) + self._init_total_samples + ) + # Save dataloader - self._save_dataloader(dataloader_path) + self._save_dataloader(dataloader_path, total_consumed_samples=total_consumed_samples) DEVICE_MODULE.empty_cache() @@ -1160,8 +1166,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: { "cur_step": self.cur_step, "cur_epoch": self._cur_epoch, - "total_consumed_samples": self._total_consumed_samples, - "total_consumed_tokens": self._total_consumed_tokens, + "total_consumed_samples": total_consumed_samples, + "total_consumed_tokens": total_consumed_tokens, "train_time_offset": self._train_time + self._train_time_offset, } ) @@ -1173,8 +1179,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: ckp_list.append(str(checkpoint_path)) current_exp.cur_step = self.cur_step current_exp.cur_epoch = self._cur_epoch - current_exp.consumed_samples = int(self._total_consumed_samples) - current_exp.consumed_tokens = int(self._total_consumed_tokens) + current_exp.consumed_samples = int(total_consumed_samples) + current_exp.consumed_tokens = int(total_consumed_tokens) current_exp.history[-1]["end"] = self.cur_step # Delete checkpoints and update meta's checkpoint_list @@ -1209,9 +1215,9 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: return True - def _save_dataloader(self, dataloader_path: Path | str): + def _save_dataloader(self, dataloader_path: Path | str, total_consumed_samples: int): if self.rank == 0: - dataloader_state = self._dataloader.get_state_dict(self._total_consumed_samples) + dataloader_state = self._dataloader.get_state_dict(total_consumed_samples) torch.save(dataloader_state, dataloader_path) @property @@ -1444,7 +1450,6 @@ def _compute_performance_metrics( self, local_step_consumed_tokens: int, local_step_consumed_img_tokens: int | None, - step_consumed_tokens: int, step_time: float, ) -> PerformanceStatistics: """Compute training metrics including tokens and throughput statistics. @@ -1452,18 +1457,17 @@ def _compute_performance_metrics( Args: local_step_consumed_tokens (int): Tokens consumed in current step on current rank. local_step_consumed_img_tokens (int | None): Image tokens consumed in current step on current rank. - step_consumed_tokens (int): Total tokens consumed in current step across all ranks. step_time (float): Time spent on current training step in seconds. Returns: TrainingMetrics: Dictionary containing computed training metrics. """ e2e_train_time = self._train_time + self._train_time_offset - total_consumed_tokens_per_rank = self._total_consumed_tokens / self.world_size tgs = local_step_consumed_tokens / step_time - e2e_tgs = total_consumed_tokens_per_rank / e2e_train_time - exp_tgs = self._exp_consumed_tokens / self.world_size / self._train_time + total_consumed_tokens = self._init_total_tokens + self._local_total_consumed_tokens * self.world_size + total_consumed_tokens_per_rank = total_consumed_tokens / self.world_size + exp_tgs = self._local_exp_consumed_tokens / self._train_time if self._train_time > 0 else 0.0 remaining_steps = self.total_step - self.cur_step avg_tokens_per_step = total_consumed_tokens_per_rank / self.cur_step @@ -1474,11 +1478,9 @@ def _compute_performance_metrics( return PerformanceStatistics( local_step_consumed_tokens=local_step_consumed_tokens, local_step_consumed_img_tokens=local_step_consumed_img_tokens, - step_consumed_tokens=step_consumed_tokens, - total_consumed_tokens=self._total_consumed_tokens, - total_consumed_tokens_per_rank=total_consumed_tokens_per_rank, + local_total_consumed_tokens=self._local_total_consumed_tokens, + total_consumed_tokens=total_consumed_tokens, tgs=tgs, - e2e_tgs=e2e_tgs, exp_tgs=exp_tgs, eta_seconds=eta_seconds, eta_hms=eta_hms, @@ -1533,7 +1535,6 @@ def _log_step( f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} " f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} " f"text_tokens: {training_metrics['local_step_consumed_tokens']} {img_tokens_str}" - f"step_consumed_tokens: {training_metrics['step_consumed_tokens']} " f"total_consumed_tokens: {training_metrics['total_consumed_tokens']} " f"{loss_log_str} " f"{data_info_str} " @@ -1543,7 +1544,6 @@ def _log_step( f"reserved_memory: {reserved_memory / (1024**3):.2f} GB " f"tgs: {training_metrics['tgs']:.1f} " f"exp_tgs: {training_metrics['exp_tgs']:.1f} " - f"e2e_tgs: {training_metrics['e2e_tgs']:.1f} " f"eta: {training_metrics['eta_hms']} " ) @@ -1554,11 +1554,9 @@ def _log_step( "time/train_time": round(self._train_time, 4), "time/eta_seconds": round(training_metrics["eta_seconds"], 1), "runtime_info/text_tokens": training_metrics["local_step_consumed_tokens"], - "runtime_info/step_consumed_tokens": training_metrics["step_consumed_tokens"], "runtime_info/total_consumed_tokens": training_metrics["total_consumed_tokens"], "runtime_info/tgs": training_metrics["tgs"], "runtime_info/exp_tgs": training_metrics["exp_tgs"], - "runtime_info/e2e_tgs": training_metrics["e2e_tgs"], "memory/max_memory_GB": round(max_memory / (1024**3), 3), "memory/reserved_memory_GB": round(reserved_memory / (1024**3), 3), "grad_norm": grad_norm, @@ -1810,14 +1808,11 @@ def _load_checkpoint(self): self._cur_step = train_state["cur_step"] self._cur_epoch = train_state["cur_epoch"] + self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC + self._init_total_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC + if load_checkpoint_cfg.load_dataset: - self._total_consumed_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC self._train_time_offset = train_state["train_time_offset"] - # _total_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。 - # 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值。 - # 2) 如果不加载 dataset,应该保持_total_consumed_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples - # 会导致存储新dataloader时 total_consumed_samples 是不正确的值。 - self._total_consumed_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC dataloader_path = resume_from / self._SAVE_DATALOADER_DIR self._resume_dataloader(dataloader_path) From 91fd38a39cef1c6de2f4acf13a1e496d8469f781 Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Thu, 2 Apr 2026 14:06:49 +0000 Subject: [PATCH 02/14] refine code --- xtuner/v1/train/trainer.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index feb471050..19d4beeff 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -87,7 +87,7 @@ class PerformanceStatistics(TypedDict): local_step_consumed_tokens: int local_step_consumed_img_tokens: int | None local_total_consumed_tokens: int - total_consumed_tokens: int + approximate_total_consumed_tokens: int tgs: float exp_tgs: float eta_seconds: float @@ -535,9 +535,12 @@ def __init__( self._debug = debug self._seed = seed + # 日志变量前缀规则: + # 空间上,当前rank的用 local_,默认 reduced 无前缀 + # 时间上,当前步用 step_, 累积用 total_ + # self._local_total_consumed_tokens 表示时间上累积到现在的当前rank的和,resume则只考虑resume步数到现在 self._local_total_consumed_tokens = 0 - self._local_exp_consumed_tokens = 0 - self._local_consumed_samples = 0 + self._local_total_samples = 0 self._init_total_tokens = 0 self._init_total_samples = 0 @@ -761,8 +764,7 @@ def fit(self): self._cur_step += 1 step_tokens = train_step_info["step_consumed_tokens"] self._local_total_consumed_tokens += step_tokens - self._local_exp_consumed_tokens += step_tokens - self._local_consumed_samples += consumed_samples + self._local_total_samples += consumed_samples self._train_time = time_after_train_step - train_begin # Compute training metrics @@ -1131,9 +1133,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: total_consumed_tokens = ( self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens ) - total_consumed_samples = ( - self._reduce_number_across_rank(self._local_consumed_samples) + self._init_total_samples - ) + total_consumed_samples = self._reduce_number_across_rank(self._local_total_samples) + self._init_total_samples # Save dataloader self._save_dataloader(dataloader_path, total_consumed_samples=total_consumed_samples) @@ -1465,12 +1465,14 @@ def _compute_performance_metrics( e2e_train_time = self._train_time + self._train_time_offset tgs = local_step_consumed_tokens / step_time - total_consumed_tokens = self._init_total_tokens + self._local_total_consumed_tokens * self.world_size - total_consumed_tokens_per_rank = total_consumed_tokens / self.world_size - exp_tgs = self._local_exp_consumed_tokens / self._train_time if self._train_time > 0 else 0.0 + approximate_total_consumed_tokens = ( + self._init_total_tokens + self._local_total_consumed_tokens * self.world_size + ) + approximate_total_consumed_tokens_per_rank = approximate_total_consumed_tokens / self.world_size + exp_tgs = self._local_total_consumed_tokens / self._train_time if self._train_time > 0 else 0.0 remaining_steps = self.total_step - self.cur_step - avg_tokens_per_step = total_consumed_tokens_per_rank / self.cur_step + avg_tokens_per_step = approximate_total_consumed_tokens_per_rank / self.cur_step remaining_tokens = remaining_steps * avg_tokens_per_step eta_seconds = remaining_tokens / max(tgs, 1) eta_hms = str(timedelta(seconds=int(eta_seconds))) @@ -1479,7 +1481,7 @@ def _compute_performance_metrics( local_step_consumed_tokens=local_step_consumed_tokens, local_step_consumed_img_tokens=local_step_consumed_img_tokens, local_total_consumed_tokens=self._local_total_consumed_tokens, - total_consumed_tokens=total_consumed_tokens, + approximate_total_consumed_tokens=approximate_total_consumed_tokens, tgs=tgs, exp_tgs=exp_tgs, eta_seconds=eta_seconds, @@ -1535,7 +1537,7 @@ def _log_step( f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} " f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} " f"text_tokens: {training_metrics['local_step_consumed_tokens']} {img_tokens_str}" - f"total_consumed_tokens: {training_metrics['total_consumed_tokens']} " + f"approximate_total_consumed_tokens: {training_metrics['approximate_total_consumed_tokens']} " f"{loss_log_str} " f"{data_info_str} " f"{extra_info_str} " @@ -1554,7 +1556,7 @@ def _log_step( "time/train_time": round(self._train_time, 4), "time/eta_seconds": round(training_metrics["eta_seconds"], 1), "runtime_info/text_tokens": training_metrics["local_step_consumed_tokens"], - "runtime_info/total_consumed_tokens": training_metrics["total_consumed_tokens"], + "runtime_info/approximate_total_consumed_tokens": training_metrics["approximate_total_consumed_tokens"], "runtime_info/tgs": training_metrics["tgs"], "runtime_info/exp_tgs": training_metrics["exp_tgs"], "memory/max_memory_GB": round(max_memory / (1024**3), 3), From e0460afbdba45970a2019c640b6d709ec9e4d774 Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Thu, 2 Apr 2026 14:13:46 +0000 Subject: [PATCH 03/14] fix resuming init total tokens and samples --- xtuner/v1/train/trainer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 19d4beeff..3e2619c38 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -1810,11 +1810,15 @@ def _load_checkpoint(self): self._cur_step = train_state["cur_step"] self._cur_epoch = train_state["cur_epoch"] - self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC - self._init_total_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC - if load_checkpoint_cfg.load_dataset: self._train_time_offset = train_state["train_time_offset"] + self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC + # TODO: total_samples 由 Dataloader 维护, 包括 save/resume + # self._init_total_samples 会影响 save dcp时 dataloader.get_state_dict的状态。 + # 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值。 + # 2) 如果不加载 dataset,应该保持 self._init_total_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples + # 会导致存储新dataloader时 total_consumed_samples 是不正确的值。 + self._init_total_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC dataloader_path = resume_from / self._SAVE_DATALOADER_DIR self._resume_dataloader(dataloader_path) From cad18c50832b0aa650996828e184e5cec634f07f Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Fri, 3 Apr 2026 10:50:50 +0000 Subject: [PATCH 04/14] Sampler add ConsumedStepsTracker for tracking consumed samples across data-parallel groups --- xtuner/v1/datasets/consumed_steps.py | 65 ++++++++++++++++++++++++++++ xtuner/v1/datasets/dataloader.py | 37 +++++++++++++--- xtuner/v1/datasets/preset_sampler.py | 25 +++++++++-- xtuner/v1/datasets/resume.py | 18 ++++++-- xtuner/v1/datasets/sampler.py | 47 +++++++++++++++++--- xtuner/v1/train/trainer.py | 30 ++++++------- 6 files changed, 188 insertions(+), 34 deletions(-) create mode 100644 xtuner/v1/datasets/consumed_steps.py diff --git a/xtuner/v1/datasets/consumed_steps.py b/xtuner/v1/datasets/consumed_steps.py new file mode 100644 index 000000000..b2dc9833a --- /dev/null +++ b/xtuner/v1/datasets/consumed_steps.py @@ -0,0 +1,65 @@ +"""Track consumed samples for checkpointing; aggregate across DP only (not +SP/TP).""" + +from __future__ import annotations + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + + +def reduce_sum_across_dp_group(dp_mesh: DeviceMesh | None, local_value: int) -> int: + """Sum ``local_value`` over the DP process group (one contribution per + data-parallel replica). + + Ranks that only differ in SP/TP see identical data batches and must not be summed with the global world group; see + Training notes for SP+DP. + """ + if dp_mesh is None or dp_mesh.size() <= 1: + return int(local_value) + if not dist.is_available() or not dist.is_initialized(): + return int(local_value) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + else: + device = torch.device("cpu") + tensor = torch.tensor([local_value], dtype=torch.int64, device=device) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=dp_mesh.get_group()) + return int(tensor.item()) + + +class ConsumedStepsTracker: + """Holds per-resume totals and per-rank local accumulation; checkpoint + total uses DP-only reduction.""" + + __slots__ = ("_dp_mesh", "_init_steps", "_local_steps") + + def __init__(self, dp_mesh: DeviceMesh | None) -> None: + self._dp_mesh = dp_mesh + self._init_steps = 0 + self._local_steps = 0 + + def record(self, n: int) -> None: + self._local_steps += int(n) + + def set_init_from_checkpoint(self, total: int) -> None: + """After loading a checkpoint: global total consumed so far; reset session-local accumulation.""" + self._init_steps = int(total) + self._local_steps = 0 + + def total_for_checkpoint(self) -> int: + """Global consumed sample count including this session (collective over + DP group).""" + return self._init_steps + reduce_sum_across_dp_group(self._dp_mesh, self._local_steps) + + +def apply_old_ckpt_init_steps(sampler: object, sampler_state: dict, train_state_total: int | None) -> None: + """If the sampler checkpoint predates ``total_consumed_steps``, copy the + total from ``train_state``.""" + if train_state_total is None: + return + if sampler_state.get("total_consumed_steps") is not None: + return + consumed: ConsumedStepsTracker | None = getattr(sampler, "_consumed", None) + if consumed is not None: + consumed.set_init_from_checkpoint(train_state_total) diff --git a/xtuner/v1/datasets/dataloader.py b/xtuner/v1/datasets/dataloader.py index bdd508d5b..e5e12acc8 100644 --- a/xtuner/v1/datasets/dataloader.py +++ b/xtuner/v1/datasets/dataloader.py @@ -5,6 +5,10 @@ from xtuner.v1.datasets.collator import ColateItem from xtuner.v1.datasets.resume import get_dataloader_state, load_dataloader_state +from xtuner.v1.utils import get_logger + + +logger = get_logger() class BaseDataloader(ABC): @@ -16,10 +20,10 @@ class BaseDataloader(ABC): """ @abstractmethod - def load_state_dict(self, state_dict: dict) -> None: ... + def load_state_dict(self, state_dict: dict, train_state_total_consumed_samples: int | None = None) -> None: ... @abstractmethod - def get_state_dict(self, consumed_samples: int) -> dict: ... + def get_state_dict(self, consumed_samples: int = -1) -> dict: ... @abstractmethod def __iter__(self) -> Iterator[list[ColateItem]]: ... @@ -33,13 +37,36 @@ class Dataloader(torch.utils.data.DataLoader, BaseDataloader): implement. """ - def load_state_dict(self, state_dict: dict) -> None: - load_dataloader_state(self, state_dict) + def load_state_dict( + self, + state_dict: dict, + train_state_total_consumed_samples: int | None = None, + ) -> None: + load_dataloader_state( + self, + state_dict, + train_state_total_consumed_samples=train_state_total_consumed_samples, + ) - def get_state_dict(self, consumed_samples: int) -> dict: + def get_state_dict(self, consumed_samples: int = -1) -> dict: + if consumed_samples != -1: + logger.warning( + "Dataloader.get_state_dict(consumed_samples=...) is deprecated; use the default (-1). " + "Consumed samples are tracked on the sampler." + ) dataloader_state = get_dataloader_state(self, consumed_samples) return cast(dict, dataloader_state) + def record_consumed_samples(self, n: int) -> None: + if hasattr(self.sampler, "record_consumed_samples"): + self.sampler.record_consumed_samples(n) + + def get_total_consumed_samples(self) -> int: + sampler = self.sampler + if hasattr(sampler, "get_total_consumed_steps"): + return int(sampler.get_total_consumed_steps()) + return 0 + # __iter__ is inherited from torch.utils.data.DataLoader # Streaming dataloader may not have `set_epoch` and `__len__` method, so we add here. diff --git a/xtuner/v1/datasets/preset_sampler.py b/xtuner/v1/datasets/preset_sampler.py index 636c9343e..5d2b49416 100644 --- a/xtuner/v1/datasets/preset_sampler.py +++ b/xtuner/v1/datasets/preset_sampler.py @@ -22,6 +22,7 @@ from xtuner.v1.utils import get_logger +from .consumed_steps import ConsumedStepsTracker from .preset_pack import PresetPackDataset @@ -116,6 +117,7 @@ def __init__( else: self.rank = 0 self.world_size = 1 + self._consumed = ConsumedStepsTracker(dp_mesh) self.dataset = dataset self.global_batch_size = global_batch_size @@ -170,19 +172,35 @@ def __len__(self) -> int: def set_epoch(self, epoch: int) -> None: self.epoch = epoch - def get_state_dict(self, step: int) -> dict: + def record_consumed_samples(self, n: int) -> None: + self._consumed.record(n) + + def get_total_consumed_steps(self) -> int: + return self._consumed.total_for_checkpoint() + + def get_state_dict(self, step: int | None = None) -> dict: # Same convention as :class:`LengthGroupedSampler`: ``step`` is the global pack offset # (modulo ``total_size``) into ``global_order``, shared across all ranks in the checkpoint. - global_step = step % self.total_size + if step is None: + total_consumed = self._consumed.total_for_checkpoint() + else: + total_consumed = int(step) + global_step = total_consumed % self.total_size return { "epoch": self.epoch, "step": global_step, + "total_consumed_steps": total_consumed, "world_size": self.world_size, "num_samples": self.num_samples, "total_size": self.total_size, } def load_state_dict(self, state_dict: dict) -> None: + tc = state_dict.get("total_consumed_steps") + if tc is not None: + self._consumed.set_init_from_checkpoint(int(tc)) + else: + self._consumed.set_init_from_checkpoint(0) if self.world_size != state_dict.get("world_size"): logger.warning( f"PresetSampler: world_size mismatch: checkpoint has " @@ -191,5 +209,4 @@ def load_state_dict(self, state_dict: dict) -> None: ) self.epoch = state_dict["epoch"] - global_step = int(state_dict["step"]) - self.step = global_step + self.step = int(state_dict["step"]) diff --git a/xtuner/v1/datasets/resume.py b/xtuner/v1/datasets/resume.py index 65ab62f3a..4c394a78c 100644 --- a/xtuner/v1/datasets/resume.py +++ b/xtuner/v1/datasets/resume.py @@ -3,6 +3,7 @@ from xtuner.v1.utils import get_logger +from .consumed_steps import apply_old_ckpt_init_steps from .packing import ExpandSoftPackDataset, _LegacySoftPackDataset from .sampler import LengthGroupedSampler, ParallelSampler @@ -15,15 +16,21 @@ class DataloaderState(TypedDict): dataset: dict -def get_dataloader_state(dataloader: DataLoader, consumed_samples: int) -> DataloaderState: +def get_dataloader_state(dataloader: DataLoader, consumed_samples: int = -1) -> DataloaderState: sampler: ParallelSampler | LengthGroupedSampler = dataloader.sampler # type: ignore[assignment] dataset: ExpandSoftPackDataset | _LegacySoftPackDataset = dataloader.dataset # type: ignore[assignment] dataloader_state = DataloaderState(sampler={}, dataset={}) if not hasattr(sampler, "load_state_dict") or not hasattr(sampler, "get_state_dict"): logger.warning(f"Resuming from {type(sampler)} is risky.") - else: + elif consumed_samples != -1: + logger.warning( + "Passing consumed_samples to get_dataloader_state is deprecated; " + "consumed sample totals are tracked on the sampler. Use the default consumed_samples=-1." + ) dataloader_state["sampler"].update(sampler.get_state_dict(step=consumed_samples)) + else: + dataloader_state["sampler"].update(sampler.get_state_dict()) if not hasattr(dataset, "load_state_dict") or not hasattr(dataset, "get_state_dict"): logger.warning(f"Resuming from {type(dataset)} is risky.") @@ -33,7 +40,11 @@ def get_dataloader_state(dataloader: DataLoader, consumed_samples: int) -> Datal return dataloader_state -def load_dataloader_state(dataloader: DataLoader, state: dict): +def load_dataloader_state( + dataloader: DataLoader, + state: dict, + train_state_total_consumed_samples: int | None = None, +): sampler = dataloader.sampler dataset = dataloader.dataset @@ -44,6 +55,7 @@ def load_dataloader_state(dataloader: DataLoader, state: dict): if hasattr(sampler, "load_state_dict"): sampler.load_state_dict(state["sampler"]) + apply_old_ckpt_init_steps(sampler, state["sampler"], train_state_total_consumed_samples) # If the dataset records the training progress, we also restore it. if hasattr(dataset, "load_state_dict"): diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index d4b591d6f..e1e52fc59 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -12,6 +12,7 @@ from xtuner.v1.utils import get_logger +from .consumed_steps import ConsumedStepsTracker from .jsonl import JsonlDataset from .packing import MLLMPretrainHybridPackDataset, _LegacySoftPackDataset from .preset_pack import PresetPackDataset @@ -84,6 +85,7 @@ def __init__( self.epoch = 0 self.step = 0 self.round_up = round_up + self._consumed = ConsumedStepsTracker(dp_mesh) if self.round_up: self.num_samples = math.ceil(len(self.dataset) / global_batch_size) * global_batch_size // world_size @@ -131,12 +133,23 @@ def set_epoch(self, epoch: int) -> None: """ self.epoch = epoch + def record_consumed_samples(self, n: int) -> None: + self._consumed.record(n) + + def get_total_consumed_steps(self) -> int: + return self._consumed.total_for_checkpoint() + def load_state_dict(self, state_dict) -> None: """Load the sampler state. Args: state_dict (dict): The state of the sampler. """ + tc = state_dict.get("total_consumed_steps") + if tc is not None: + self._consumed.set_init_from_checkpoint(int(tc)) + else: + self._consumed.set_init_from_checkpoint(0) self.epoch = state_dict["epoch"] self.step = state_dict["step"] @@ -146,12 +159,17 @@ def load_state_dict(self, state_dict) -> None: f"is different from the current shuffle ({self.shuffle})." ) - def get_state_dict(self, step: int): + def get_state_dict(self, step: int | None = None): # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples. - step = step % self.total_size + if step is None: + total_consumed = self._consumed.total_for_checkpoint() + else: + total_consumed = int(step) + step_mod = total_consumed % self.total_size return { "epoch": self.epoch, - "step": step, + "step": step_mod, + "total_consumed_steps": total_consumed, "world_size": self.world_size, "shuffle": self.shuffle, "round_up": self.round_up, @@ -233,6 +251,7 @@ def __init__( assert isinstance(self.max_lengths, (list, tuple, Column, np.ndarray)) self.global_batch_size = global_batch_size + self._consumed = ConsumedStepsTracker(dp_mesh) def __iter__(self) -> Iterator[int]: """Iterate the indices.""" @@ -275,12 +294,23 @@ def set_epoch(self, epoch: int) -> None: """ self.epoch = epoch + def record_consumed_samples(self, n: int) -> None: + self._consumed.record(n) + + def get_total_consumed_steps(self) -> int: + return self._consumed.total_for_checkpoint() + def load_state_dict(self, state_dict: dict) -> None: """Load the sampler state. Args: state_dict (dict): The state of the sampler. """ + tc = state_dict.get("total_consumed_steps") + if tc is not None: + self._consumed.set_init_from_checkpoint(int(tc)) + else: + self._consumed.set_init_from_checkpoint(0) self.epoch = state_dict["epoch"] self.step = state_dict["step"] @@ -298,17 +328,22 @@ def load_state_dict(self, state_dict: dict) -> None: ) self.group_size = origin_group_size - def get_state_dict(self, step: int): + def get_state_dict(self, step: int | None = None): """Get the sampler state dict. Returns: dict: The state of the sampler. """ # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples. - step = step % self.total_size + if step is None: + total_consumed = self._consumed.total_for_checkpoint() + else: + total_consumed = int(step) + step_mod = total_consumed % self.total_size return { "epoch": self.epoch, - "step": step, + "step": step_mod, + "total_consumed_steps": total_consumed, "world_size": self.world_size, "round_up": self.round_up, "num_samples": self.num_samples, diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 3e2619c38..942359d91 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -540,9 +540,7 @@ def __init__( # 时间上,当前步用 step_, 累积用 total_ # self._local_total_consumed_tokens 表示时间上累积到现在的当前rank的和,resume则只考虑resume步数到现在 self._local_total_consumed_tokens = 0 - self._local_total_samples = 0 self._init_total_tokens = 0 - self._init_total_samples = 0 self._train_time = 0 self._train_time_offset = 0 @@ -764,7 +762,7 @@ def fit(self): self._cur_step += 1 step_tokens = train_step_info["step_consumed_tokens"] self._local_total_consumed_tokens += step_tokens - self._local_total_samples += consumed_samples + self._dataloader.record_consumed_samples(consumed_samples) self._train_time = time_after_train_step - train_begin # Compute training metrics @@ -1133,10 +1131,10 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: total_consumed_tokens = ( self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens ) - total_consumed_samples = self._reduce_number_across_rank(self._local_total_samples) + self._init_total_samples + total_consumed_samples = self._dataloader.get_total_consumed_samples() # Save dataloader - self._save_dataloader(dataloader_path, total_consumed_samples=total_consumed_samples) + self._save_dataloader(dataloader_path) DEVICE_MODULE.empty_cache() @@ -1215,9 +1213,9 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: return True - def _save_dataloader(self, dataloader_path: Path | str, total_consumed_samples: int): + def _save_dataloader(self, dataloader_path: Path | str): + dataloader_state = self._dataloader.get_state_dict() if self.rank == 0: - dataloader_state = self._dataloader.get_state_dict(total_consumed_samples) torch.save(dataloader_state, dataloader_path) @property @@ -1813,15 +1811,12 @@ def _load_checkpoint(self): if load_checkpoint_cfg.load_dataset: self._train_time_offset = train_state["train_time_offset"] self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC - # TODO: total_samples 由 Dataloader 维护, 包括 save/resume - # self._init_total_samples 会影响 save dcp时 dataloader.get_state_dict的状态。 - # 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值。 - # 2) 如果不加载 dataset,应该保持 self._init_total_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples - # 会导致存储新dataloader时 total_consumed_samples 是不正确的值。 - self._init_total_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC dataloader_path = resume_from / self._SAVE_DATALOADER_DIR - self._resume_dataloader(dataloader_path) + self._resume_dataloader( + dataloader_path, + train_state_total_consumed_samples=train_state.get("total_consumed_samples"), + ) if load_checkpoint_cfg.load_scheduler: scheduler_path = resume_from / self._SAVE_SCHEDULER_DIR @@ -1834,11 +1829,14 @@ def _load_checkpoint(self): scheduler_step = self.total_step - self._cur_step self._lr_scheduler = self.build_lr_scheduler(self._lr_cfg, scheduler_step) - def _resume_dataloader(self, dataloader_path: Path): + def _resume_dataloader(self, dataloader_path: Path, train_state_total_consumed_samples: int | None = None): if not dataloader_path.exists(): raise FileNotFoundError(f"Dataloader path {dataloader_path} does not exist.") dataloader_state = torch.load(dataloader_path, map_location=DEVICE) - self._dataloader.load_state_dict(dataloader_state) + self._dataloader.load_state_dict( + dataloader_state, + train_state_total_consumed_samples=train_state_total_consumed_samples, + ) def _setup_hooks(self, hooks_config: HooksConfig) -> HooksConfig: for stage in HookStage: From 418bb4ac01fbb521354e06ceb6738f92d26ddb5c Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Fri, 3 Apr 2026 11:08:34 +0000 Subject: [PATCH 05/14] refine local_steps update logic --- xtuner/v1/datasets/dataloader.py | 4 ---- xtuner/v1/datasets/preset_sampler.py | 7 +++---- xtuner/v1/datasets/sampler.py | 14 ++++++-------- xtuner/v1/train/trainer.py | 2 -- 4 files changed, 9 insertions(+), 18 deletions(-) diff --git a/xtuner/v1/datasets/dataloader.py b/xtuner/v1/datasets/dataloader.py index e5e12acc8..b5cea5a6d 100644 --- a/xtuner/v1/datasets/dataloader.py +++ b/xtuner/v1/datasets/dataloader.py @@ -57,10 +57,6 @@ def get_state_dict(self, consumed_samples: int = -1) -> dict: dataloader_state = get_dataloader_state(self, consumed_samples) return cast(dict, dataloader_state) - def record_consumed_samples(self, n: int) -> None: - if hasattr(self.sampler, "record_consumed_samples"): - self.sampler.record_consumed_samples(n) - def get_total_consumed_samples(self) -> int: sampler = self.sampler if hasattr(sampler, "get_total_consumed_steps"): diff --git a/xtuner/v1/datasets/preset_sampler.py b/xtuner/v1/datasets/preset_sampler.py index 5d2b49416..aae511de1 100644 --- a/xtuner/v1/datasets/preset_sampler.py +++ b/xtuner/v1/datasets/preset_sampler.py @@ -163,7 +163,9 @@ def __init__( def __iter__(self) -> Iterator[int]: # load order from npy → global_order → rank_view 类型均为 memmap, 子视图 的路径仍然保持 # memmap 语义(视图、按需分页、文件后端);单机多进程可共享同一份文件页缓存 - yield from self.global_order[self.step + self.rank : self.total_size : self.world_size] + for idx in self.global_order[self.step + self.rank : self.total_size : self.world_size]: + self._consumed.record(1) + yield int(idx) self.step = 0 def __len__(self) -> int: @@ -172,9 +174,6 @@ def __len__(self) -> int: def set_epoch(self, epoch: int) -> None: self.epoch = epoch - def record_consumed_samples(self, n: int) -> None: - self._consumed.record(n) - def get_total_consumed_steps(self) -> int: return self._consumed.total_for_checkpoint() diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index e1e52fc59..7aa58bb4c 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -113,7 +113,9 @@ def __iter__(self) -> Iterator[int]: # subsample indices = indices[self.step + self.rank : self.total_size : self.world_size] - yield from iter(indices) + for idx in indices: + self._consumed.record(1) + yield idx self.step = 0 def __len__(self) -> int: @@ -133,9 +135,6 @@ def set_epoch(self, epoch: int) -> None: """ self.epoch = epoch - def record_consumed_samples(self, n: int) -> None: - self._consumed.record(n) - def get_total_consumed_steps(self) -> int: return self._consumed.total_for_checkpoint() @@ -275,7 +274,9 @@ def __iter__(self) -> Iterator[int]: assert len(indices) == self.total_size indices = indices[self.step + self.rank : self.total_size : self.world_size] assert len(indices) == self.num_samples - self.step // self.world_size - yield from iter(indices) + for idx in indices: + self._consumed.record(1) + yield idx self.step = 0 def __len__(self) -> int: @@ -294,9 +295,6 @@ def set_epoch(self, epoch: int) -> None: """ self.epoch = epoch - def record_consumed_samples(self, n: int) -> None: - self._consumed.record(n) - def get_total_consumed_steps(self) -> int: return self._consumed.total_for_checkpoint() diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 942359d91..8c0dd64ea 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -727,7 +727,6 @@ def fit(self): train_begin = time.time() time_before_get_data = time.time() for data_batch in self._data_iter(): - consumed_samples = len(data_batch) time_before_train_step = time.time() ProberList.set_step(self._cur_step + 1) @@ -762,7 +761,6 @@ def fit(self): self._cur_step += 1 step_tokens = train_step_info["step_consumed_tokens"] self._local_total_consumed_tokens += step_tokens - self._dataloader.record_consumed_samples(consumed_samples) self._train_time = time_after_train_step - train_begin # Compute training metrics From 920c24f2f323013b82c71d965c92238146a3fdea Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Fri, 3 Apr 2026 11:18:40 +0000 Subject: [PATCH 06/14] _save_dataloader return total_consumed_samples --- xtuner/v1/datasets/dataloader.py | 6 ------ xtuner/v1/datasets/preset_sampler.py | 3 --- xtuner/v1/datasets/sampler.py | 6 ------ xtuner/v1/train/trainer.py | 6 +++--- 4 files changed, 3 insertions(+), 18 deletions(-) diff --git a/xtuner/v1/datasets/dataloader.py b/xtuner/v1/datasets/dataloader.py index b5cea5a6d..566acf7b0 100644 --- a/xtuner/v1/datasets/dataloader.py +++ b/xtuner/v1/datasets/dataloader.py @@ -57,12 +57,6 @@ def get_state_dict(self, consumed_samples: int = -1) -> dict: dataloader_state = get_dataloader_state(self, consumed_samples) return cast(dict, dataloader_state) - def get_total_consumed_samples(self) -> int: - sampler = self.sampler - if hasattr(sampler, "get_total_consumed_steps"): - return int(sampler.get_total_consumed_steps()) - return 0 - # __iter__ is inherited from torch.utils.data.DataLoader # Streaming dataloader may not have `set_epoch` and `__len__` method, so we add here. diff --git a/xtuner/v1/datasets/preset_sampler.py b/xtuner/v1/datasets/preset_sampler.py index aae511de1..358c928be 100644 --- a/xtuner/v1/datasets/preset_sampler.py +++ b/xtuner/v1/datasets/preset_sampler.py @@ -174,9 +174,6 @@ def __len__(self) -> int: def set_epoch(self, epoch: int) -> None: self.epoch = epoch - def get_total_consumed_steps(self) -> int: - return self._consumed.total_for_checkpoint() - def get_state_dict(self, step: int | None = None) -> dict: # Same convention as :class:`LengthGroupedSampler`: ``step`` is the global pack offset # (modulo ``total_size``) into ``global_order``, shared across all ranks in the checkpoint. diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index 7aa58bb4c..e82eb7b53 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -135,9 +135,6 @@ def set_epoch(self, epoch: int) -> None: """ self.epoch = epoch - def get_total_consumed_steps(self) -> int: - return self._consumed.total_for_checkpoint() - def load_state_dict(self, state_dict) -> None: """Load the sampler state. @@ -295,9 +292,6 @@ def set_epoch(self, epoch: int) -> None: """ self.epoch = epoch - def get_total_consumed_steps(self) -> int: - return self._consumed.total_for_checkpoint() - def load_state_dict(self, state_dict: dict) -> None: """Load the sampler state. diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 8c0dd64ea..638f7c596 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -1129,10 +1129,9 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: total_consumed_tokens = ( self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens ) - total_consumed_samples = self._dataloader.get_total_consumed_samples() # Save dataloader - self._save_dataloader(dataloader_path) + total_consumed_samples = self._save_dataloader(dataloader_path) DEVICE_MODULE.empty_cache() @@ -1211,10 +1210,11 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: return True - def _save_dataloader(self, dataloader_path: Path | str): + def _save_dataloader(self, dataloader_path: Path | str) -> int: dataloader_state = self._dataloader.get_state_dict() if self.rank == 0: torch.save(dataloader_state, dataloader_path) + return int(dataloader_state.get("sampler", {}).get("total_consumed_steps", 0)) @property def work_dir(self) -> Path: From ed91fb1dcdbbc978b491e2b64a3238862b442e75 Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Fri, 3 Apr 2026 11:56:06 +0000 Subject: [PATCH 07/14] refine code --- xtuner/v1/datasets/dataloader.py | 1 + xtuner/v1/datasets/preset_sampler.py | 7 ++----- xtuner/v1/datasets/resume.py | 3 ++- xtuner/v1/datasets/sampler.py | 14 ++++---------- 4 files changed, 9 insertions(+), 16 deletions(-) diff --git a/xtuner/v1/datasets/dataloader.py b/xtuner/v1/datasets/dataloader.py index 566acf7b0..8665a8b03 100644 --- a/xtuner/v1/datasets/dataloader.py +++ b/xtuner/v1/datasets/dataloader.py @@ -54,6 +54,7 @@ def get_state_dict(self, consumed_samples: int = -1) -> dict: "Dataloader.get_state_dict(consumed_samples=...) is deprecated; use the default (-1). " "Consumed samples are tracked on the sampler." ) + # TODO: remove consumed_samples parameter in get_dataloader_state in next major release dataloader_state = get_dataloader_state(self, consumed_samples) return cast(dict, dataloader_state) diff --git a/xtuner/v1/datasets/preset_sampler.py b/xtuner/v1/datasets/preset_sampler.py index 358c928be..234dfac6a 100644 --- a/xtuner/v1/datasets/preset_sampler.py +++ b/xtuner/v1/datasets/preset_sampler.py @@ -192,11 +192,8 @@ def get_state_dict(self, step: int | None = None) -> dict: } def load_state_dict(self, state_dict: dict) -> None: - tc = state_dict.get("total_consumed_steps") - if tc is not None: - self._consumed.set_init_from_checkpoint(int(tc)) - else: - self._consumed.set_init_from_checkpoint(0) + tc = int(state_dict.get("total_consumed_steps", 0)) + self._consumed.set_init_from_checkpoint(tc) if self.world_size != state_dict.get("world_size"): logger.warning( f"PresetSampler: world_size mismatch: checkpoint has " diff --git a/xtuner/v1/datasets/resume.py b/xtuner/v1/datasets/resume.py index 4c394a78c..d5bf0c57c 100644 --- a/xtuner/v1/datasets/resume.py +++ b/xtuner/v1/datasets/resume.py @@ -5,6 +5,7 @@ from .consumed_steps import apply_old_ckpt_init_steps from .packing import ExpandSoftPackDataset, _LegacySoftPackDataset +from .preset_sampler import PresetSampler from .sampler import LengthGroupedSampler, ParallelSampler @@ -45,7 +46,7 @@ def load_dataloader_state( state: dict, train_state_total_consumed_samples: int | None = None, ): - sampler = dataloader.sampler + sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = dataloader.sampler # type: ignore[assignment] dataset = dataloader.dataset # Sampler require `load_state_dict` to restore the training progress since the sampler state will diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index e82eb7b53..ef75f32b8 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -141,11 +141,8 @@ def load_state_dict(self, state_dict) -> None: Args: state_dict (dict): The state of the sampler. """ - tc = state_dict.get("total_consumed_steps") - if tc is not None: - self._consumed.set_init_from_checkpoint(int(tc)) - else: - self._consumed.set_init_from_checkpoint(0) + tc = int(state_dict.get("total_consumed_steps", 0)) + self._consumed.set_init_from_checkpoint(tc) self.epoch = state_dict["epoch"] self.step = state_dict["step"] @@ -298,11 +295,8 @@ def load_state_dict(self, state_dict: dict) -> None: Args: state_dict (dict): The state of the sampler. """ - tc = state_dict.get("total_consumed_steps") - if tc is not None: - self._consumed.set_init_from_checkpoint(int(tc)) - else: - self._consumed.set_init_from_checkpoint(0) + tc = int(state_dict.get("total_consumed_steps", 0)) + self._consumed.set_init_from_checkpoint(tc) self.epoch = state_dict["epoch"] self.step = state_dict["step"] From fa2b8fc65b7a04690696da73d2ce7e67c443a5cc Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Fri, 3 Apr 2026 12:45:28 +0000 Subject: [PATCH 08/14] fix RL worker's sft_dataloader save --- xtuner/v1/rl/base/worker.py | 39 ++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 855bc589a..bb48b6d87 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -263,7 +263,6 @@ def _init_sft(self, worker_cfg: WorkerConfig): self._rollout_step = 0 self._sft_cur_epoch = 0 - self._sft_total_consumed_samples = 0 self._sft_total_consumed_tokens = 0 if self._sft_dataloader_config is not None: @@ -672,7 +671,6 @@ def _fit_sft(self): time_before_train_step = time.time() data_time = time_before_train_step - time_before_get_data DEVICE_MODULE.reset_peak_memory_stats() - cur_sample_num = len(data_batch) train_step_info, grad_norm = self._train_one_step_sft(data_batch) @@ -680,7 +678,6 @@ def _fit_sft(self): step_time = time_after_train_step - time_before_train_step step_consumed_tokens = train_step_info["step_consumed_tokens"] - self._sft_total_consumed_samples += self._reduce_number_across_rank(cur_sample_num) reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens) self._sft_total_consumed_tokens += reduced_step_consumed_tokens @@ -1391,9 +1388,13 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False): ) # Save sft dataloader - if self.rank == 0 and self._sft_dataloader is not None: + if self._sft_dataloader is not None: sft_dataloader_path = checkpoint_path / self._SAVE_SFT_DATALOADER_DIR - dataloader_state = self._sft_dataloader.get_state_dict(self._sft_total_consumed_samples) + dataloader_state = self._sft_dataloader.get_state_dict() + total_consumed_samples = int(dataloader_state.get("sampler", {}).get("total_consumed_steps", 0)) + if self.rank != 0: + return + torch.save(dataloader_state, sft_dataloader_path) train_state_path = checkpoint_path / self._SAVE_SFT_TRAIN_STATE_PATH @@ -1403,7 +1404,7 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False): { "cur_step": self._rollout_step, "cur_epoch": self._sft_cur_epoch, - "total_consumed_samples": self._sft_total_consumed_samples, + "total_consumed_samples": total_consumed_samples, "total_consumed_tokens": self._sft_total_consumed_tokens, } ) @@ -1437,24 +1438,26 @@ def resume(self, load_checkpoint_cfg: LoadCheckpointConfig): ) # Resume sft dataloader - sft_dataloader_path = resume_from / self._SAVE_SFT_DATALOADER_DIR if self._sft_dataloader is not None: - if not sft_dataloader_path.exists(): - raise FileNotFoundError(f"Dataloader path {sft_dataloader_path} does not exist.") - dataloader_state = torch.load(sft_dataloader_path, map_location=DEVICE) - self._sft_dataloader.load_state_dict(dataloader_state) - self.logger.info(f"Resume sft dataloader from {sft_dataloader_path}") - train_state_path = resume_from / self._SAVE_SFT_TRAIN_STATE_PATH if not train_state_path.exists(): raise FileNotFoundError(f"Train state path {train_state_path} does not exist.") with train_state_path.open("r") as f: train_state = json.loads(f.read()) - self._rollout_step = train_state["cur_step"] - self._sft_cur_epoch = train_state["cur_epoch"] - self._sft_total_consumed_samples = train_state["total_consumed_samples"] - self._sft_total_consumed_tokens = train_state["total_consumed_tokens"] - self.logger.info(f"Resume sft train state from {train_state_path}") + self._rollout_step = train_state["cur_step"] + self._sft_cur_epoch = train_state["cur_epoch"] + self._sft_total_consumed_tokens = train_state["total_consumed_tokens"] + self.logger.info(f"Resume sft train state from {train_state_path}") + + sft_dataloader_path = resume_from / self._SAVE_SFT_DATALOADER_DIR + if not sft_dataloader_path.exists(): + raise FileNotFoundError(f"Dataloader path {sft_dataloader_path} does not exist.") + dataloader_state = torch.load(sft_dataloader_path, map_location=DEVICE) + self._sft_dataloader.load_state_dict( + dataloader_state, + train_state_total_consumed_samples=train_state.get("total_consumed_samples", 0), + ) + self.logger.info(f"Resume sft dataloader from {sft_dataloader_path}") @ray_method def ready(self) -> bool: From 767d9ccd9597e97b86b8070e026bcce873a992eb Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Fri, 3 Apr 2026 13:38:35 +0000 Subject: [PATCH 09/14] fix sampler save when dataloader num_workers > 0 --- tests/datasets/test_dataloader.py | 79 ++++++++++++++-------------- xtuner/v1/datasets/dataloader.py | 18 ++++++- xtuner/v1/datasets/preset_sampler.py | 4 +- xtuner/v1/datasets/sampler.py | 8 +-- 4 files changed, 59 insertions(+), 50 deletions(-) diff --git a/tests/datasets/test_dataloader.py b/tests/datasets/test_dataloader.py index fa8ee4bec..5a20d8ba2 100644 --- a/tests/datasets/test_dataloader.py +++ b/tests/datasets/test_dataloader.py @@ -1,12 +1,21 @@ from pathlib import Path import os import pickle +import socket import torch -from xtuner.v1.datasets import build_dataloader, build_datasets, get_dataloader_state, load_dataloader_state, FTDPTokenizeFnConfig, DatasetConfig, DataloaderConfig +from xtuner.v1.datasets import ( + DataloaderConfig, + DatasetConfig, + FTDPTokenizeFnConfig, + build_dataloader, + build_datasets, + get_dataloader_state, + load_dataloader_state, +) from xtuner.v1.train.toy_tokenizer import UTF8ByteTokenizer -from torch.multiprocessing import spawn, get_context +from torch.multiprocessing import spawn from torch.distributed.device_mesh import init_device_mesh import pytest @@ -15,6 +24,12 @@ from itertools import repeat, chain +def _alloc_master_port() -> None: + """Bind an ephemeral TCP port so concurrent test runs avoid EADDRINUSE on a fixed port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + os.environ["MASTER_PORT"] = str(s.getsockname()[1]) + class RandomDataset: @@ -282,65 +297,53 @@ def _test_resume_spmd( rank: int, world_size: int, dataloader_config: DataloaderConfig, - dataset_configs: list[dict], global_batch_size: int, micro_batch_size: int, - step:int, + step: int, seed: int, save_path: Path, dataloader_state: dict | None = None, - consumed_samples: int = 0, ): os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29505" - + os.environ.setdefault("MASTER_ADDR", "localhost") + if "MASTER_PORT" not in os.environ: + raise RuntimeError("tests must call _alloc_master_port() before torch.multiprocessing.spawn") torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) data_mesh = init_device_mesh( device_type="cuda", - mesh_shape=(world_size,) + mesh_shape=(world_size,), ) tokenizer = UTF8ByteTokenizer() - datasets = build_datasets( - dataset_config=dataset_configs, + dataloader = dataloader_config.build( tokenizer=tokenizer, - ) - dataloader = build_dataloader( - dataloader_config=dataloader_config, - datasets=datasets, + dp_mesh=data_mesh, global_batch_size=global_batch_size, micro_batch_size=micro_batch_size, seed=seed, - dp_mesh=data_mesh, ) if dataloader_state is not None: - load_dataloader_state(dataloader, dataloader_state) + dataloader.load_state_dict(dataloader_state) data_iter = iter(dataloader) data_list = [] for _ in range(step): batch = next(data_iter) data_list.append(batch) - consumed_samples += len(batch) - consumed_samples_list = [None for _ in range(world_size)] - torch.distributed.all_gather_object(consumed_samples_list, consumed_samples) - global_consumed_samples = sum(consumed_samples_list) + # Snapshot after the first `step` batches so total_consumed_steps matches resume intent. + dataloader_state = dataloader.get_state_dict() expected_data = [] - for _ in range(step): batch = next(data_iter) expected_data.append(batch) - dataloader_state = get_dataloader_state(dataloader, global_consumed_samples) - all_data_list = [None for _ in range(world_size)] torch.distributed.all_gather_object(all_data_list, list(chain(*data_list))) @@ -372,7 +375,6 @@ def _test_resume_spmd( "dataloader_state": dataloader_state, "data_list": all_data_list, "expected_data": all_expected_data, - "consumed_samples": consumed_samples } ) ) @@ -389,7 +391,6 @@ def _test_resume_spmd( ("none", 0, False), ("soft", 0, True), ("soft", 4, True), - ("soft", 4, True), ] ) def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, group_by_length): @@ -402,36 +403,36 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou _create_fake_dataset(data_dir1 / f"depth3", dataset_num=3, max_depth=3, dup_times=9) # 1. Test resuming with the same world size + dataset_configs = [ + { + "dataset": DatasetConfig(anno_path=str(data_dir1)), + "tokenize_fn": FTDPTokenizeFnConfig(max_length=1024), + }, + ] + dataloader_config = DataloaderConfig( + dataset_config_list=dataset_configs, pack_max_length=1024, pack_level=pack_level, num_workers=num_workers, group_by_length=group_by_length, - collator="fake_collator" + collator="fake_collator", ) - dataset_configs = [ - { - "dataset": DatasetConfig(anno_path=str(data_dir1)), - "tokenize_fn": FTDPTokenizeFnConfig(max_length=1024) - }, - ] - ctx = get_context("spawn") world_size = 2 save_path1 = tmp_path / "dataloader_state.pkl" + _alloc_master_port() spawn( _test_resume_spmd, args=( world_size, dataloader_config, - dataset_configs, 16, BATCH_SIZE, TOTAL_STEP, 10, save_path1, None, - 0, ), nprocs=2, join=True, @@ -443,19 +444,18 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou # 2. tet Rsume with same world size save_path2 = tmp_path / "dataloader_state2.pkl" + _alloc_master_port() spawn( _test_resume_spmd, args=( world_size, dataloader_config, - dataset_configs, 16, BATCH_SIZE, TOTAL_STEP, 10, save_path2, result1["dataloader_state"], - result1["consumed_samples"], ), nprocs=world_size, join=True, @@ -470,19 +470,18 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou world_size = 4 save_path3 = tmp_path / "dataloader_state3.pkl" + _alloc_master_port() spawn( _test_resume_spmd, args=( world_size, dataloader_config, - dataset_configs, 16, BATCH_SIZE, TOTAL_STEP, 10, save_path3, result1["dataloader_state"], - result1["consumed_samples"], ), nprocs=world_size, join=True, diff --git a/xtuner/v1/datasets/dataloader.py b/xtuner/v1/datasets/dataloader.py index 8665a8b03..d2ca6f0e9 100644 --- a/xtuner/v1/datasets/dataloader.py +++ b/xtuner/v1/datasets/dataloader.py @@ -4,6 +4,7 @@ import torch from xtuner.v1.datasets.collator import ColateItem +from xtuner.v1.datasets.consumed_steps import ConsumedStepsTracker from xtuner.v1.datasets.resume import get_dataloader_state, load_dataloader_state from xtuner.v1.utils import get_logger @@ -42,6 +43,11 @@ def load_state_dict( state_dict: dict, train_state_total_consumed_samples: int | None = None, ) -> None: + if train_state_total_consumed_samples is not None: + logger.warning( + "Dataloader.load_state_dict(train_state_total_consumed_samples=...) is deprecated; " + "use the default (None). Consumed samples are tracked on the sampler." + ) load_dataloader_state( self, state_dict, @@ -58,7 +64,17 @@ def get_state_dict(self, consumed_samples: int = -1) -> dict: dataloader_state = get_dataloader_state(self, consumed_samples) return cast(dict, dataloader_state) - # __iter__ is inherited from torch.utils.data.DataLoader + def __iter__(self) -> Iterator[list[ColateItem]]: # type: ignore[override] + # Override to count delivered batches, not prefetched indices. + # With num_workers > 0 the sampler is iterated ahead by DataLoader's prefetch queue, + # so recording inside sampler.__iter__ would count too many samples. Instead we + # increment _consumed exactly once per batch that reaches the caller. + sampler = self.sampler + consumed: ConsumedStepsTracker | None = getattr(sampler, "_consumed", None) + for batch in super().__iter__(): + if consumed is not None: + consumed.record(len(batch)) + yield batch # Streaming dataloader may not have `set_epoch` and `__len__` method, so we add here. def set_epoch(self, epoch: int) -> None: diff --git a/xtuner/v1/datasets/preset_sampler.py b/xtuner/v1/datasets/preset_sampler.py index 234dfac6a..031f69ba5 100644 --- a/xtuner/v1/datasets/preset_sampler.py +++ b/xtuner/v1/datasets/preset_sampler.py @@ -163,9 +163,7 @@ def __init__( def __iter__(self) -> Iterator[int]: # load order from npy → global_order → rank_view 类型均为 memmap, 子视图 的路径仍然保持 # memmap 语义(视图、按需分页、文件后端);单机多进程可共享同一份文件页缓存 - for idx in self.global_order[self.step + self.rank : self.total_size : self.world_size]: - self._consumed.record(1) - yield int(idx) + yield from (int(idx) for idx in self.global_order[self.step + self.rank : self.total_size : self.world_size]) self.step = 0 def __len__(self) -> int: diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index ef75f32b8..9e7dbf368 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -113,9 +113,7 @@ def __iter__(self) -> Iterator[int]: # subsample indices = indices[self.step + self.rank : self.total_size : self.world_size] - for idx in indices: - self._consumed.record(1) - yield idx + yield from indices self.step = 0 def __len__(self) -> int: @@ -268,9 +266,7 @@ def __iter__(self) -> Iterator[int]: assert len(indices) == self.total_size indices = indices[self.step + self.rank : self.total_size : self.world_size] assert len(indices) == self.num_samples - self.step // self.world_size - for idx in indices: - self._consumed.record(1) - yield idx + yield from indices self.step = 0 def __len__(self) -> int: From eadc87aba8b6a161cc494fea62cb5841847471c8 Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Fri, 3 Apr 2026 14:21:02 +0000 Subject: [PATCH 10/14] refine test dataloader ut --- tests/datasets/test_dataloader.py | 51 +++++++++++----------------- xtuner/v1/datasets/preset_sampler.py | 2 +- xtuner/v1/datasets/sampler.py | 4 +-- 3 files changed, 22 insertions(+), 35 deletions(-) diff --git a/tests/datasets/test_dataloader.py b/tests/datasets/test_dataloader.py index 5a20d8ba2..e5b26e150 100644 --- a/tests/datasets/test_dataloader.py +++ b/tests/datasets/test_dataloader.py @@ -10,9 +10,6 @@ DatasetConfig, FTDPTokenizeFnConfig, build_dataloader, - build_datasets, - get_dataloader_state, - load_dataloader_state, ) from xtuner.v1.train.toy_tokenizer import UTF8ByteTokenizer from torch.multiprocessing import spawn @@ -197,11 +194,12 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro dataset_configs = [ { "dataset": DatasetConfig(anno_path=str(data_dir1)), - "tokenize_fn": FTDPTokenizeFnConfig(max_length=1024) + "tokenize_fn": FTDPTokenizeFnConfig(max_length=1024), }, ] dataloader_config = DataloaderConfig( + dataset_config_list=dataset_configs, pack_max_length=1024, pack_level=pack_level, num_workers=num_workers, @@ -209,13 +207,9 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro pack_workers=pack_workers, ) - datasets = build_datasets( - dataset_config=dataset_configs, + dataloader1 = dataloader_config.build( tokenizer=tokenizer, - ) - dataloader1 = build_dataloader( - dataloader_config=dataloader_config, - datasets=datasets, + dp_mesh=None, global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=BATCH_SIZE, seed=10, @@ -225,26 +219,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro assert len(dataloader1) > 10 dataloader_iter = iter(dataloader1) - consumed_sample = 0 for _ in range(RESUME_ITER): - batch = next(dataloader_iter) - consumed_sample += len(batch) + next(dataloader_iter) - dataloader_state = get_dataloader_state(dataloader1, consumed_sample) + dataloader_state = dataloader1.get_state_dict() expected_data = [] for _ in range(AFTER_RESUME_ITER): - batch = next(dataloader_iter) - consumed_sample += len(batch) - expected_data.append(batch) + expected_data.append(next(dataloader_iter)) - new_dataloader1 = build_dataloader( - dataloader_config=dataloader_config, - datasets=datasets, + new_dataloader1 = dataloader_config.build( + tokenizer=tokenizer, + dp_mesh=None, global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=BATCH_SIZE, seed=10, ) - load_dataloader_state(new_dataloader1, dataloader_state) + new_dataloader1.load_state_dict(dataloader_state) new_dataloader_iter = iter(new_dataloader1) resume_data = [] @@ -257,32 +247,29 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro # 2. Test resume after consuming multiple epochs while True: try: - batch = next(dataloader_iter) - consumed_sample += len(batch) + next(dataloader_iter) except StopIteration: break - dataloader_iter = iter(dataloader1) - for batch in range(RESUME_ITER): - batch = next(dataloader_iter) - consumed_sample += len(batch) + for _ in range(RESUME_ITER): + next(dataloader_iter) - dataloader_state = get_dataloader_state(dataloader1, consumed_sample) + dataloader_state = dataloader1.get_state_dict() expected_data = [] for _ in range(AFTER_RESUME_ITER): expected_data.append(next(dataloader_iter)) - new_dataloader2 = build_dataloader( - dataloader_config=dataloader_config, - datasets=datasets, + new_dataloader2 = dataloader_config.build( + tokenizer=tokenizer, + dp_mesh=None, global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=BATCH_SIZE, seed=10, ) - load_dataloader_state(new_dataloader2, dataloader_state) + new_dataloader2.load_state_dict(dataloader_state) new_dataloader_iter2 = iter(new_dataloader2) resume_data = [] diff --git a/xtuner/v1/datasets/preset_sampler.py b/xtuner/v1/datasets/preset_sampler.py index 031f69ba5..42b6804ad 100644 --- a/xtuner/v1/datasets/preset_sampler.py +++ b/xtuner/v1/datasets/preset_sampler.py @@ -163,7 +163,7 @@ def __init__( def __iter__(self) -> Iterator[int]: # load order from npy → global_order → rank_view 类型均为 memmap, 子视图 的路径仍然保持 # memmap 语义(视图、按需分页、文件后端);单机多进程可共享同一份文件页缓存 - yield from (int(idx) for idx in self.global_order[self.step + self.rank : self.total_size : self.world_size]) + yield from self.global_order[self.step + self.rank : self.total_size : self.world_size] self.step = 0 def __len__(self) -> int: diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index 9e7dbf368..d6c616d62 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -113,7 +113,7 @@ def __iter__(self) -> Iterator[int]: # subsample indices = indices[self.step + self.rank : self.total_size : self.world_size] - yield from indices + yield from iter(indices) self.step = 0 def __len__(self) -> int: @@ -266,7 +266,7 @@ def __iter__(self) -> Iterator[int]: assert len(indices) == self.total_size indices = indices[self.step + self.rank : self.total_size : self.world_size] assert len(indices) == self.num_samples - self.step // self.world_size - yield from indices + yield from iter(indices) self.step = 0 def __len__(self) -> int: From d0aea4b781724e5591d593c321f5b38e449489fa Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Fri, 3 Apr 2026 14:50:41 +0000 Subject: [PATCH 11/14] fix master port in ut --- tests/datasets/test_dataloader.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/tests/datasets/test_dataloader.py b/tests/datasets/test_dataloader.py index e5b26e150..847d8b1ca 100644 --- a/tests/datasets/test_dataloader.py +++ b/tests/datasets/test_dataloader.py @@ -21,14 +21,6 @@ from itertools import repeat, chain -def _alloc_master_port() -> None: - """Bind an ephemeral TCP port so concurrent test runs avoid EADDRINUSE on a fixed port.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - os.environ["MASTER_PORT"] = str(s.getsockname()[1]) - - - class RandomDataset: def __init__(self, size: int, **kwargs): self.size = size @@ -294,9 +286,8 @@ def _test_resume_spmd( os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) - os.environ.setdefault("MASTER_ADDR", "localhost") - if "MASTER_PORT" not in os.environ: - raise RuntimeError("tests must call _alloc_master_port() before torch.multiprocessing.spawn") + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29505" torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) @@ -408,7 +399,6 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou world_size = 2 save_path1 = tmp_path / "dataloader_state.pkl" - _alloc_master_port() spawn( _test_resume_spmd, args=( @@ -431,7 +421,6 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou # 2. tet Rsume with same world size save_path2 = tmp_path / "dataloader_state2.pkl" - _alloc_master_port() spawn( _test_resume_spmd, args=( @@ -457,7 +446,6 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou world_size = 4 save_path3 = tmp_path / "dataloader_state3.pkl" - _alloc_master_port() spawn( _test_resume_spmd, args=( From 96e3033a8010d97d267cba486675ed53479fbcb7 Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Tue, 7 Apr 2026 06:00:39 +0000 Subject: [PATCH 12/14] DataLoader handles total consumed samples --- tests/datasets/test_dataloader.py | 2 +- tests/datasets/test_preset_dataloader.py | 6 +- tests/datasets/test_preset_sampler.py | 4 +- xtuner/v1/datasets/__init__.py | 3 - xtuner/v1/datasets/config.py | 1 + xtuner/v1/datasets/consumed_steps.py | 65 ------------- xtuner/v1/datasets/dataloader.py | 119 ++++++++++++++++++----- xtuner/v1/datasets/preset_sampler.py | 13 +-- xtuner/v1/datasets/resume.py | 63 ------------ xtuner/v1/datasets/sampler.py | 25 +---- xtuner/v1/rl/base/worker.py | 2 +- xtuner/v1/train/trainer.py | 2 +- 12 files changed, 107 insertions(+), 198 deletions(-) delete mode 100644 xtuner/v1/datasets/consumed_steps.py delete mode 100644 xtuner/v1/datasets/resume.py diff --git a/tests/datasets/test_dataloader.py b/tests/datasets/test_dataloader.py index 847d8b1ca..9aa82e695 100644 --- a/tests/datasets/test_dataloader.py +++ b/tests/datasets/test_dataloader.py @@ -314,7 +314,7 @@ def _test_resume_spmd( batch = next(data_iter) data_list.append(batch) - # Snapshot after the first `step` batches so total_consumed_steps matches resume intent. + # Snapshot after the first `step` batches so total_consumed_samples matches resume intent. dataloader_state = dataloader.get_state_dict() expected_data = [] diff --git a/tests/datasets/test_preset_dataloader.py b/tests/datasets/test_preset_dataloader.py index 9bf8f5f7c..7030ad8b7 100644 --- a/tests/datasets/test_preset_dataloader.py +++ b/tests/datasets/test_preset_dataloader.py @@ -20,7 +20,7 @@ from itertools import chain -from xtuner.v1.datasets import PretrainTokenizeFunctionConfig, get_dataloader_state, load_dataloader_state +from xtuner.v1.datasets import PretrainTokenizeFunctionConfig from xtuner.v1.datasets.config import DatasetConfig, DataloaderConfig from xtuner.v1.datasets.packing import get_pack_infos_by_hard_split from xtuner.v1.datasets.preset_pack import PresetPackDataset @@ -700,7 +700,7 @@ def _build(): global_consumed_samples = sum(int(x) for x in consumed_samples_list if x is not None) # 3. Get ckpt state - # dataloader_state = get_dataloader_state(dl, global_consumed_samples) + # dataloader_state = dl.get_state_dict(global_consumed_samples) dataloader_state = dl.get_state_dict(global_consumed_samples) # 4. Continue to consume data at [half_step, 2*half_step) @@ -738,7 +738,7 @@ def _build(): dl2 = _build() with ckpt_path.open("rb") as f: ckpt = pickle.load(f) - # load_dataloader_state(dl2, ckpt["dataloader_state"]) + # dl2.load_state_dict(ckpt["dataloader_state"]) dl2.load_state_dict(ckpt["dataloader_state"]) resume_iter = iter(dl2) diff --git a/tests/datasets/test_preset_sampler.py b/tests/datasets/test_preset_sampler.py index e2aab4e84..6ba87f1df 100644 --- a/tests/datasets/test_preset_sampler.py +++ b/tests/datasets/test_preset_sampler.py @@ -87,7 +87,7 @@ def test_state_dict_resume(tmp_path): sampler = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1) - state = sampler.get_state_dict(step=3) + state = sampler.get_state_dict(3) assert state["step"] == 3 sampler2 = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1) @@ -102,7 +102,7 @@ def test_state_dict_world_size_mismatch(tmp_path): path = _write_order_npy(tmp_path, "order.npy", _i64(0, 1, 2, 3)) sampler = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1) - state = sampler.get_state_dict(step=0) + state = sampler.get_state_dict(0) state["world_size"] = 99 sampler.load_state_dict(state) diff --git a/xtuner/v1/datasets/__init__.py b/xtuner/v1/datasets/__init__.py index 30d91ee81..94a8f52a7 100644 --- a/xtuner/v1/datasets/__init__.py +++ b/xtuner/v1/datasets/__init__.py @@ -25,7 +25,6 @@ PretrainTokenizeFunction, PretrainTokenizeFunctionConfig, ) -from .resume import get_dataloader_state, load_dataloader_state from .rl_tokenize_fn import RLTokenizeFnConfig from .sampler import LengthGroupedSampler, ParallelSampler from .sft_tokenize_fn import OpenaiTokenizeFunction, OpenaiTokenizeFunctionConfig @@ -68,8 +67,6 @@ "InternS1VLTokenizeFnConfig", "fake_collator", "RLTokenizeFnConfig", - "get_dataloader_state", - "load_dataloader_state", "DatasetConfigList", "DataloaderConfig", "BaseTokenizeFnConfig", diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index 97e6e02ca..f9783369f 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -544,5 +544,6 @@ def build( collate_fn=collator, multiprocessing_context=ctx if self.num_workers > 0 else None, persistent_workers=self.num_workers > 0, + dp_mesh=dp_mesh, ) return dataloader diff --git a/xtuner/v1/datasets/consumed_steps.py b/xtuner/v1/datasets/consumed_steps.py deleted file mode 100644 index b2dc9833a..000000000 --- a/xtuner/v1/datasets/consumed_steps.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Track consumed samples for checkpointing; aggregate across DP only (not -SP/TP).""" - -from __future__ import annotations - -import torch -import torch.distributed as dist -from torch.distributed.device_mesh import DeviceMesh - - -def reduce_sum_across_dp_group(dp_mesh: DeviceMesh | None, local_value: int) -> int: - """Sum ``local_value`` over the DP process group (one contribution per - data-parallel replica). - - Ranks that only differ in SP/TP see identical data batches and must not be summed with the global world group; see - Training notes for SP+DP. - """ - if dp_mesh is None or dp_mesh.size() <= 1: - return int(local_value) - if not dist.is_available() or not dist.is_initialized(): - return int(local_value) - if torch.cuda.is_available(): - device = torch.device(f"cuda:{torch.cuda.current_device()}") - else: - device = torch.device("cpu") - tensor = torch.tensor([local_value], dtype=torch.int64, device=device) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=dp_mesh.get_group()) - return int(tensor.item()) - - -class ConsumedStepsTracker: - """Holds per-resume totals and per-rank local accumulation; checkpoint - total uses DP-only reduction.""" - - __slots__ = ("_dp_mesh", "_init_steps", "_local_steps") - - def __init__(self, dp_mesh: DeviceMesh | None) -> None: - self._dp_mesh = dp_mesh - self._init_steps = 0 - self._local_steps = 0 - - def record(self, n: int) -> None: - self._local_steps += int(n) - - def set_init_from_checkpoint(self, total: int) -> None: - """After loading a checkpoint: global total consumed so far; reset session-local accumulation.""" - self._init_steps = int(total) - self._local_steps = 0 - - def total_for_checkpoint(self) -> int: - """Global consumed sample count including this session (collective over - DP group).""" - return self._init_steps + reduce_sum_across_dp_group(self._dp_mesh, self._local_steps) - - -def apply_old_ckpt_init_steps(sampler: object, sampler_state: dict, train_state_total: int | None) -> None: - """If the sampler checkpoint predates ``total_consumed_steps``, copy the - total from ``train_state``.""" - if train_state_total is None: - return - if sampler_state.get("total_consumed_steps") is not None: - return - consumed: ConsumedStepsTracker | None = getattr(sampler, "_consumed", None) - if consumed is not None: - consumed.set_init_from_checkpoint(train_state_total) diff --git a/xtuner/v1/datasets/dataloader.py b/xtuner/v1/datasets/dataloader.py index d2ca6f0e9..7b0713444 100644 --- a/xtuner/v1/datasets/dataloader.py +++ b/xtuner/v1/datasets/dataloader.py @@ -1,17 +1,42 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Iterator, cast +from typing import Iterator import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh from xtuner.v1.datasets.collator import ColateItem -from xtuner.v1.datasets.consumed_steps import ConsumedStepsTracker -from xtuner.v1.datasets.resume import get_dataloader_state, load_dataloader_state +from xtuner.v1.datasets.packing import ExpandSoftPackDataset, _LegacySoftPackDataset +from xtuner.v1.datasets.preset_sampler import PresetSampler +from xtuner.v1.datasets.sampler import LengthGroupedSampler, ParallelSampler from xtuner.v1.utils import get_logger logger = get_logger() +def reduce_sum_across_dp_group(dp_mesh: DeviceMesh | None, local_value: int) -> int: + """Sum ``local_value`` over the DP process group (one contribution per + data-parallel replica). + + Ranks that only differ in SP/TP see identical data batches and must not be summed with the global world group; see + Training notes for SP+DP. + """ + if dp_mesh is None or dp_mesh.size() <= 1: + return int(local_value) + if not dist.is_available() or not dist.is_initialized(): + return int(local_value) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + else: + device = torch.device("cpu") + tensor = torch.tensor([local_value], dtype=torch.int64, device=device) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=dp_mesh.get_group()) + return int(tensor.item()) + + class BaseDataloader(ABC): """BaseDataloader represents the whole data module to interact with the training process. @@ -24,7 +49,7 @@ class BaseDataloader(ABC): def load_state_dict(self, state_dict: dict, train_state_total_consumed_samples: int | None = None) -> None: ... @abstractmethod - def get_state_dict(self, consumed_samples: int = -1) -> dict: ... + def get_state_dict(self, total_consumed_steps_override: int | None = None) -> dict: ... @abstractmethod def __iter__(self) -> Iterator[list[ColateItem]]: ... @@ -38,6 +63,23 @@ class Dataloader(torch.utils.data.DataLoader, BaseDataloader): implement. """ + def __init__(self, *args, **kwargs) -> None: + dp_mesh: DeviceMesh | None = kwargs.pop("dp_mesh", None) + super().__init__(*args, **kwargs) + self._dp_mesh = dp_mesh + self._init_total_samples = 0 + self._local_samples = 0 + + @staticmethod + def _apply_old_ckpt_total_consumed_samples(state: dict, train_state_total: int | None) -> None: + """If the checkpoint has no ``total_consumed_samples`` (and no legacy + sampler field), copy from ``train_state``.""" + if train_state_total is None: + return + if state.get("total_consumed_samples") is not None: + return + state["total_consumed_samples"] = int(train_state_total) + def load_state_dict( self, state_dict: dict, @@ -45,35 +87,58 @@ def load_state_dict( ) -> None: if train_state_total_consumed_samples is not None: logger.warning( - "Dataloader.load_state_dict(train_state_total_consumed_samples=...) is deprecated; " - "use the default (None). Consumed samples are tracked on the sampler." - ) - load_dataloader_state( - self, - state_dict, - train_state_total_consumed_samples=train_state_total_consumed_samples, - ) - - def get_state_dict(self, consumed_samples: int = -1) -> dict: - if consumed_samples != -1: - logger.warning( - "Dataloader.get_state_dict(consumed_samples=...) is deprecated; use the default (-1). " - "Consumed samples are tracked on the sampler." + "Dataloader.load_state_dict(train_state_total_consumed_samples=...) is deprecated except for " + "very old checkpoints missing total_consumed_samples." ) - # TODO: remove consumed_samples parameter in get_dataloader_state in next major release - dataloader_state = get_dataloader_state(self, consumed_samples) - return cast(dict, dataloader_state) + self._apply_old_ckpt_total_consumed_samples(state_dict, train_state_total_consumed_samples) + + sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = self.sampler # type: ignore[assignment] + dataset = self.dataset + sampler_state = state_dict["sampler"] + + if not hasattr(sampler, "load_state_dict"): + logger.warning(f"Resuming from {type(sampler)} is risky.") + else: + sampler.load_state_dict(sampler_state) + + self._init_total_samples = state_dict["total_consumed_samples"] + self._local_samples = 0 + + if hasattr(dataset, "load_state_dict"): + dataset.load_state_dict(state_dict["dataset"]) + + def get_state_dict(self, total_consumed_steps_override: int | None = None) -> dict: + if total_consumed_steps_override is not None: + total_steps = int(total_consumed_steps_override) + else: + total_steps = self._init_total_samples + reduce_sum_across_dp_group(self._dp_mesh, self._local_samples) + sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = self.sampler # type: ignore[assignment] + dataset: ExpandSoftPackDataset | _LegacySoftPackDataset = self.dataset # type: ignore[assignment] + dataloader_state: dict = { + "sampler": {}, + "dataset": {}, + "total_consumed_samples": total_steps, + } + + if not hasattr(sampler, "load_state_dict") or not hasattr(sampler, "get_state_dict"): + logger.warning(f"Resuming from {type(sampler)} is risky.") + else: + dataloader_state["sampler"].update(sampler.get_state_dict(total_steps)) + + if not hasattr(dataset, "load_state_dict") or not hasattr(dataset, "get_state_dict"): + logger.warning(f"Resuming from {type(dataset)} is risky.") + else: + dataloader_state["dataset"].update(dataset.get_state_dict()) + + return dataloader_state def __iter__(self) -> Iterator[list[ColateItem]]: # type: ignore[override] # Override to count delivered batches, not prefetched indices. # With num_workers > 0 the sampler is iterated ahead by DataLoader's prefetch queue, - # so recording inside sampler.__iter__ would count too many samples. Instead we - # increment _consumed exactly once per batch that reaches the caller. - sampler = self.sampler - consumed: ConsumedStepsTracker | None = getattr(sampler, "_consumed", None) + # so recording inside sampler.__iter__ would count too many samples. Instead we + # increment local consumed exactly once per batch that reaches the caller. for batch in super().__iter__(): - if consumed is not None: - consumed.record(len(batch)) + self._local_samples += len(batch) yield batch # Streaming dataloader may not have `set_epoch` and `__len__` method, so we add here. diff --git a/xtuner/v1/datasets/preset_sampler.py b/xtuner/v1/datasets/preset_sampler.py index 42b6804ad..8a730033d 100644 --- a/xtuner/v1/datasets/preset_sampler.py +++ b/xtuner/v1/datasets/preset_sampler.py @@ -22,7 +22,6 @@ from xtuner.v1.utils import get_logger -from .consumed_steps import ConsumedStepsTracker from .preset_pack import PresetPackDataset @@ -117,7 +116,6 @@ def __init__( else: self.rank = 0 self.world_size = 1 - self._consumed = ConsumedStepsTracker(dp_mesh) self.dataset = dataset self.global_batch_size = global_batch_size @@ -172,26 +170,19 @@ def __len__(self) -> int: def set_epoch(self, epoch: int) -> None: self.epoch = epoch - def get_state_dict(self, step: int | None = None) -> dict: + def get_state_dict(self, total_consumed_steps: int) -> dict: # Same convention as :class:`LengthGroupedSampler`: ``step`` is the global pack offset # (modulo ``total_size``) into ``global_order``, shared across all ranks in the checkpoint. - if step is None: - total_consumed = self._consumed.total_for_checkpoint() - else: - total_consumed = int(step) - global_step = total_consumed % self.total_size + global_step = total_consumed_steps % self.total_size return { "epoch": self.epoch, "step": global_step, - "total_consumed_steps": total_consumed, "world_size": self.world_size, "num_samples": self.num_samples, "total_size": self.total_size, } def load_state_dict(self, state_dict: dict) -> None: - tc = int(state_dict.get("total_consumed_steps", 0)) - self._consumed.set_init_from_checkpoint(tc) if self.world_size != state_dict.get("world_size"): logger.warning( f"PresetSampler: world_size mismatch: checkpoint has " diff --git a/xtuner/v1/datasets/resume.py b/xtuner/v1/datasets/resume.py deleted file mode 100644 index d5bf0c57c..000000000 --- a/xtuner/v1/datasets/resume.py +++ /dev/null @@ -1,63 +0,0 @@ -from torch.utils.data import DataLoader -from typing_extensions import TypedDict - -from xtuner.v1.utils import get_logger - -from .consumed_steps import apply_old_ckpt_init_steps -from .packing import ExpandSoftPackDataset, _LegacySoftPackDataset -from .preset_sampler import PresetSampler -from .sampler import LengthGroupedSampler, ParallelSampler - - -logger = get_logger() - - -class DataloaderState(TypedDict): - sampler: dict - dataset: dict - - -def get_dataloader_state(dataloader: DataLoader, consumed_samples: int = -1) -> DataloaderState: - sampler: ParallelSampler | LengthGroupedSampler = dataloader.sampler # type: ignore[assignment] - dataset: ExpandSoftPackDataset | _LegacySoftPackDataset = dataloader.dataset # type: ignore[assignment] - dataloader_state = DataloaderState(sampler={}, dataset={}) - - if not hasattr(sampler, "load_state_dict") or not hasattr(sampler, "get_state_dict"): - logger.warning(f"Resuming from {type(sampler)} is risky.") - elif consumed_samples != -1: - logger.warning( - "Passing consumed_samples to get_dataloader_state is deprecated; " - "consumed sample totals are tracked on the sampler. Use the default consumed_samples=-1." - ) - dataloader_state["sampler"].update(sampler.get_state_dict(step=consumed_samples)) - else: - dataloader_state["sampler"].update(sampler.get_state_dict()) - - if not hasattr(dataset, "load_state_dict") or not hasattr(dataset, "get_state_dict"): - logger.warning(f"Resuming from {type(dataset)} is risky.") - else: - dataloader_state["dataset"].update(dataset.get_state_dict()) - - return dataloader_state - - -def load_dataloader_state( - dataloader: DataLoader, - state: dict, - train_state_total_consumed_samples: int | None = None, -): - sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = dataloader.sampler # type: ignore[assignment] - dataset = dataloader.dataset - - # Sampler require `load_state_dict` to restore the training progress since the sampler state will - # record the consumed samples. - if not hasattr(sampler, "load_state_dict"): - logger.warning(f"Resuming from {type(sampler)} is risky.") - - if hasattr(sampler, "load_state_dict"): - sampler.load_state_dict(state["sampler"]) - apply_old_ckpt_init_steps(sampler, state["sampler"], train_state_total_consumed_samples) - - # If the dataset records the training progress, we also restore it. - if hasattr(dataset, "load_state_dict"): - dataset.load_state_dict(state["dataset"]) diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index d6c616d62..ca1ebca30 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -12,7 +12,6 @@ from xtuner.v1.utils import get_logger -from .consumed_steps import ConsumedStepsTracker from .jsonl import JsonlDataset from .packing import MLLMPretrainHybridPackDataset, _LegacySoftPackDataset from .preset_pack import PresetPackDataset @@ -85,7 +84,6 @@ def __init__( self.epoch = 0 self.step = 0 self.round_up = round_up - self._consumed = ConsumedStepsTracker(dp_mesh) if self.round_up: self.num_samples = math.ceil(len(self.dataset) / global_batch_size) * global_batch_size // world_size @@ -139,8 +137,6 @@ def load_state_dict(self, state_dict) -> None: Args: state_dict (dict): The state of the sampler. """ - tc = int(state_dict.get("total_consumed_steps", 0)) - self._consumed.set_init_from_checkpoint(tc) self.epoch = state_dict["epoch"] self.step = state_dict["step"] @@ -150,17 +146,12 @@ def load_state_dict(self, state_dict) -> None: f"is different from the current shuffle ({self.shuffle})." ) - def get_state_dict(self, step: int | None = None): + def get_state_dict(self, total_consumed_steps: int): # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples. - if step is None: - total_consumed = self._consumed.total_for_checkpoint() - else: - total_consumed = int(step) - step_mod = total_consumed % self.total_size + step_mod = total_consumed_steps % self.total_size return { "epoch": self.epoch, "step": step_mod, - "total_consumed_steps": total_consumed, "world_size": self.world_size, "shuffle": self.shuffle, "round_up": self.round_up, @@ -242,7 +233,6 @@ def __init__( assert isinstance(self.max_lengths, (list, tuple, Column, np.ndarray)) self.global_batch_size = global_batch_size - self._consumed = ConsumedStepsTracker(dp_mesh) def __iter__(self) -> Iterator[int]: """Iterate the indices.""" @@ -291,8 +281,6 @@ def load_state_dict(self, state_dict: dict) -> None: Args: state_dict (dict): The state of the sampler. """ - tc = int(state_dict.get("total_consumed_steps", 0)) - self._consumed.set_init_from_checkpoint(tc) self.epoch = state_dict["epoch"] self.step = state_dict["step"] @@ -310,22 +298,17 @@ def load_state_dict(self, state_dict: dict) -> None: ) self.group_size = origin_group_size - def get_state_dict(self, step: int | None = None): + def get_state_dict(self, total_consumed_steps: int): """Get the sampler state dict. Returns: dict: The state of the sampler. """ # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples. - if step is None: - total_consumed = self._consumed.total_for_checkpoint() - else: - total_consumed = int(step) - step_mod = total_consumed % self.total_size + step_mod = total_consumed_steps % self.total_size return { "epoch": self.epoch, "step": step_mod, - "total_consumed_steps": total_consumed, "world_size": self.world_size, "round_up": self.round_up, "num_samples": self.num_samples, diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index bb48b6d87..de8612d2c 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -1391,7 +1391,7 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False): if self._sft_dataloader is not None: sft_dataloader_path = checkpoint_path / self._SAVE_SFT_DATALOADER_DIR dataloader_state = self._sft_dataloader.get_state_dict() - total_consumed_samples = int(dataloader_state.get("sampler", {}).get("total_consumed_steps", 0)) + total_consumed_samples = dataloader_state["total_consumed_samples"] if self.rank != 0: return diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 638f7c596..f7b8f094b 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -1214,7 +1214,7 @@ def _save_dataloader(self, dataloader_path: Path | str) -> int: dataloader_state = self._dataloader.get_state_dict() if self.rank == 0: torch.save(dataloader_state, dataloader_path) - return int(dataloader_state.get("sampler", {}).get("total_consumed_steps", 0)) + return dataloader_state["total_consumed_samples"] @property def work_dir(self) -> Path: From 9a81d9c56ee8429625df04bf212ebeb3b3c23bc3 Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Tue, 7 Apr 2026 06:12:49 +0000 Subject: [PATCH 13/14] remove old resume logic of total consumed samples --- tests/datasets/test_preset_dataloader.py | 4 +-- xtuner/v1/datasets/dataloader.py | 36 ++++-------------------- xtuner/v1/ray/dataflow/replay_buffer.py | 4 +-- xtuner/v1/rl/base/worker.py | 5 +--- xtuner/v1/train/trainer.py | 12 ++------ 5 files changed, 14 insertions(+), 47 deletions(-) diff --git a/tests/datasets/test_preset_dataloader.py b/tests/datasets/test_preset_dataloader.py index 7030ad8b7..ec40bac73 100644 --- a/tests/datasets/test_preset_dataloader.py +++ b/tests/datasets/test_preset_dataloader.py @@ -700,8 +700,8 @@ def _build(): global_consumed_samples = sum(int(x) for x in consumed_samples_list if x is not None) # 3. Get ckpt state - # dataloader_state = dl.get_state_dict(global_consumed_samples) - dataloader_state = dl.get_state_dict(global_consumed_samples) + dataloader_state = dl.get_state_dict() + assert dataloader_state["total_consumed_samples"] == global_consumed_samples # 4. Continue to consume data at [half_step, 2*half_step) expected_batches = [] diff --git a/xtuner/v1/datasets/dataloader.py b/xtuner/v1/datasets/dataloader.py index 7b0713444..fc3390271 100644 --- a/xtuner/v1/datasets/dataloader.py +++ b/xtuner/v1/datasets/dataloader.py @@ -46,10 +46,10 @@ class BaseDataloader(ABC): """ @abstractmethod - def load_state_dict(self, state_dict: dict, train_state_total_consumed_samples: int | None = None) -> None: ... + def load_state_dict(self, state_dict: dict) -> None: ... @abstractmethod - def get_state_dict(self, total_consumed_steps_override: int | None = None) -> dict: ... + def get_state_dict(self) -> dict: ... @abstractmethod def __iter__(self) -> Iterator[list[ColateItem]]: ... @@ -70,28 +70,7 @@ def __init__(self, *args, **kwargs) -> None: self._init_total_samples = 0 self._local_samples = 0 - @staticmethod - def _apply_old_ckpt_total_consumed_samples(state: dict, train_state_total: int | None) -> None: - """If the checkpoint has no ``total_consumed_samples`` (and no legacy - sampler field), copy from ``train_state``.""" - if train_state_total is None: - return - if state.get("total_consumed_samples") is not None: - return - state["total_consumed_samples"] = int(train_state_total) - - def load_state_dict( - self, - state_dict: dict, - train_state_total_consumed_samples: int | None = None, - ) -> None: - if train_state_total_consumed_samples is not None: - logger.warning( - "Dataloader.load_state_dict(train_state_total_consumed_samples=...) is deprecated except for " - "very old checkpoints missing total_consumed_samples." - ) - self._apply_old_ckpt_total_consumed_samples(state_dict, train_state_total_consumed_samples) - + def load_state_dict(self, state_dict: dict) -> None: sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = self.sampler # type: ignore[assignment] dataset = self.dataset sampler_state = state_dict["sampler"] @@ -101,17 +80,14 @@ def load_state_dict( else: sampler.load_state_dict(sampler_state) - self._init_total_samples = state_dict["total_consumed_samples"] + self._init_total_samples = int(state_dict["total_consumed_samples"]) self._local_samples = 0 if hasattr(dataset, "load_state_dict"): dataset.load_state_dict(state_dict["dataset"]) - def get_state_dict(self, total_consumed_steps_override: int | None = None) -> dict: - if total_consumed_steps_override is not None: - total_steps = int(total_consumed_steps_override) - else: - total_steps = self._init_total_samples + reduce_sum_across_dp_group(self._dp_mesh, self._local_samples) + def get_state_dict(self) -> dict: + total_steps = self._init_total_samples + reduce_sum_across_dp_group(self._dp_mesh, self._local_samples) sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = self.sampler # type: ignore[assignment] dataset: ExpandSoftPackDataset | _LegacySoftPackDataset = self.dataset # type: ignore[assignment] dataloader_state: dict = { diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 7068406ef..cdf10fd68 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -296,7 +296,7 @@ def resume(self, dataloader_path): dataloader_state = torch.load(dataloader_path, map_location=DEVICE) self.dataloader.load_state_dict(dataloader_state) self.dataloader_iter = iter(self.dataloader) - self.reduced_consumed_samples = dataloader_state["sampler"]["step"] + self.reduced_consumed_samples = int(dataloader_state["total_consumed_samples"]) self.cur_epoch = dataloader_state["sampler"]["epoch"] @@ -923,7 +923,7 @@ def save(self, file_path: Path | str): # save dataloader dataloader_path = file_path / "dataloader" - dataloader_state = self.sampler.dataloader.get_state_dict(self.sampler.reduced_consumed_samples) + dataloader_state = self.sampler.dataloader.get_state_dict() torch.save(dataloader_state, dataloader_path) # save storage diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index de8612d2c..22f1bf275 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -1453,10 +1453,7 @@ def resume(self, load_checkpoint_cfg: LoadCheckpointConfig): if not sft_dataloader_path.exists(): raise FileNotFoundError(f"Dataloader path {sft_dataloader_path} does not exist.") dataloader_state = torch.load(sft_dataloader_path, map_location=DEVICE) - self._sft_dataloader.load_state_dict( - dataloader_state, - train_state_total_consumed_samples=train_state.get("total_consumed_samples", 0), - ) + self._sft_dataloader.load_state_dict(dataloader_state) self.logger.info(f"Resume sft dataloader from {sft_dataloader_path}") @ray_method diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index f7b8f094b..4ee55ec29 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -1811,10 +1811,7 @@ def _load_checkpoint(self): self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC dataloader_path = resume_from / self._SAVE_DATALOADER_DIR - self._resume_dataloader( - dataloader_path, - train_state_total_consumed_samples=train_state.get("total_consumed_samples"), - ) + self._resume_dataloader(dataloader_path) if load_checkpoint_cfg.load_scheduler: scheduler_path = resume_from / self._SAVE_SCHEDULER_DIR @@ -1827,14 +1824,11 @@ def _load_checkpoint(self): scheduler_step = self.total_step - self._cur_step self._lr_scheduler = self.build_lr_scheduler(self._lr_cfg, scheduler_step) - def _resume_dataloader(self, dataloader_path: Path, train_state_total_consumed_samples: int | None = None): + def _resume_dataloader(self, dataloader_path: Path): if not dataloader_path.exists(): raise FileNotFoundError(f"Dataloader path {dataloader_path} does not exist.") dataloader_state = torch.load(dataloader_path, map_location=DEVICE) - self._dataloader.load_state_dict( - dataloader_state, - train_state_total_consumed_samples=train_state_total_consumed_samples, - ) + self._dataloader.load_state_dict(dataloader_state) def _setup_hooks(self, hooks_config: HooksConfig) -> HooksConfig: for stage in HookStage: From 0cd59d1d98d0a9019b5b0dfd0b43580927677395 Mon Sep 17 00:00:00 2001 From: zhaopenghao Date: Tue, 7 Apr 2026 08:41:21 +0000 Subject: [PATCH 14/14] Refactor Trainer._save_dataloader to remove return value of consumed samples --- xtuner/v1/train/trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 4ee55ec29..5d8e60556 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -1131,7 +1131,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: ) # Save dataloader - total_consumed_samples = self._save_dataloader(dataloader_path) + self._save_dataloader(dataloader_path) DEVICE_MODULE.empty_cache() @@ -1161,7 +1161,6 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: { "cur_step": self.cur_step, "cur_epoch": self._cur_epoch, - "total_consumed_samples": total_consumed_samples, "total_consumed_tokens": total_consumed_tokens, "train_time_offset": self._train_time + self._train_time_offset, } @@ -1174,7 +1173,6 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: ckp_list.append(str(checkpoint_path)) current_exp.cur_step = self.cur_step current_exp.cur_epoch = self._cur_epoch - current_exp.consumed_samples = int(total_consumed_samples) current_exp.consumed_tokens = int(total_consumed_tokens) current_exp.history[-1]["end"] = self.cur_step @@ -1210,11 +1208,10 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: return True - def _save_dataloader(self, dataloader_path: Path | str) -> int: + def _save_dataloader(self, dataloader_path: Path | str): dataloader_state = self._dataloader.get_state_dict() if self.rank == 0: torch.save(dataloader_state, dataloader_path) - return dataloader_state["total_consumed_samples"] @property def work_dir(self) -> Path: @@ -1464,6 +1461,8 @@ def _compute_performance_metrics( approximate_total_consumed_tokens = ( self._init_total_tokens + self._local_total_consumed_tokens * self.world_size ) + # TODO: approximate_total_consumed_tokens_per_rank could be incorrect if world_size changed. + # So calculate `eta_seconds = step_time * remaining_steps` instead? approximate_total_consumed_tokens_per_rank = approximate_total_consumed_tokens / self.world_size exp_tgs = self._local_total_consumed_tokens / self._train_time if self._train_time > 0 else 0.0