Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/scripts/test_vlm_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def parse_args():


def extract_data_from_log(logfile: Path):
pattern_str = r"\[XTuner\].*Step.*lr:\s(\d+.\d*)\s.*text_tokens:\s(\d+.\d*)\s.*reduced_llm_loss:\s(\d+.\d*)\s.*max_memory:\s(\d+.\d*)\s*GB\s.*grad_norm:\s(\d+.\d*)\s.*e2e_tgs:\s(\d+.\d*)"
pattern_str = r"\[XTuner\].*Step.*lr:\s(\d+.\d*)\s.*text_tokens:\s(\d+.\d*)\s.*reduced_llm_loss:\s(\d+.\d*)\s.*max_memory:\s(\d+.\d*)\s*GB\s.*grad_norm:\s(\d+.\d*)\s.*exp_tgs:\s(\d+.\d*)"
compiled_pattern = re.compile(pattern_str)

cur_lr = []
Expand Down
112 changes: 43 additions & 69 deletions tests/datasets/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from pathlib import Path
import os
import pickle
import socket

import torch

from xtuner.v1.datasets import build_dataloader, build_datasets, get_dataloader_state, load_dataloader_state, FTDPTokenizeFnConfig, DatasetConfig, DataloaderConfig
from xtuner.v1.datasets import (
DataloaderConfig,
DatasetConfig,
FTDPTokenizeFnConfig,
build_dataloader,
)
from xtuner.v1.train.toy_tokenizer import UTF8ByteTokenizer
from torch.multiprocessing import spawn, get_context
from torch.multiprocessing import spawn
from torch.distributed.device_mesh import init_device_mesh
import pytest

Expand All @@ -15,8 +21,6 @@
from itertools import repeat, chain




class RandomDataset:
def __init__(self, size: int, **kwargs):
self.size = size
Expand Down Expand Up @@ -182,25 +186,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
dataset_configs = [
{
"dataset": DatasetConfig(anno_path=str(data_dir1)),
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024)
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024),
},
]

dataloader_config = DataloaderConfig(
dataset_config_list=dataset_configs,
pack_max_length=1024,
pack_level=pack_level,
num_workers=num_workers,
group_by_length=group_by_length,
pack_workers=pack_workers,
)

