Skip to content

Commit d3f185d

Browse files
authored
[perf] Accelerate HF model saving and avoid OOM on npu (#1631)
1 parent 8b33a95 commit d3f185d

2 files changed

Lines changed: 19 additions & 18 deletions

File tree

xtuner/v1/ops/comm/foreach_allgather.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import cast
2-
31
import torch
42
import torch.distributed as dist
53

@@ -10,6 +8,10 @@ def foreach_all_gather(
108
params: list[torch.Tensor],
119
group: dist.ProcessGroup | None,
1210
) -> list[list[torch.Tensor]]:
11+
"""Perform a fused all-gather on a list of tensors.
12+
13+
All ranks must contribute tensors with identical numels and shapes.
14+
"""
1315
if group is None:
1416
group = dist.group.WORLD
1517

@@ -18,29 +20,23 @@ def foreach_all_gather(
1820

1921
input_tensor_numels = [param.numel() for param in params]
2022
input_tensor_shapes = [param.shape for param in params]
23+
world_size = dist.get_world_size(group)
24+
local_tensor_size = sum(input_tensor_numels)
25+
global_tensor_size = local_tensor_size * world_size
2126

22-
flatten_copyin_tensor = torch.empty((sum(input_tensor_numels),), dtype=param0.dtype, device=param0.device)
27+
# prepare flatten tensor
28+
flatten_copyin_tensor = torch.empty((local_tensor_size,), dtype=param0.dtype, device=param0.device)
2329
splits_copyin_tensor = torch.split(flatten_copyin_tensor, input_tensor_numels)
2430
torch._foreach_copy_(splits_copyin_tensor, [p.flatten() for p in params])
31+
flatten_copyout_tensor = torch.empty((global_tensor_size,), dtype=param0.dtype, device=param0.device)
2532

26-
input_tensor_numels_tensor = torch.tensor(input_tensor_numels, dtype=torch.int64, device=param0.device)
27-
global_input_tensor_numels = [
28-
torch.zeros_like(input_tensor_numels_tensor) for _ in range(dist.get_world_size(group))
29-
]
30-
31-
dist.all_gather(global_input_tensor_numels, input_tensor_numels_tensor, group=group)
32-
copyout_size = int(sum(sum(i) for i in global_input_tensor_numels))
33-
flatten_copyout_tensor = torch.empty((copyout_size,), dtype=param0.dtype, device=param0.device)
34-
33+
# allgather global flatten tensor
3534
dist.all_gather_into_tensor(flatten_copyout_tensor, flatten_copyin_tensor, group=group)
36-
copyout_split_size: list[int] = sum([i.tolist() for i in global_input_tensor_numels], [])
35+
copyout_split_size: list[int] = input_tensor_numels * world_size
3736
splits_copyout_tensor = torch.split(flatten_copyout_tensor, copyout_split_size)
37+
global_input_tensor_shapes = input_tensor_shapes * world_size
3838

39-
_global_input_tensor_shapes: list[None] | list[list[tuple]] = [None for _ in range(dist.get_world_size(group))]
40-
dist.all_gather_object(_global_input_tensor_shapes, input_tensor_shapes, group=group)
41-
_global_input_tensor_shapes = cast(list[list[tuple]], _global_input_tensor_shapes)
42-
global_input_tensor_shapes: list[tuple] = sum(_global_input_tensor_shapes, [])
43-
39+
# gathered_params: [[params1/p, params1/p,...], [params2/p, params2/p,...], ...]
4440
gathered_params: list[list[torch.Tensor]] = []
4541
for i in range(len(params)):
4642
single_gathered_params: list[torch.Tensor] = []

xtuner/v1/train/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,9 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11131113
scheduler_path = checkpoint_path / self._SAVE_SCHEDULER_DIR
11141114
train_state_path = checkpoint_path / self._SAVE_TRAIN_STATE_PATH
11151115

1116+
if self.cur_step % ckp_interval == 0:
1117+
DEVICE_MODULE.empty_cache()
1118+
11161119
# Save model and optimizer
11171120
self._engine.save_dcp(
11181121
model_dir=model_path,
@@ -1122,6 +1125,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11221125
# Save dataloader
11231126
self._save_dataloader(dataloader_path)
11241127

1128+
DEVICE_MODULE.empty_cache()
1129+
11251130
# Save scheduler
11261131
if self.rank == 0:
11271132
lr_scheduler_state = self._lr_scheduler.state_dict()

0 commit comments

Comments
 (0)