Skip to content

Commit 19df587

Browse files
committed
[Opt] reduce comm usage for ep train and ep rollout in RL scenario
1 parent 714483a commit 19df587

2 files changed

Lines changed: 104 additions & 10 deletions

File tree

tests/ray/test_update_weight.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,15 @@ def test_lmdeploy_update_weight_and_generate(self):
132132
ray.get(train_controller.onload.remote(target="all"))
133133
ray.get(train_controller.offload.remote(["optimizer"]))
134134
ray.get(rollout_controller.onload_weights.remote())
135+
import time
136+
start_time = time.time()
135137
ray.get(train_controller.update_weights.remote())
138+
end_time = time.time()
139+
print(f"Update weights time 1: {end_time - start_time} seconds")
140+
start_time = time.time()
141+
ray.get(train_controller.update_weights.remote())
142+
end_time = time.time()
143+
print(f"Update weights time 2: {end_time - start_time} seconds")
136144
ray.get(train_controller.offload.remote(["model"]))
137145
ray.get(rollout_controller.onload_kvcache.remote())
138146

xtuner/v1/rl/base/worker.py

Lines changed: 96 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from xtuner.v1.model.base import ModelItem, TransformerConfig
3333
from xtuner.v1.model.compose.base import BaseComposeConfig, BaseComposeModel
3434
from xtuner.v1.model.compose.qwen3_vl import Qwen3VLForConditionalGeneration
35+
from xtuner.v1.model.moe.moe import MoE
3536
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo
3637
from xtuner.v1.ray.base import SingleAcceleratorWorker
3738
from xtuner.v1.ray.config import RolloutConfig
@@ -46,7 +47,7 @@
4647
monkey_unpatch_torch_reductions,
4748
ray_method,
4849
)
49-
from xtuner.v1.utils.load_spec import LoadEnum
50+
from xtuner.v1.utils.load_spec import LoadEnum, LoadSpec
5051