datasets = build_datasets(
dataset_config=dataset_configs,
dataloader1 = dataloader_config.build(
tokenizer=tokenizer,
)
dataloader1 = build_dataloader(
dataloader_config=dataloader_config,
datasets=datasets,
dp_mesh=None,
global_batch_size=GLOBAL_BATCH_SIZE,
micro_batch_size=BATCH_SIZE,
seed=10,
Expand All @@ -210,26 +211,22 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
assert len(dataloader1) > 10

dataloader_iter = iter(dataloader1)
consumed_sample = 0
for _ in range(RESUME_ITER):
batch = next(dataloader_iter)
consumed_sample += len(batch)
next(dataloader_iter)

dataloader_state = get_dataloader_state(dataloader1, consumed_sample)
dataloader_state = dataloader1.get_state_dict()
expected_data = []
for _ in range(AFTER_RESUME_ITER):
batch = next(dataloader_iter)
consumed_sample += len(batch)
expected_data.append(batch)
expected_data.append(next(dataloader_iter))

new_dataloader1 = build_dataloader(
dataloader_config=dataloader_config,
datasets=datasets,
new_dataloader1 = dataloader_config.build(
tokenizer=tokenizer,
dp_mesh=None,
global_batch_size=GLOBAL_BATCH_SIZE,
micro_batch_size=BATCH_SIZE,
seed=10,
)
load_dataloader_state(new_dataloader1, dataloader_state)
new_dataloader1.load_state_dict(dataloader_state)
new_dataloader_iter = iter(new_dataloader1)

resume_data = []
Expand All @@ -242,32 +239,29 @@ def test_dataloader_resume_single_process(tmp_path, pack_level, num_workers, gro
# 2. Test resume after consuming multiple epochs
while True:
try:
batch = next(dataloader_iter)
consumed_sample += len(batch)
next(dataloader_iter)
except StopIteration:
break


dataloader_iter = iter(dataloader1)

for batch in range(RESUME_ITER):
batch = next(dataloader_iter)
consumed_sample += len(batch)
for _ in range(RESUME_ITER):
next(dataloader_iter)

dataloader_state = get_dataloader_state(dataloader1, consumed_sample)
dataloader_state = dataloader1.get_state_dict()

expected_data = []
for _ in range(AFTER_RESUME_ITER):
expected_data.append(next(dataloader_iter))

new_dataloader2 = build_dataloader(
dataloader_config=dataloader_config,
datasets=datasets,
new_dataloader2 = dataloader_config.build(
tokenizer=tokenizer,
dp_mesh=None,
global_batch_size=GLOBAL_BATCH_SIZE,
micro_batch_size=BATCH_SIZE,
seed=10,
)
load_dataloader_state(new_dataloader2, dataloader_state)
new_dataloader2.load_state_dict(dataloader_state)
new_dataloader_iter2 = iter(new_dataloader2)

resume_data = []
Expand All @@ -282,65 +276,52 @@ def _test_resume_spmd(
rank: int,
world_size: int,
dataloader_config: DataloaderConfig,
dataset_configs: list[dict],
global_batch_size: int,
micro_batch_size: int,
step:int,
step: int,
seed: int,
save_path: Path,
dataloader_state: dict | None = None,
consumed_samples: int = 0,
):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29505"


torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
data_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(world_size,)
mesh_shape=(world_size,),
)
tokenizer = UTF8ByteTokenizer()

datasets = build_datasets(
dataset_config=dataset_configs,
dataloader = dataloader_config.build(
tokenizer=tokenizer,
)
dataloader = build_dataloader(
dataloader_config=dataloader_config,
datasets=datasets,
dp_mesh=data_mesh,
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
seed=seed,
dp_mesh=data_mesh,
)

if dataloader_state is not None:
load_dataloader_state(dataloader, dataloader_state)
dataloader.load_state_dict(dataloader_state)

data_iter = iter(dataloader)
data_list = []
for _ in range(step):
batch = next(data_iter)
data_list.append(batch)
consumed_samples += len(batch)

consumed_samples_list = [None for _ in range(world_size)]
torch.distributed.all_gather_object(consumed_samples_list, consumed_samples)
global_consumed_samples = sum(consumed_samples_list)
# Snapshot after the first `step` batches so total_consumed_samples matches resume intent.
dataloader_state = dataloader.get_state_dict()

expected_data = []

for _ in range(step):
batch = next(data_iter)
expected_data.append(batch)

dataloader_state = get_dataloader_state(dataloader, global_consumed_samples)

all_data_list = [None for _ in range(world_size)]
torch.distributed.all_gather_object(all_data_list, list(chain(*data_list)))

Expand Down Expand Up @@ -372,7 +353,6 @@ def _test_resume_spmd(
"dataloader_state": dataloader_state,
"data_list": all_data_list,
"expected_data": all_expected_data,
"consumed_samples": consumed_samples
}
)
)
Expand All @@ -389,7 +369,6 @@ def _test_resume_spmd(
("none", 0, False),
("soft", 0, True),
("soft", 4, True),
("soft", 4, True),
]
)
def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, group_by_length):
Expand All @@ -402,36 +381,35 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
_create_fake_dataset(data_dir1 / f"depth3", dataset_num=3, max_depth=3, dup_times=9)

# 1. Test resuming with the same world size
dataset_configs = [
{
"dataset": DatasetConfig(anno_path=str(data_dir1)),
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024),
},
]

dataloader_config = DataloaderConfig(
dataset_config_list=dataset_configs,
pack_max_length=1024,
pack_level=pack_level,
num_workers=num_workers,
group_by_length=group_by_length,
collator="fake_collator"
collator="fake_collator",
)
dataset_configs = [
{
"dataset": DatasetConfig(anno_path=str(data_dir1)),
"tokenize_fn": FTDPTokenizeFnConfig(max_length=1024)
},
]

