2626 MixedPrecisionPolicy ,
2727 fully_shard ,
2828)
29- from torch .distributed .tensor import DTensor , Placement , Shard
29+ from torch .distributed .tensor import DTensor , Placement , Replicate , Shard , distribute_tensor
3030from torch .distributed .tensor ._utils import compute_local_shape_and_global_offset
3131from typing_extensions import NotRequired , Self , TypedDict , overload
3232
@@ -82,6 +82,12 @@ class HFSaveCfg(PydanticBaseModel):
8282 worker_per_rank : Annotated [int , Parameter (group = "model" )] = 16
8383 max_save_rank : Annotated [int , Parameter (group = "model" )] = 16
8484 bucket_size : Annotated [int , Parameter (group = "model" )] = 1024 ** 3 * 4
85+ # TODO: `XTunerBaseModel` should also be able to specify which parameters to be trained in fp32,
86+ # currently it could only be specified in HFSaveCfg
87+ # Each entry is a **regex** pattern (passed to `re.search`) matched against the HF parameter name.
88+ # Remember to escape literal dots, e.g. use r"model\.layers\.\d+\.weight" instead of
89+ # r"model.layers.\d+.weight" to avoid unintended wildcard matches.
90+ fp32_keys_pattern : Annotated [list [str ] | None , Parameter (group = "model" )] = None
8591
8692
8793class XTunerBaseModelConfig (PydanticBaseModel ):
@@ -313,6 +319,7 @@ def fully_shard(
313319 """Fully shard the model parameters."""
314320 self .fsdp_config = fsdp_config
315321 self .fsdp_mesh = self ._init_world_mesh ()
322+ self ._world_mesh = self .fsdp_mesh
316323
317324 if self .fsdp_config .requires_grad :
318325 for name , module in self .named_modules ():
@@ -337,15 +344,79 @@ def fully_shard(
337344 mp_policy = MixedPrecisionPolicy (
338345 param_dtype = self .fsdp_config .param_dtype , reduce_dtype = fsdp_config .reduce_dtype
339346 )
340- fully_shard (
341- self ,
347+ self ._fully_shard (
342348 mesh = self .fsdp_mesh ,
343349 mp_policy = mp_policy ,
344350 reshard_after_forward = fsdp_config .reshard_after_forward ,
345351 offload_policy = CPUOffloadPolicy () if self .fsdp_config .cpu_offload else None ,
346352 )
347353 return self
348354
355+ def _fully_shard (
356+ self ,
357+ mesh : DeviceMesh ,
358+ mp_policy : MixedPrecisionPolicy ,
359+ reshard_after_forward : bool ,
360+ offload_policy : CPUOffloadPolicy | None ,
361+ module : nn .Module | None = None ,
362+ ) -> None :
363+ def traverse (module ):
364+ for name , param in module .named_parameters (recurse = False ):
365+ full_name = full_param_name_mapping [id (param )]
366+ full_name = self ._clean_param_name (full_name )
367+ hf_name_list = self .to_hf_key_list (full_name )
368+
369+ for hf_name in hf_name_list :
370+ if any (re .search (p , hf_name ) for p in patterns ): # type: ignore
371+ if not isinstance (param , DTensor ):
372+ dist_param = nn .Parameter (
373+ distribute_tensor (
374+ param , self .world_mesh , [Replicate () for _ in range (self .world_mesh .ndim )]
375+ ),
376+ requires_grad = param .requires_grad ,
377+ )
378+ module .register_parameter (name , dist_param )
379+ ignored_params .add (dist_param )
380+ else :
381+ # param is already a DTensor (e.g. distributed by
382+ # MoE._replicate_other_params on ep_mesh before _fully_shard
383+ # is called). We skip re-distributing on world_mesh and just
384+ # add it to ignored_params so FSDP leaves it alone.
385+ # ASSUMPTION: fp32 distribution always happens AFTER any
386+ # prior EP distribution, so the existing placement is correct.
387+ ignored_params .add (param )
388+ break
389+
390+ for child in module .children ():
391+ traverse (child )
392+
393+ # Collect the parameters of `target` that match any fp32 pattern so they can be
394+ # excluded from FSDP sharding (passed as `ignored_params`).
395+ #
396+ # We intentionally iterate over `self.named_parameters()` rather than
397+ # `target.named_parameters()` so that `name` is always relative to the root model
398+ # (`self`). This matters when `target` is a sub-module (e.g. `self.embed_tokens`):
399+ # `target.named_parameters()` would yield bare names like `"weight"`, which
400+ # `to_hf_key_list` cannot resolve correctly. By iterating from `self` we get the
401+ # full path (e.g. `"embed_tokens.weight"`) and filter to `target`'s parameters
402+ # using identity comparison.
403+ full_param_name_mapping = {id (param ): name for name , param in self .named_parameters ()}
404+ ignored_params : set [nn .Parameter ] = set ()
405+ patterns = self .config .hf_save_cfg .fp32_keys_pattern
406+
407+ target = module or self
408+ if patterns :
409+ traverse (target )
410+
411+ fully_shard (
412+ target ,
413+ mesh = mesh ,
414+ mp_policy = mp_policy ,
415+ reshard_after_forward = reshard_after_forward ,
416+ offload_policy = offload_policy ,
417+ ignored_params = ignored_params if ignored_params else None ,
418+ )
419+
349420 def save_hf (self , hf_dir : Path | str , save_dtype : torch .dtype = torch .bfloat16 , safetensors_prefix : str = "model" ):
350421 with profile_time_and_memory (f"[Saving HF to [{ safetensors_prefix } ]{ hf_dir } cost]" ):
351422 self ._save_hf (hf_dir = hf_dir , save_dtype = save_dtype , safetensors_prefix = safetensors_prefix )
@@ -396,6 +467,12 @@ def device(self) -> torch.device:
396467 return torch .device ("cpu" )
397468 return torch .device (DEVICE )
398469
470+ @property
471+ def world_mesh (self ) -> DeviceMesh | None :
472+ if not hasattr (self , "_world_mesh" ):
473+ self ._world_mesh = self ._init_world_mesh ()
474+ return self ._world_mesh
475+
399476 @property
400477 def default_compile_cfg (self ) -> dict [str , TorchCompileOption ]:
401478 return {}
@@ -670,6 +747,12 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat
670747 )
671748 return ret
672749
750+ def _get_save_dtype (self , name : str , dtype : torch .dtype ) -> torch .dtype :
751+ patterns = self .config .hf_save_cfg .fp32_keys_pattern
752+ if patterns and any (re .search (p , name ) for p in patterns ):
753+ return torch .float32
754+ return dtype
755+
673756 def _get_shard_hf_param (
674757 self ,
675758 params : list [tuple [torch .Tensor , LoadSpec ]],
@@ -679,6 +762,16 @@ def _get_shard_hf_param(
679762 ) -> Generator [tuple [list [str ], list [torch .Tensor ]], None , None ]:
680763 if not params :
681764 return
765+
766+ ignored_params , params = self ._split_ignored_params (params )
767+ if ignored_params :
768+ name_list : list [str ] = [load_spec .hf_keys [0 ] for _ , load_spec in ignored_params ]
769+ hf_params = [param ._local_tensor if isinstance (param , DTensor ) else param for param , _ in ignored_params ]
770+ yield name_list , hf_params
771+
772+ if not params :
773+ return
774+
682775 if dtype != torch .bfloat16 :
683776 raise NotImplementedError
684777
@@ -696,7 +789,7 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis
696789 # Get unsharded params
697790 _unsharded_tensor_list = foreach_all_gather (fsdp_unsharded_tensor_list , load_spec0 .group )
698791 unsharded_tensor_list = [
699- torch .cat ([ i . to ( dtype ) for i in tensors ] , dim = load_spec0 .dim ) for tensors in _unsharded_tensor_list
792+ torch .cat (list ( tensors ) , dim = load_spec0 .dim ) for tensors in _unsharded_tensor_list
700793 ]
701794 name_list = [spec .hf_keys [0 ] for _ , spec in fsdp_tensor_list ]
702795 unsharded_tensor_list = [
@@ -711,11 +804,11 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis
711804
712805 safetensor_size = 0
713806 tensor_list : list [tuple [torch .Tensor , LoadSpec ]] = []
714- name_list : list [ str ] = []
807+ name_list = []
715808
716809 for param , load_spec in params :
717810 local_tensor = param ._local_tensor if isinstance (param , DTensor ) else param
718- local_tensor = local_tensor .to (dtype = dtype )
811+ local_tensor = local_tensor .to (dtype = self . _get_save_dtype ( load_spec . hf_keys [ 0 ], torch . bfloat16 ) )
719812 tensor_size = self ._get_tensor_size (param , dtype )
720813 if safetensor_size + tensor_size > bucket_size and tensor_list :
721814 hf_params = _get_hf_params (tensor_list )
@@ -744,6 +837,12 @@ def _get_fused_hf_param(
744837 if not params :
745838 return
746839
840+ ignored_params , params = self ._split_ignored_params (params )
841+ if ignored_params :
842+ fp32_name_list : list [str ] = [load_spec .hf_keys [0 ] for _ , load_spec in ignored_params ]
843+ fp32_params = [param ._local_tensor if isinstance (param , DTensor ) else param for param , _ in ignored_params ]
844+ yield fp32_name_list , fp32_params
845+
747846 def _get_hf_params (
748847 fsdp_tensor_list : list [tuple [torch .Tensor , LoadSpec ]],
749848 name_list : list [str ],
@@ -867,7 +966,7 @@ def _get_hf_params(
867966
868967 for param , load_spec in params :
869968 local_tensor = param ._local_tensor if isinstance (param , DTensor ) else param
870- local_tensor = local_tensor .bfloat16 ( )
969+ local_tensor = local_tensor .to ( dtype = self . _get_save_dtype ( load_spec . hf_keys [ 0 ], torch . bfloat16 ) )
871970 tensor_size = self ._get_tensor_size (param , dtype )
872971 if safetensor_size + tensor_size > bucket_size and tensor_list :
873972 hf_params , name_list = _get_hf_params (tensor_list , name_list )
@@ -893,6 +992,15 @@ def _get_same_hf_param(
893992 ) -> Generator [tuple [list [str ], list [torch .Tensor ]], None , None ]:
894993 if not params :
895994 return
995+
996+ ignored_params , params = self ._split_ignored_params (params )
997+ if ignored_params :
998+ fp32_name_list : list [str ] = [load_spec .hf_keys [0 ] for _ , load_spec in ignored_params ]
999+ fp32_tensor_list : list [torch .Tensor ] = [
1000+ param ._local_tensor if isinstance (param , DTensor ) else param for param , _ in ignored_params
1001+ ]
1002+ yield fp32_name_list , fp32_tensor_list
1003+
8961004 if bucket_size is None :
8971005 bucket_size = self .config .hf_save_cfg .bucket_size
8981006 safetensor_size = 0
@@ -909,7 +1017,7 @@ def _get_same_hf_param(
9091017 buffer_name_list .append (load_spec .hf_keys [0 ])
9101018 continue
9111019 local_tensor = param ._local_tensor if isinstance (param , DTensor ) else param
912- local_tensor = local_tensor .bfloat16 ( )
1020+ local_tensor = local_tensor .to ( dtype = self . _get_save_dtype ( load_spec . hf_keys [ 0 ], torch . bfloat16 ) )
9131021 tensor_size = self ._get_tensor_size (param , dtype )
9141022 if safetensor_size + tensor_size > bucket_size and tensor_list :
9151023 if self .fsdp_mesh is not None :
@@ -953,6 +1061,21 @@ def _get_same_hf_param(
9531061 if buffer_tensor_list :
9541062 yield buffer_name_list , buffer_tensor_list
9551063
1064+ def _is_ignored_params (self , key : str ):
1065+ patterns = self .config .hf_save_cfg .fp32_keys_pattern
1066+ if patterns is None :
1067+ return False
1068+ return any (re .search (p , key ) for p in patterns )
1069+
1070+ def _split_ignored_params (
1071+ self , params : list [tuple [torch .Tensor , LoadSpec ]]
1072+ ) -> tuple [list [tuple [torch .Tensor , LoadSpec ]], list [tuple [torch .Tensor , LoadSpec ]]]:
1073+ if not self .config .hf_save_cfg .fp32_keys_pattern :
1074+ return [], params
1075+ ignored_params = [(p , l ) for p , l in params if self ._is_ignored_params (l .hf_keys [0 ])]
1076+ remaining = [(p , l ) for p , l in params if not self ._is_ignored_params (l .hf_keys [0 ])]
1077+ return ignored_params , remaining
1078+
9561079 # TODO: Using `xtuenr.v1.utils.misc.clean_param_name`
9571080 def _clean_param_name (self , name : str ) -> str :
9581081 if "_checkpoint_wrapped_module." in name :
@@ -1230,7 +1353,12 @@ def _load_same_hf_param(
12301353
12311354 loaded_tensor = loaded_tensor .to (local_tensor .device )
12321355
1233- if self .fsdp_mesh is not None and isinstance (param , nn .Parameter ):
1356+ if (
1357+ self .fsdp_mesh is not None
1358+ and isinstance (param , nn .Parameter )
1359+ and isinstance (param , DTensor )
1360+ and any (isinstance (p , Shard ) for p in param .placements )
1361+ ):
12341362 shape_before_fsdp = load_spec .shape
12351363 _ , _offset = compute_local_shape_and_global_offset (
12361364 shape_before_fsdp , self .fsdp_mesh , [Shard (self .FSDP_SHARD_DIM )]
0 commit comments