diff --git a/ci/scripts/test_vlm_sft_trainer.py b/ci/scripts/test_vlm_sft_trainer.py index 204eb9f5d..e36edc130 100644 --- a/ci/scripts/test_vlm_sft_trainer.py +++ b/ci/scripts/test_vlm_sft_trainer.py @@ -204,7 +204,7 @@ def parse_args(): def extract_data_from_log(logfile: Path): - pattern_str = r"\[XTuner\].*Step.*lr:\s(\d+.\d*)\s.*text_tokens:\s(\d+.\d*)\s.*reduced_llm_loss:\s(\d+.\d*)\s.*max_memory:\s(\d+.\d*)\s*GB\s.*grad_norm:\s(\d+.\d*)\s.*e2e_tgs:\s(\d+.\d*)" + pattern_str = r"\[XTuner\].*Step.*lr:\s(\d+.\d*)\s.*text_tokens:\s(\d+.\d*)\s.*reduced_llm_loss:\s(\d+.\d*)\s.*max_memory:\s(\d+.\d*)\s*GB\s.*grad_norm:\s(\d+.\d*)\s.*exp_tgs:\s(\d+.\d*)" compiled_pattern = re.compile(pattern_str) cur_lr = [] diff --git a/tests/datasets/test_dataloader.py b/tests/datasets/test_dataloader.py index fa8ee4bec..9aa82e695 100644 --- a/tests/datasets/test_dataloader.py +++ b/tests/datasets/test_dataloader.py @@ -1,12 +1,18 @@ from pathlib import Path import os import pickle +import socket import torch -from xtuner.v1.datasets import build_dataloader, build_datasets, get_dataloader_state, load_dataloader_state, FTDPTokenizeFnConfig, DatasetConfig, DataloaderConfig +from xtuner.v1.datasets import ( + DataloaderConfig, + DatasetConfig, + FTDPTokenizeFnConfig, + build_dataloader, +) from xtuner.v1.train.toy_tokenizer import UTF8ByteTokenizer -from torch.multiprocessing import spawn, get_context +from torch.multiprocessing import spawn from torch.distributed.device_mesh import init_device_mesh import pytest @@ -15,8 +21,6 @@ from itertools import repeat, chain - - class RandomDataset: def __init__(self, size: int, **kwargs): self.size = size @@ -182,11 +186,12 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro dataset_configs = [ { "dataset": DatasetConfig(anno_path=str(data_dir1)), - "tokenize_fn": FTDPTokenizeFnConfig(max_length=1024) + "tokenize_fn": FTDPTokenizeFnConfig(max_length=1024), }, ] dataloader_config = DataloaderConfig( + dataset_config_list=dataset_configs, pack_max_length=1024, pack_level=pack_level, num_workers=num_workers, @@ -194,13 +199,9 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro pack_workers=pack_workers, ) - datasets = build_datasets( - dataset_config=dataset_configs, + dataloader1 = dataloader_config.build( tokenizer=tokenizer, - ) - dataloader1 = build_dataloader( - dataloader_config=dataloader_config, - datasets=datasets, + dp_mesh=None, global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=BATCH_SIZE, seed=10, @@ -210,26 +211,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro assert len(dataloader1) > 10 dataloader_iter = iter(dataloader1) - consumed_sample = 0 for _ in range(RESUME_ITER): - batch = next(dataloader_iter) - consumed_sample += len(batch) + next(dataloader_iter) - dataloader_state = get_dataloader_state(dataloader1, consumed_sample) + dataloader_state = dataloader1.get_state_dict() expected_data = [] for _ in range(AFTER_RESUME_ITER): - batch = next(dataloader_iter) - consumed_sample += len(batch) - expected_data.append(batch) + expected_data.append(next(dataloader_iter)) - new_dataloader1 = build_dataloader( - dataloader_config=dataloader_config, - datasets=datasets, + new_dataloader1 = dataloader_config.build( + tokenizer=tokenizer, + dp_mesh=None, global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=BATCH_SIZE, seed=10, ) - load_dataloader_state(new_dataloader1, dataloader_state) + new_dataloader1.load_state_dict(dataloader_state) new_dataloader_iter = iter(new_dataloader1) resume_data = [] @@ -242,32 +239,29 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro # 2. Test resume after consuming multiple epochs while True: try: - batch = next(dataloader_iter) - consumed_sample += len(batch) + next(dataloader_iter) except StopIteration: break - dataloader_iter = iter(dataloader1) - for batch in range(RESUME_ITER): - batch = next(dataloader_iter) - consumed_sample += len(batch) + for _ in range(RESUME_ITER): + next(dataloader_iter) - dataloader_state = get_dataloader_state(dataloader1, consumed_sample) + dataloader_state = dataloader1.get_state_dict() expected_data = [] for _ in range(AFTER_RESUME_ITER): expected_data.append(next(dataloader_iter)) - new_dataloader2 = build_dataloader( - dataloader_config=dataloader_config, - datasets=datasets, + new_dataloader2 = dataloader_config.build( + tokenizer=tokenizer, + dp_mesh=None, global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=BATCH_SIZE, seed=10, ) - load_dataloader_state(new_dataloader2, dataloader_state) + new_dataloader2.load_state_dict(dataloader_state) new_dataloader_iter2 = iter(new_dataloader2) resume_data = [] @@ -282,14 +276,12 @@ def _test_resume_spmd( rank: int, world_size: int, dataloader_config: DataloaderConfig, - dataset_configs: list[dict], global_batch_size: int, micro_batch_size: int, - step:int, + step: int, seed: int, save_path: Path, dataloader_state: dict | None = None, - consumed_samples: int = 0, ): os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) @@ -297,50 +289,39 @@ def _test_resume_spmd( os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29505" - torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) data_mesh = init_device_mesh( device_type="cuda", - mesh_shape=(world_size,) + mesh_shape=(world_size,), ) tokenizer = UTF8ByteTokenizer() - datasets = build_datasets( - dataset_config=dataset_configs, + dataloader = dataloader_config.build( tokenizer=tokenizer, - ) - dataloader = build_dataloader( - dataloader_config=dataloader_config, - datasets=datasets, + dp_mesh=data_mesh, global_batch_size=global_batch_size, micro_batch_size=micro_batch_size, seed=seed, - dp_mesh=data_mesh, ) if dataloader_state is not None: - load_dataloader_state(dataloader, dataloader_state) + dataloader.load_state_dict(dataloader_state) data_iter = iter(dataloader) data_list = [] for _ in range(step): batch = next(data_iter) data_list.append(batch) - consumed_samples += len(batch) - consumed_samples_list = [None for _ in range(world_size)] - torch.distributed.all_gather_object(consumed_samples_list, consumed_samples) - global_consumed_samples = sum(consumed_samples_list) + # Snapshot after the first `step` batches so total_consumed_samples matches resume intent. + dataloader_state = dataloader.get_state_dict() expected_data = [] - for _ in range(step): batch = next(data_iter) expected_data.append(batch) - dataloader_state = get_dataloader_state(dataloader, global_consumed_samples) - all_data_list = [None for _ in range(world_size)] torch.distributed.all_gather_object(all_data_list, list(chain(*data_list))) @@ -372,7 +353,6 @@ def _test_resume_spmd( "dataloader_state": dataloader_state, "data_list": all_data_list, "expected_data": all_expected_data, - "consumed_samples": consumed_samples } ) ) @@ -389,7 +369,6 @@ def _test_resume_spmd( ("none", 0, False), ("soft", 0, True), ("soft", 4, True), - ("soft", 4, True), ] ) def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, group_by_length): @@ -402,21 +381,22 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou _create_fake_dataset(data_dir1 / f"depth3", dataset_num=3, max_depth=3, dup_times=9) # 1. Test resuming with the same world size + dataset_configs = [ + { + "dataset": DatasetConfig(anno_path=str(data_dir1)), + "tokenize_fn": FTDPTokenizeFnConfig(max_length=1024), + }, + ] + dataloader_config = DataloaderConfig( + dataset_config_list=dataset_configs, pack_max_length=1024, pack_level=pack_level, num_workers=num_workers, group_by_length=group_by_length, - collator="fake_collator" + collator="fake_collator", ) - dataset_configs = [ - { - "dataset": DatasetConfig(anno_path=str(data_dir1)), - "tokenize_fn": FTDPTokenizeFnConfig(max_length=1024) - }, - ] - ctx = get_context("spawn") world_size = 2 save_path1 = tmp_path / "dataloader_state.pkl" spawn( @@ -424,14 +404,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou args=( world_size, dataloader_config, - dataset_configs, 16, BATCH_SIZE, TOTAL_STEP, 10, save_path1, None, - 0, ), nprocs=2, join=True, @@ -448,14 +426,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou args=( world_size, dataloader_config, - dataset_configs, 16, BATCH_SIZE, TOTAL_STEP, 10, save_path2, result1["dataloader_state"], - result1["consumed_samples"], ), nprocs=world_size, join=True, @@ -475,14 +451,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou args=( world_size, dataloader_config, - dataset_configs, 16, BATCH_SIZE, TOTAL_STEP, 10, save_path3, result1["dataloader_state"], - result1["consumed_samples"], ), nprocs=world_size, join=True, diff --git a/tests/datasets/test_preset_dataloader.py b/tests/datasets/test_preset_dataloader.py index 9bf8f5f7c..ec40bac73 100644 --- a/tests/datasets/test_preset_dataloader.py +++ b/tests/datasets/test_preset_dataloader.py @@ -20,7 +20,7 @@ from itertools import chain -from xtuner.v1.datasets import PretrainTokenizeFunctionConfig, get_dataloader_state, load_dataloader_state +from xtuner.v1.datasets import PretrainTokenizeFunctionConfig from xtuner.v1.datasets.config import DatasetConfig, DataloaderConfig from xtuner.v1.datasets.packing import get_pack_infos_by_hard_split from xtuner.v1.datasets.preset_pack import PresetPackDataset @@ -700,8 +700,8 @@ def _build(): global_consumed_samples = sum(int(x) for x in consumed_samples_list if x is not None) # 3. Get ckpt state - # dataloader_state = get_dataloader_state(dl, global_consumed_samples) - dataloader_state = dl.get_state_dict(global_consumed_samples) + dataloader_state = dl.get_state_dict() + assert dataloader_state["total_consumed_samples"] == global_consumed_samples # 4. Continue to consume data at [half_step, 2*half_step) expected_batches = [] @@ -738,7 +738,7 @@ def _build(): dl2 = _build() with ckpt_path.open("rb") as f: ckpt = pickle.load(f) - # load_dataloader_state(dl2, ckpt["dataloader_state"]) + # dl2.load_state_dict(ckpt["dataloader_state"]) dl2.load_state_dict(ckpt["dataloader_state"]) resume_iter = iter(dl2) diff --git a/tests/datasets/test_preset_sampler.py b/tests/datasets/test_preset_sampler.py index e2aab4e84..6ba87f1df 100644 --- a/tests/datasets/test_preset_sampler.py +++ b/tests/datasets/test_preset_sampler.py @@ -87,7 +87,7 @@ def test_state_dict_resume(tmp_path): sampler = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1) - state = sampler.get_state_dict(step=3) + state = sampler.get_state_dict(3) assert state["step"] == 3 sampler2 = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1) @@ -102,7 +102,7 @@ def test_state_dict_world_size_mismatch(tmp_path): path = _write_order_npy(tmp_path, "order.npy", _i64(0, 1, 2, 3)) sampler = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1) - state = sampler.get_state_dict(step=0) + state = sampler.get_state_dict(0) state["world_size"] = 99 sampler.load_state_dict(state) diff --git a/xtuner/v1/datasets/__init__.py b/xtuner/v1/datasets/__init__.py index 30d91ee81..94a8f52a7 100644 --- a/xtuner/v1/datasets/__init__.py +++ b/xtuner/v1/datasets/__init__.py @@ -25,7 +25,6 @@ PretrainTokenizeFunction, PretrainTokenizeFunctionConfig, ) -from .resume import get_dataloader_state, load_dataloader_state from .rl_tokenize_fn import RLTokenizeFnConfig from .sampler import LengthGroupedSampler, ParallelSampler from .sft_tokenize_fn import OpenaiTokenizeFunction, OpenaiTokenizeFunctionConfig @@ -68,8 +67,6 @@ "InternS1VLTokenizeFnConfig", "fake_collator", "RLTokenizeFnConfig", - "get_dataloader_state", - "load_dataloader_state", "DatasetConfigList", "DataloaderConfig", "BaseTokenizeFnConfig", diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index 97e6e02ca..f9783369f 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -544,5 +544,6 @@ def build( collate_fn=collator, multiprocessing_context=ctx if self.num_workers > 0 else None, persistent_workers=self.num_workers > 0, + dp_mesh=dp_mesh, ) return dataloader diff --git a/xtuner/v1/datasets/dataloader.py b/xtuner/v1/datasets/dataloader.py index bdd508d5b..fc3390271 100644 --- a/xtuner/v1/datasets/dataloader.py +++ b/xtuner/v1/datasets/dataloader.py @@ -1,10 +1,40 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Iterator, cast +from typing import Iterator import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh from xtuner.v1.datasets.collator import ColateItem -from xtuner.v1.datasets.resume import get_dataloader_state, load_dataloader_state +from xtuner.v1.datasets.packing import ExpandSoftPackDataset, _LegacySoftPackDataset +from xtuner.v1.datasets.preset_sampler import PresetSampler +from xtuner.v1.datasets.sampler import LengthGroupedSampler, ParallelSampler +from xtuner.v1.utils import get_logger + + +logger = get_logger() + + +def reduce_sum_across_dp_group(dp_mesh: DeviceMesh | None, local_value: int) -> int: + """Sum ``local_value`` over the DP process group (one contribution per + data-parallel replica). + + Ranks that only differ in SP/TP see identical data batches and must not be summed with the global world group; see + Training notes for SP+DP. + """ + if dp_mesh is None or dp_mesh.size() <= 1: + return int(local_value) + if not dist.is_available() or not dist.is_initialized(): + return int(local_value) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + else: + device = torch.device("cpu") + tensor = torch.tensor([local_value], dtype=torch.int64, device=device) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=dp_mesh.get_group()) + return int(tensor.item()) class BaseDataloader(ABC): @@ -19,7 +49,7 @@ class BaseDataloader(ABC): def load_state_dict(self, state_dict: dict) -> None: ... @abstractmethod - def get_state_dict(self, consumed_samples: int) -> dict: ... + def get_state_dict(self) -> dict: ... @abstractmethod def __iter__(self) -> Iterator[list[ColateItem]]: ... @@ -33,14 +63,59 @@ class Dataloader(torch.utils.data.DataLoader, BaseDataloader): implement. """ + def __init__(self, *args, **kwargs) -> None: + dp_mesh: DeviceMesh | None = kwargs.pop("dp_mesh", None) + super().__init__(*args, **kwargs) + self._dp_mesh = dp_mesh + self._init_total_samples = 0 + self._local_samples = 0 + def load_state_dict(self, state_dict: dict) -> None: - load_dataloader_state(self, state_dict) + sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = self.sampler # type: ignore[assignment] + dataset = self.dataset + sampler_state = state_dict["sampler"] + + if not hasattr(sampler, "load_state_dict"): + logger.warning(f"Resuming from {type(sampler)} is risky.") + else: + sampler.load_state_dict(sampler_state) + + self._init_total_samples = int(state_dict["total_consumed_samples"]) + self._local_samples = 0 + + if hasattr(dataset, "load_state_dict"): + dataset.load_state_dict(state_dict["dataset"]) + + def get_state_dict(self) -> dict: + total_steps = self._init_total_samples + reduce_sum_across_dp_group(self._dp_mesh, self._local_samples) + sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = self.sampler # type: ignore[assignment] + dataset: ExpandSoftPackDataset | _LegacySoftPackDataset = self.dataset # type: ignore[assignment] + dataloader_state: dict = { + "sampler": {}, + "dataset": {}, + "total_consumed_samples": total_steps, + } + + if not hasattr(sampler, "load_state_dict") or not hasattr(sampler, "get_state_dict"): + logger.warning(f"Resuming from {type(sampler)} is risky.") + else: + dataloader_state["sampler"].update(sampler.get_state_dict(total_steps)) + + if not hasattr(dataset, "load_state_dict") or not hasattr(dataset, "get_state_dict"): + logger.warning(f"Resuming from {type(dataset)} is risky.") + else: + dataloader_state["dataset"].update(dataset.get_state_dict()) - def get_state_dict(self, consumed_samples: int) -> dict: - dataloader_state = get_dataloader_state(self, consumed_samples) - return cast(dict, dataloader_state) + return dataloader_state - # __iter__ is inherited from torch.utils.data.DataLoader + def __iter__(self) -> Iterator[list[ColateItem]]: # type: ignore[override] + # Override to count delivered batches, not prefetched indices. + # With num_workers > 0 the sampler is iterated ahead by DataLoader's prefetch queue, + # so recording inside sampler.__iter__ would count too many samples. Instead we + # increment local consumed exactly once per batch that reaches the caller. + for batch in super().__iter__(): + self._local_samples += len(batch) + yield batch # Streaming dataloader may not have `set_epoch` and `__len__` method, so we add here. def set_epoch(self, epoch: int) -> None: diff --git a/xtuner/v1/datasets/preset_sampler.py b/xtuner/v1/datasets/preset_sampler.py index 636c9343e..8a730033d 100644 --- a/xtuner/v1/datasets/preset_sampler.py +++ b/xtuner/v1/datasets/preset_sampler.py @@ -170,10 +170,10 @@ def __len__(self) -> int: def set_epoch(self, epoch: int) -> None: self.epoch = epoch - def get_state_dict(self, step: int) -> dict: + def get_state_dict(self, total_consumed_steps: int) -> dict: # Same convention as :class:`LengthGroupedSampler`: ``step`` is the global pack offset # (modulo ``total_size``) into ``global_order``, shared across all ranks in the checkpoint. - global_step = step % self.total_size + global_step = total_consumed_steps % self.total_size return { "epoch": self.epoch, "step": global_step, @@ -191,5 +191,4 @@ def load_state_dict(self, state_dict: dict) -> None: ) self.epoch = state_dict["epoch"] - global_step = int(state_dict["step"]) - self.step = global_step + self.step = int(state_dict["step"]) diff --git a/xtuner/v1/datasets/resume.py b/xtuner/v1/datasets/resume.py deleted file mode 100644 index 65ab62f3a..000000000 --- a/xtuner/v1/datasets/resume.py +++ /dev/null @@ -1,50 +0,0 @@ -from torch.utils.data import DataLoader -from typing_extensions import TypedDict - -from xtuner.v1.utils import get_logger - -from .packing import ExpandSoftPackDataset, _LegacySoftPackDataset -from .sampler import LengthGroupedSampler, ParallelSampler - - -logger = get_logger() - - -class DataloaderState(TypedDict): - sampler: dict - dataset: dict - - -def get_dataloader_state(dataloader: DataLoader, consumed_samples: int) -> DataloaderState: - sampler: ParallelSampler | LengthGroupedSampler = dataloader.sampler # type: ignore[assignment] - dataset: ExpandSoftPackDataset | _LegacySoftPackDataset = dataloader.dataset # type: ignore[assignment] - dataloader_state = DataloaderState(sampler={}, dataset={}) - - if not hasattr(sampler, "load_state_dict") or not hasattr(sampler, "get_state_dict"): - logger.warning(f"Resuming from {type(sampler)} is risky.") - else: - dataloader_state["sampler"].update(sampler.get_state_dict(step=consumed_samples)) - - if not hasattr(dataset, "load_state_dict") or not hasattr(dataset, "get_state_dict"): - logger.warning(f"Resuming from {type(dataset)} is risky.") - else: - dataloader_state["dataset"].update(dataset.get_state_dict()) - - return dataloader_state - - -def load_dataloader_state(dataloader: DataLoader, state: dict): - sampler = dataloader.sampler - dataset = dataloader.dataset - - # Sampler require `load_state_dict` to restore the training progress since the sampler state will - # record the consumed samples. - if not hasattr(sampler, "load_state_dict"): - logger.warning(f"Resuming from {type(sampler)} is risky.") - - if hasattr(sampler, "load_state_dict"): - sampler.load_state_dict(state["sampler"]) - - # If the dataset records the training progress, we also restore it. - if hasattr(dataset, "load_state_dict"): - dataset.load_state_dict(state["dataset"]) diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index d4b591d6f..ca1ebca30 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -146,12 +146,12 @@ def load_state_dict(self, state_dict) -> None: f"is different from the current shuffle ({self.shuffle})." ) - def get_state_dict(self, step: int): + def get_state_dict(self, total_consumed_steps: int): # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples. - step = step % self.total_size + step_mod = total_consumed_steps % self.total_size return { "epoch": self.epoch, - "step": step, + "step": step_mod, "world_size": self.world_size, "shuffle": self.shuffle, "round_up": self.round_up, @@ -298,17 +298,17 @@ def load_state_dict(self, state_dict: dict) -> None: ) self.group_size = origin_group_size - def get_state_dict(self, step: int): + def get_state_dict(self, total_consumed_steps: int): """Get the sampler state dict. Returns: dict: The state of the sampler. """ # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples. - step = step % self.total_size + step_mod = total_consumed_steps % self.total_size return { "epoch": self.epoch, - "step": step, + "step": step_mod, "world_size": self.world_size, "round_up": self.round_up, "num_samples": self.num_samples, diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 7068406ef..cdf10fd68 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -296,7 +296,7 @@ def resume(self, dataloader_path): dataloader_state = torch.load(dataloader_path, map_location=DEVICE) self.dataloader.load_state_dict(dataloader_state) self.dataloader_iter = iter(self.dataloader) - self.reduced_consumed_samples = dataloader_state["sampler"]["step"] + self.reduced_consumed_samples = int(dataloader_state["total_consumed_samples"]) self.cur_epoch = dataloader_state["sampler"]["epoch"] @@ -923,7 +923,7 @@ def save(self, file_path: Path | str): # save dataloader dataloader_path = file_path / "dataloader" - dataloader_state = self.sampler.dataloader.get_state_dict(self.sampler.reduced_consumed_samples) + dataloader_state = self.sampler.dataloader.get_state_dict() torch.save(dataloader_state, dataloader_path) # save storage diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 855bc589a..22f1bf275 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -263,7 +263,6 @@ def _init_sft(self, worker_cfg: WorkerConfig): self._rollout_step = 0 self._sft_cur_epoch = 0 - self._sft_total_consumed_samples = 0 self._sft_total_consumed_tokens = 0 if self._sft_dataloader_config is not None: @@ -672,7 +671,6 @@ def _fit_sft(self): time_before_train_step = time.time() data_time = time_before_train_step - time_before_get_data DEVICE_MODULE.reset_peak_memory_stats() - cur_sample_num = len(data_batch) train_step_info, grad_norm = self._train_one_step_sft(data_batch) @@ -680,7 +678,6 @@ def _fit_sft(self): step_time = time_after_train_step - time_before_train_step step_consumed_tokens = train_step_info["step_consumed_tokens"] - self._sft_total_consumed_samples += self._reduce_number_across_rank(cur_sample_num) reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens) self._sft_total_consumed_tokens += reduced_step_consumed_tokens @@ -1391,9 +1388,13 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False): ) # Save sft dataloader - if self.rank == 0 and self._sft_dataloader is not None: + if self._sft_dataloader is not None: sft_dataloader_path = checkpoint_path / self._SAVE_SFT_DATALOADER_DIR - dataloader_state = self._sft_dataloader.get_state_dict(self._sft_total_consumed_samples) + dataloader_state = self._sft_dataloader.get_state_dict() + total_consumed_samples = dataloader_state["total_consumed_samples"] + if self.rank != 0: + return + torch.save(dataloader_state, sft_dataloader_path) train_state_path = checkpoint_path / self._SAVE_SFT_TRAIN_STATE_PATH @@ -1403,7 +1404,7 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False): { "cur_step": self._rollout_step, "cur_epoch": self._sft_cur_epoch, - "total_consumed_samples": self._sft_total_consumed_samples, + "total_consumed_samples": total_consumed_samples, "total_consumed_tokens": self._sft_total_consumed_tokens, } ) @@ -1437,24 +1438,23 @@ def resume(self, load_checkpoint_cfg: LoadCheckpointConfig): ) # Resume sft dataloader - sft_dataloader_path = resume_from / self._SAVE_SFT_DATALOADER_DIR if self._sft_dataloader is not None: - if not sft_dataloader_path.exists(): - raise FileNotFoundError(f"Dataloader path {sft_dataloader_path} does not exist.") - dataloader_state = torch.load(sft_dataloader_path, map_location=DEVICE) - self._sft_dataloader.load_state_dict(dataloader_state) - self.logger.info(f"Resume sft dataloader from {sft_dataloader_path}") - train_state_path = resume_from / self._SAVE_SFT_TRAIN_STATE_PATH if not train_state_path.exists(): raise FileNotFoundError(f"Train state path {train_state_path} does not exist.") with train_state_path.open("r") as f: train_state = json.loads(f.read()) - self._rollout_step = train_state["cur_step"] - self._sft_cur_epoch = train_state["cur_epoch"] - self._sft_total_consumed_samples = train_state["total_consumed_samples"] - self._sft_total_consumed_tokens = train_state["total_consumed_tokens"] - self.logger.info(f"Resume sft train state from {train_state_path}") + self._rollout_step = train_state["cur_step"] + self._sft_cur_epoch = train_state["cur_epoch"] + self._sft_total_consumed_tokens = train_state["total_consumed_tokens"] + self.logger.info(f"Resume sft train state from {train_state_path}") + + sft_dataloader_path = resume_from / self._SAVE_SFT_DATALOADER_DIR + if not sft_dataloader_path.exists(): + raise FileNotFoundError(f"Dataloader path {sft_dataloader_path} does not exist.") + dataloader_state = torch.load(sft_dataloader_path, map_location=DEVICE) + self._sft_dataloader.load_state_dict(dataloader_state) + self.logger.info(f"Resume sft dataloader from {sft_dataloader_path}") @ray_method def ready(self) -> bool: diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index d09b5f652..5d8e60556 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -86,11 +86,9 @@ class ExpHistory(TypedDict): class PerformanceStatistics(TypedDict): local_step_consumed_tokens: int local_step_consumed_img_tokens: int | None - step_consumed_tokens: int - total_consumed_tokens: int - total_consumed_tokens_per_rank: float + local_total_consumed_tokens: int + approximate_total_consumed_tokens: int tgs: float - e2e_tgs: float exp_tgs: float eta_seconds: float eta_hms: str @@ -537,9 +535,12 @@ def __init__( self._debug = debug self._seed = seed - self._total_consumed_tokens = 0 - self._exp_consumed_tokens = 0 - self._total_consumed_samples = 0 + # 日志变量前缀规则: + # 空间上,当前rank的用 local_,默认 reduced 无前缀 + # 时间上,当前步用 step_, 累积用 total_ + # self._local_total_consumed_tokens 表示时间上累积到现在的当前rank的和,resume则只考虑resume步数到现在 + self._local_total_consumed_tokens = 0 + self._init_total_tokens = 0 self._train_time = 0 self._train_time_offset = 0 @@ -726,7 +727,6 @@ def fit(self): train_begin = time.time() time_before_get_data = time.time() for data_batch in self._data_iter(): - consumed_samples = len(data_batch) time_before_train_step = time.time() ProberList.set_step(self._cur_step + 1) @@ -759,17 +759,14 @@ def fit(self): internal_metrics = self._maybe_pop_model_internal_metrics(engine_input) self._cur_step += 1 - reduced_step_consumed_tokens = self._reduce_number_across_rank(train_step_info["step_consumed_tokens"]) - self._total_consumed_tokens += reduced_step_consumed_tokens - self._exp_consumed_tokens += reduced_step_consumed_tokens - self._total_consumed_samples += self._reduce_number_across_rank(consumed_samples) + step_tokens = train_step_info["step_consumed_tokens"] + self._local_total_consumed_tokens += step_tokens self._train_time = time_after_train_step - train_begin # Compute training metrics training_metrics = self._compute_performance_metrics( - local_step_consumed_tokens=train_step_info["step_consumed_tokens"], + local_step_consumed_tokens=step_tokens, local_step_consumed_img_tokens=train_step_info.get("step_consumed_img_tokens"), - step_consumed_tokens=reduced_step_consumed_tokens, step_time=step_time, ) @@ -1129,6 +1126,10 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: optimizer_dir=optimizer_path, ) + total_consumed_tokens = ( + self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens + ) + # Save dataloader self._save_dataloader(dataloader_path) @@ -1160,8 +1161,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: { "cur_step": self.cur_step, "cur_epoch": self._cur_epoch, - "total_consumed_samples": self._total_consumed_samples, - "total_consumed_tokens": self._total_consumed_tokens, + "total_consumed_tokens": total_consumed_tokens, "train_time_offset": self._train_time + self._train_time_offset, } ) @@ -1173,8 +1173,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: ckp_list.append(str(checkpoint_path)) current_exp.cur_step = self.cur_step current_exp.cur_epoch = self._cur_epoch - current_exp.consumed_samples = int(self._total_consumed_samples) - current_exp.consumed_tokens = int(self._total_consumed_tokens) + current_exp.consumed_tokens = int(total_consumed_tokens) current_exp.history[-1]["end"] = self.cur_step # Delete checkpoints and update meta's checkpoint_list @@ -1210,8 +1209,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: return True def _save_dataloader(self, dataloader_path: Path | str): + dataloader_state = self._dataloader.get_state_dict() if self.rank == 0: - dataloader_state = self._dataloader.get_state_dict(self._total_consumed_samples) torch.save(dataloader_state, dataloader_path) @property @@ -1444,7 +1443,6 @@ def _compute_performance_metrics( self, local_step_consumed_tokens: int, local_step_consumed_img_tokens: int | None, - step_consumed_tokens: int, step_time: float, ) -> PerformanceStatistics: """Compute training metrics including tokens and throughput statistics. @@ -1452,21 +1450,24 @@ def _compute_performance_metrics( Args: local_step_consumed_tokens (int): Tokens consumed in current step on current rank. local_step_consumed_img_tokens (int | None): Image tokens consumed in current step on current rank. - step_consumed_tokens (int): Total tokens consumed in current step across all ranks. step_time (float): Time spent on current training step in seconds. Returns: TrainingMetrics: Dictionary containing computed training metrics. """ e2e_train_time = self._train_time + self._train_time_offset - total_consumed_tokens_per_rank = self._total_consumed_tokens / self.world_size tgs = local_step_consumed_tokens / step_time - e2e_tgs = total_consumed_tokens_per_rank / e2e_train_time - exp_tgs = self._exp_consumed_tokens / self.world_size / self._train_time + approximate_total_consumed_tokens = ( + self._init_total_tokens + self._local_total_consumed_tokens * self.world_size + ) + # TODO: approximate_total_consumed_tokens_per_rank could be incorrect if world_size changed. + # So calculate `eta_seconds = step_time * remaining_steps` instead? + approximate_total_consumed_tokens_per_rank = approximate_total_consumed_tokens / self.world_size + exp_tgs = self._local_total_consumed_tokens / self._train_time if self._train_time > 0 else 0.0 remaining_steps = self.total_step - self.cur_step - avg_tokens_per_step = total_consumed_tokens_per_rank / self.cur_step + avg_tokens_per_step = approximate_total_consumed_tokens_per_rank / self.cur_step remaining_tokens = remaining_steps * avg_tokens_per_step eta_seconds = remaining_tokens / max(tgs, 1) eta_hms = str(timedelta(seconds=int(eta_seconds))) @@ -1474,11 +1475,9 @@ def _compute_performance_metrics( return PerformanceStatistics( local_step_consumed_tokens=local_step_consumed_tokens, local_step_consumed_img_tokens=local_step_consumed_img_tokens, - step_consumed_tokens=step_consumed_tokens, - total_consumed_tokens=self._total_consumed_tokens, - total_consumed_tokens_per_rank=total_consumed_tokens_per_rank, + local_total_consumed_tokens=self._local_total_consumed_tokens, + approximate_total_consumed_tokens=approximate_total_consumed_tokens, tgs=tgs, - e2e_tgs=e2e_tgs, exp_tgs=exp_tgs, eta_seconds=eta_seconds, eta_hms=eta_hms, @@ -1533,8 +1532,7 @@ def _log_step( f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} " f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} " f"text_tokens: {training_metrics['local_step_consumed_tokens']} {img_tokens_str}" - f"step_consumed_tokens: {training_metrics['step_consumed_tokens']} " - f"total_consumed_tokens: {training_metrics['total_consumed_tokens']} " + f"approximate_total_consumed_tokens: {training_metrics['approximate_total_consumed_tokens']} " f"{loss_log_str} " f"{data_info_str} " f"{extra_info_str} " @@ -1543,7 +1541,6 @@ def _log_step( f"reserved_memory: {reserved_memory / (1024**3):.2f} GB " f"tgs: {training_metrics['tgs']:.1f} " f"exp_tgs: {training_metrics['exp_tgs']:.1f} " - f"e2e_tgs: {training_metrics['e2e_tgs']:.1f} " f"eta: {training_metrics['eta_hms']} " ) @@ -1554,11 +1551,9 @@ def _log_step( "time/train_time": round(self._train_time, 4), "time/eta_seconds": round(training_metrics["eta_seconds"], 1), "runtime_info/text_tokens": training_metrics["local_step_consumed_tokens"], - "runtime_info/step_consumed_tokens": training_metrics["step_consumed_tokens"], - "runtime_info/total_consumed_tokens": training_metrics["total_consumed_tokens"], + "runtime_info/approximate_total_consumed_tokens": training_metrics["approximate_total_consumed_tokens"], "runtime_info/tgs": training_metrics["tgs"], "runtime_info/exp_tgs": training_metrics["exp_tgs"], - "runtime_info/e2e_tgs": training_metrics["e2e_tgs"], "memory/max_memory_GB": round(max_memory / (1024**3), 3), "memory/reserved_memory_GB": round(reserved_memory / (1024**3), 3), "grad_norm": grad_norm, @@ -1811,13 +1806,8 @@ def _load_checkpoint(self): self._cur_epoch = train_state["cur_epoch"] if load_checkpoint_cfg.load_dataset: - self._total_consumed_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC self._train_time_offset = train_state["train_time_offset"] - # _total_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。 - # 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值。 - # 2) 如果不加载 dataset,应该保持_total_consumed_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples - # 会导致存储新dataloader时 total_consumed_samples 是不正确的值。 - self._total_consumed_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC + self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC dataloader_path = resume_from / self._SAVE_DATALOADER_DIR self._resume_dataloader(dataloader_path)