ctx = get_context("spawn")
world_size = 2
save_path1 = tmp_path / "dataloader_state.pkl"
spawn(
_test_resume_spmd,
args=(
world_size,
dataloader_config,
dataset_configs,
16,
BATCH_SIZE,
TOTAL_STEP,
10,
save_path1,
None,
0,
),
nprocs=2,
join=True,
Expand All @@ -448,14 +426,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
args=(
world_size,
dataloader_config,
dataset_configs,
16,
BATCH_SIZE,
TOTAL_STEP,
10,
save_path2,
result1["dataloader_state"],
result1["consumed_samples"],
),
nprocs=world_size,
join=True,
Expand All @@ -475,14 +451,12 @@ def test_dataloader_resume_multi_process(tmp_path, pack_level, num_workers, grou
args=(
world_size,
dataloader_config,
dataset_configs,
16,
BATCH_SIZE,
TOTAL_STEP,
10,
save_path3,
result1["dataloader_state"],
result1["consumed_samples"],
),
nprocs=world_size,
join=True,
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_preset_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from itertools import chain

from xtuner.v1.datasets import PretrainTokenizeFunctionConfig, get_dataloader_state, load_dataloader_state
from xtuner.v1.datasets import PretrainTokenizeFunctionConfig
from xtuner.v1.datasets.config import DatasetConfig, DataloaderConfig
from xtuner.v1.datasets.packing import get_pack_infos_by_hard_split
from xtuner.v1.datasets.preset_pack import PresetPackDataset
Expand Down Expand Up @@ -700,8 +700,8 @@ def _build():
global_consumed_samples = sum(int(x) for x in consumed_samples_list if x is not None)

# 3. Get ckpt state
# dataloader_state = get_dataloader_state(dl, global_consumed_samples)
dataloader_state = dl.get_state_dict(global_consumed_samples)
dataloader_state = dl.get_state_dict()
assert dataloader_state["total_consumed_samples"] == global_consumed_samples

# 4. Continue to consume data at [half_step, 2*half_step)
expected_batches = []
Expand Down Expand Up @@ -738,7 +738,7 @@ def _build():
dl2 = _build()
with ckpt_path.open("rb") as f:
ckpt = pickle.load(f)
# load_dataloader_state(dl2, ckpt["dataloader_state"])
# dl2.load_state_dict(ckpt["dataloader_state"])
dl2.load_state_dict(ckpt["dataloader_state"])

resume_iter = iter(dl2)
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_preset_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_state_dict_resume(tmp_path):

sampler = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1)

state = sampler.get_state_dict(step=3)
state = sampler.get_state_dict(3)
assert state["step"] == 3

sampler2 = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1)
Expand All @@ -102,7 +102,7 @@ def test_state_dict_world_size_mismatch(tmp_path):
path = _write_order_npy(tmp_path, "order.npy", _i64(0, 1, 2, 3))

sampler = PresetSampler(dataset, sampler_config_path=path, global_batch_size=1)
state = sampler.get_state_dict(step=0)
state = sampler.get_state_dict(0)
state["world_size"] = 99

sampler.load_state_dict(state)
Expand Down
3 changes: 0 additions & 3 deletions xtuner/v1/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
PretrainTokenizeFunction,
PretrainTokenizeFunctionConfig,
)
from .resume import get_dataloader_state, load_dataloader_state
from .rl_tokenize_fn import RLTokenizeFnConfig
from .sampler import LengthGroupedSampler, ParallelSampler
from .sft_tokenize_fn import OpenaiTokenizeFunction, OpenaiTokenizeFunctionConfig
Expand Down Expand Up @@ -68,8 +67,6 @@
"InternS1VLTokenizeFnConfig",
"fake_collator",
"RLTokenizeFnConfig",
"get_dataloader_state",
"load_dataloader_state",
"DatasetConfigList",
"DataloaderConfig",
"BaseTokenizeFnConfig",
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,5 +544,6 @@ def build(
collate_fn=collator,
multiprocessing_context=ctx if self.num_workers > 0 else None,
persistent_workers=self.num_workers > 0,
dp_mesh=dp_mesh,
)
return dataloader
Loading
Loading