3232from xtuner .v1 .model .base import ModelItem , TransformerConfig
3333from xtuner .v1 .model .compose .base import BaseComposeConfig , BaseComposeModel
3434from xtuner .v1 .model .compose .qwen3_vl import Qwen3VLForConditionalGeneration
35+ from xtuner .v1 .model .moe .moe import MoE
3536from xtuner .v1 .model .utils .misc import ModelForwardExtraLogInfo
3637from xtuner .v1 .ray .base import SingleAcceleratorWorker
3738from xtuner .v1 .ray .config import RolloutConfig
4647 monkey_unpatch_torch_reductions ,
4748 ray_method ,
4849)
49- from xtuner .v1 .utils .load_spec import LoadEnum
50+ from xtuner .v1 .utils .load_spec import LoadEnum , LoadSpec
5051
5152from ..loss_fn import kl_penalty
5253from .loss import BaseRLLossConfig
@@ -252,6 +253,7 @@ def __init__(
252253 else :
253254 mode = "eager"
254255 self .logprob_cfg = LogProbConfig (chunk_size = worker_cfg .loss_cfg .chunk_size , mode = mode )
256+ self ._global_hf_keys_mapping_cache : dict [str , list [str ]] = dict ()
255257
256258 def _init_sft (self , worker_cfg : WorkerConfig ):
257259 self ._sft_dataloader_config = worker_cfg .sft_dataloader_cfg
@@ -863,7 +865,79 @@ def update_weights(self):
863865 else :
864866 self ._update_weights_hf_generator ()
865867
866- def _update_weights_hf_generator (self , submodule = None , final_update = True ):
868+ def _rl_get_fused_ep_hf_param (self , model : MoE , target_ep_rank : int , target_ep_size : int , bucket_size : int ):
869+ fused_param_groups : list [tuple [torch .Tensor , LoadSpec ]] = model ._group_param_by_load_spec (LoadEnum .FUSED )
870+ model_ep_size = 1 if model .fsdp_config is None else model .fsdp_config .ep_size
871+ if not fused_param_groups :
872+ return
873+
874+ def _get_hf_params (
875+ fsdp_tensor_list : list [tuple [torch .Tensor , LoadSpec ]],
876+ ) -> tuple [list [torch .Tensor ], list [str ]]:
877+ hf_keys_list : list [str ] = []
878+ # Split the fused tensor into target hf tensors
879+ hf_tensor_list : list [torch .Tensor ] = []
880+
881+ for fsdp_tensor , load_spec in fsdp_tensor_list :
882+ hf_keys = load_spec .hf_keys
883+ if model_ep_size > 1 and model .ep_mesh is not None :
884+ if load_spec .name not in self ._global_hf_keys_mapping_cache :
885+ global_hf_keys : list [list [str ] | None ] = [None ] * model_ep_size
886+ dist .all_gather_object (global_hf_keys , hf_keys , group = model .ep_mesh .get_group ())
887+ global_hf_keys_gathered = cast (list [list [str ]], global_hf_keys )
888+ self ._global_hf_keys_mapping_cache [load_spec .name ] = list (
889+ chain .from_iterable (global_hf_keys_gathered )
890+ )
891+ hf_keys = self ._global_hf_keys_mapping_cache [load_spec .name ]
892+
893+ fused_full_tensor = fsdp_tensor .bfloat16 ()
894+ if isinstance (fused_full_tensor , DTensor ):
895+ fused_full_tensor = fused_full_tensor .full_tensor ()
896+ dim = cast (int , load_spec .dim )
897+ num_split = len (hf_keys )
898+ hf_tensor_size = fused_full_tensor .shape [dim ] / num_split
899+ assert hf_tensor_size .is_integer (), "Internal Error, hf_tensor_size is not integer"
900+ hf_tensor_size = int (hf_tensor_size )
901+
902+ hf_tensor = fused_full_tensor .split ([hf_tensor_size ] * num_split , dim = dim )
903+ # slice target ep rank
904+ assert num_split % target_ep_size == 0 , (
905+ f"len(hf_keys) of '{ hf_keys } ' is { num_split } , it must be divisible by target_ep_size { target_ep_size } "
906+ )
907+ start_idx = (num_split // target_ep_size ) * target_ep_rank
908+ end_idx = (num_split // target_ep_size ) * (target_ep_rank + 1 )
909+
910+ hf_keys_list .extend (hf_keys [start_idx :end_idx ])
911+ hf_tensor_list .extend (hf_tensor [start_idx :end_idx ])
912+
913+ hf_tensor_list = [
914+ model .param_to_safetensor (safetensor , name ) for safetensor , name in zip (hf_tensor_list , hf_keys_list )
915+ ]
916+
917+ return hf_tensor_list , hf_keys_list
918+
919+ safetensor_size = 0
920+ dtype = torch .bfloat16 # hardcode bfloat16 for now
921+ tensor_list : list [tuple [torch .Tensor , LoadSpec ]] = []
922+
923+ for param , load_spec in fused_param_groups :
924+ tensor_size = dtype .itemsize * param .numel () // target_ep_size
925+ if safetensor_size + tensor_size > bucket_size and tensor_list :
926+ hf_params , name_list = _get_hf_params (tensor_list )
927+ yield name_list , hf_params
928+ safetensor_size = tensor_size
929+ name_list = load_spec .hf_keys .copy ()
930+ tensor_list = [(param , load_spec )]
931+ continue
932+ safetensor_size += tensor_size
933+ tensor_list .append ((param , load_spec ))
934+
935+ if tensor_list :
936+ hf_params , name_list = _get_hf_params (tensor_list )
937+ yield name_list , hf_params
938+
939+ @torch .no_grad ()
940+ def _update_weights_hf_generator (self , submodule = None , final_update = False ):
867941 """Update the model weights."""
868942 self .endpoints ["update_weights" ] = "update_weights"
869943 assert self .rollout_device_mesh is not None
@@ -881,20 +955,32 @@ def _update_weights_hf_generator(self, submodule=None, final_update=True):
881955 same_gen = model ._get_same_hf_param (
882956 model ._group_param_by_load_spec (LoadEnum .SAME ), dtype = dtype , device = DEVICE , bucket_size = bucket_size
883957 )
884- fused_gen = model ._get_fused_hf_param (
885- model ._group_param_by_load_spec (LoadEnum .FUSED ),
886- dtype = dtype ,
887- device = DEVICE ,
888- bucket_size = bucket_size ,
889- update_weights_for_rl = True ,
890- )
958+
959+ train_enable_ep = model .fsdp_config is not None and model .fsdp_config .ep_size > 1
960+ if train_enable_ep and self .rollout_cfg_info ["ep" ] > 1 :
961+ # rollout_device_mesh contains the coordinate info of rollout engine
962+ # whose the coordinate is the same as training engine rank
963+ fused_gen = self ._rl_get_fused_ep_hf_param (
964+ model ,
965+ target_ep_rank = self .rollout_device_mesh ["engine_parallel" ].get_coordinate ()[0 ],
966+ target_ep_size = self .rollout_device_mesh ["engine_parallel" ].size (),
967+ bucket_size = bucket_size ,
968+ )
969+ else :
970+ fused_gen = model ._get_fused_hf_param (
971+ model ._group_param_by_load_spec (LoadEnum .FUSED ),
972+ dtype = dtype ,
973+ device = DEVICE ,
974+ bucket_size = bucket_size ,
975+ update_weights_for_rl = True ,
976+ )
891977 shard_gen = model ._get_shard_hf_param (
892978 model ._group_param_by_load_spec (LoadEnum .SHARD ), dtype = dtype , device = DEVICE , bucket_size = bucket_size
893979 )
894980
895981 for name_list , fused_param_list in fused_gen :
896982 state_dict = {name : param .detach () for name , param in zip (name_list , fused_param_list )}
897- if model . fsdp_config . ep_size > 1 :
983+ if train_enable_ep and self . rollout_cfg_info [ "tp" ] > 1 :
898984 # When ep_size > 1, generator generates part of the fused param on each ep rank in one ep_group.
899985 # We can all gather them to get full fused param but it would lead to a larger memory usage.
900986 # So we broadcast the part fused param from each ep rank in ep_group sequentially,
0 commit comments