5152
from ..loss_fn import kl_penalty
5253
from .loss import BaseRLLossConfig
@@ -252,6 +253,7 @@ def __init__(
252253
else:
253254
mode = "eager"
254255
self.logprob_cfg = LogProbConfig(chunk_size=worker_cfg.loss_cfg.chunk_size, mode=mode)
256+
self._global_hf_keys_mapping_cache: dict[str, list[str]] = dict()
255257

256258
def _init_sft(self, worker_cfg: WorkerConfig):
257259
self._sft_dataloader_config = worker_cfg.sft_dataloader_cfg
@@ -863,7 +865,79 @@ def update_weights(self):
863865
else:
864866
self._update_weights_hf_generator()
865867

866-
def _update_weights_hf_generator(self, submodule=None, final_update=True):
868+
def _rl_get_fused_ep_hf_param(self, model: MoE, target_ep_rank: int, target_ep_size: int, bucket_size: int):
869+
fused_param_groups: list[tuple[torch.Tensor, LoadSpec]] = model._group_param_by_load_spec(LoadEnum.FUSED)
870+
model_ep_size = 1 if model.fsdp_config is None else model.fsdp_config.ep_size
871+
if not fused_param_groups:
872+
return
873+
874+
def _get_hf_params(
875+
fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]],
876+
) -> tuple[list[torch.Tensor], list[str]]:
877+
hf_keys_list: list[str] = []
878+
# Split the fused tensor into target hf tensors
879+
hf_tensor_list: list[torch.Tensor] = []
880+
881+
for fsdp_tensor, load_spec in fsdp_tensor_list:
882+
hf_keys = load_spec.hf_keys
883+
if model_ep_size > 1 and model.ep_mesh is not None:
884+
if load_spec.name not in self._global_hf_keys_mapping_cache:
885+
global_hf_keys: list[list[str] | None] = [None] * model_ep_size
886+
dist.all_gather_object(global_hf_keys, hf_keys, group=model.ep_mesh.get_group())
887+
global_hf_keys_gathered = cast(list[list[str]], global_hf_keys)
888+
self._global_hf_keys_mapping_cache[load_spec.name] = list(
889+
chain.from_iterable(global_hf_keys_gathered)
890+
)
891+
hf_keys = self._global_hf_keys_mapping_cache[load_spec.name]
892+
893+
fused_full_tensor = fsdp_tensor.bfloat16()
894+
if isinstance(fused_full_tensor, DTensor):
895+
fused_full_tensor = fused_full_tensor.full_tensor()
896+
dim = cast(int, load_spec.dim)
897+
num_split = len(hf_keys)
898+
hf_tensor_size = fused_full_tensor.shape[dim] / num_split
899+
assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer"
900+
hf_tensor_size = int(hf_tensor_size)
901+
902+
hf_tensor = fused_full_tensor.split([hf_tensor_size] * num_split, dim=dim)
903+
# slice target ep rank
904+
assert num_split % target_ep_size == 0, (
905+
f"len(hf_keys) of '{hf_keys}' is {num_split}, it must be divisible by target_ep_size {target_ep_size}"
906+
)
907+
start_idx = (num_split // target_ep_size) * target_ep_rank
908+
end_idx = (num_split // target_ep_size) * (target_ep_rank + 1)
909+
910+
hf_keys_list.extend(hf_keys[start_idx:end_idx])
911+
hf_tensor_list.extend(hf_tensor[start_idx:end_idx])
912+
913+
hf_tensor_list = [
914+
model.param_to_safetensor(safetensor, name) for safetensor, name in zip(hf_tensor_list, hf_keys_list)
915+
]
916+
917+
return hf_tensor_list, hf_keys_list
918+
919+
safetensor_size = 0
920+
dtype = torch.bfloat16 # hardcode bfloat16 for now
921+
tensor_list: list[tuple[torch.Tensor, LoadSpec]] = []
922+
923+
for param, load_spec in fused_param_groups:
924+
tensor_size = dtype.itemsize * param.numel() // target_ep_size
925+
if safetensor_size + tensor_size > bucket_size and tensor_list:
926+
hf_params, name_list = _get_hf_params(tensor_list)
927+
yield name_list, hf_params
928+
safetensor_size = tensor_size
929+
name_list = load_spec.hf_keys.copy()
930+
tensor_list = [(param, load_spec)]
931+
continue
932+
safetensor_size += tensor_size
933+
tensor_list.append((param, load_spec))
934+
935+
if tensor_list:
936+
hf_params, name_list = _get_hf_params(tensor_list)
937+
yield name_list, hf_params
938+
939+
@torch.no_grad()
940+
def _update_weights_hf_generator(self, submodule=None, final_update=False):
867941
"""Update the model weights."""
868942
self.endpoints["update_weights"] = "update_weights"
869943
assert self.rollout_device_mesh is not None
@@ -881,20 +955,32 @@ def _update_weights_hf_generator(self, submodule=None, final_update=True):
881955
same_gen = model._get_same_hf_param(
882956
model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size
883957
)
884-
fused_gen = model._get_fused_hf_param(
885-
model._group_param_by_load_spec(LoadEnum.FUSED),
886-
dtype=dtype,
887-
device=DEVICE,
888-
bucket_size=bucket_size,
889-
update_weights_for_rl=True,
890-
)
958+
959+
train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1
960+
if train_enable_ep and self.rollout_cfg_info["ep"] > 1:
961+
# rollout_device_mesh contains the coordinate info of rollout engine
962+
# whose the coordinate is the same as training engine rank
963+
fused_gen = self._rl_get_fused_ep_hf_param(
964+
model,
965+
target_ep_rank=self.rollout_device_mesh["engine_parallel"].get_coordinate()[0],
966+
target_ep_size=self.rollout_device_mesh["engine_parallel"].size(),
967+
bucket_size=bucket_size,
968+
)
969+
else:
970+
fused_gen = model._get_fused_hf_param(
971+
model._group_param_by_load_spec(LoadEnum.FUSED),
972+
dtype=dtype,
973+
device=DEVICE,
974+
bucket_size=bucket_size,
975+
update_weights_for_rl=True,
976+
)
891977
shard_gen = model._get_shard_hf_param(
892978
model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size
893979
)
894980

895981
for name_list, fused_param_list in fused_gen:
896982
state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)}
897-
if model.fsdp_config.ep_size > 1:
983+
if train_enable_ep and self.rollout_cfg_info["tp"] > 1:
898984
# When ep_size > 1, generator generates part of the fused param on each ep rank in one ep_group.
899985
# We can all gather them to get full fused param but it would lead to a larger memory usage.
900986
# So we broadcast the part fused param from each ep rank in ep_group sequentially,

0 commit comments

Comments
 (0)