Skip to content

Commit b0fdc8d

Browse files
committed
[Fix] Support fp32 param preservation during FSDP save and load
- Add `fp32_keys_pattern` to `HFSaveCfg` to specify params that should be saved in fp32 regardless of the global save dtype - Add `_fully_shard()` to `BaseModel` which wraps `fully_shard()` and distributes matched params as Replicate DTensors so FSDP ignores them - Replace all direct `fully_shard()` call sites with `self._fully_shard()` - Add `_get_save_dtype()` to select fp32 per-param at save time - Fix load path in `_load_same_hf_param` to skip shard-offset logic for Replicate DTensors (only apply for Shard-placed params) - Fix `world_mesh` property setter bug: use `_world_mesh` directly instead of assigning through the read-only property - Fix gradient reduce in `MoE.scale_and_reduce_grad` to correctly identify the Replicate mesh dimension via DTensor placements - Preserve `requires_grad` when distributing params in EP and fp32 paths - Configure `Qwen3_5_VLTextMoEConfig` with fp32 patterns for `linear_attn.norm.weight` and `linear_attn.A_log` - Add `test_save_hf_with_mtp` to verify round-trip weight preservation
1 parent 3bbbb88 commit b0fdc8d

11 files changed

Lines changed: 278 additions & 62 deletions

File tree

tests/model/test_qwen3_5.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
from xtuner.v1.config import FSDPConfig
1616
from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh
1717

18+
import tempfile
19+
from pathlib import Path
20+
import json
21+
from safetensors import safe_open
22+
1823

1924
VIDEO_ROOT = os.environ["VIDEO_ROOT"]
2025

@@ -216,6 +221,82 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol):
216221
self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol))
217222
self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol))
218223

224+
@parametrize.parametrize(
225+
"device,sp_size",
226+
[
227+
("cuda", 1),
228+
],
229+
)
230+
def test_save_hf_with_mtp(self, device, sp_size):
231+
self.create_pg(device)
232+
QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"]
233+
234+
with torch.device("meta"):
235+
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
236+
qwen3vl_model = model_cfg.build().to(torch.bfloat16)
237+
238+
fsdp_config = FSDPConfig(cpu_offload=False)
239+
fsdp_mesh = init_world_mesh()
240+
qwen3vl_model.vision_tower.fsdp_mesh = fsdp_mesh
241+
qwen3vl_model.vision_tower.fsdp_config = fsdp_config
242+
qwen3vl_model.fully_shard(fsdp_config=fsdp_config)
243+
244+
with tempfile.TemporaryDirectory() as tmpdir:
245+
syncdir = [tmpdir]
246+
dist.broadcast_object_list(syncdir, src=0)
247+
tmpdir = Path(syncdir[0])
248+
qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH)
249+
qwen3vl_model.save_hf(tmpdir)
250+
251+
origin_hf_path = Path(QWEN3_VL_MOE_PATH)
252+
origin_index_path = origin_hf_path / "model.safetensors.index.json"
253+
saved_index_path = tmpdir / "model.safetensors.index.json"
254+
255+
if dist.get_rank() == 0:
256+
with open(origin_index_path, "r") as f:
257+
origin_index = json.load(f)
258+
with open(saved_index_path, "r") as f:
259+
saved_index = json.load(f)
260+
261+
cache_save_fh: dict = {}
262+
263+
# Verify all original HF weights are preserved correctly
264+
for key in origin_index["weight_map"].keys():
265+
if "mtp" in key:
266+
continue # TODO: remove this after MTP is implemented
267+
origin_safetensor_name = origin_index["weight_map"][key]
268+
saved_safetensor_name = saved_index["weight_map"][key]
269+
270+
origin_sf_fh_name = str(origin_hf_path / origin_safetensor_name)
271+
saved_sf_fh_name = str(tmpdir / saved_safetensor_name)
272+
273+
if origin_sf_fh_name not in cache_save_fh:
274+
cache_save_fh[origin_sf_fh_name] = safe_open(origin_sf_fh_name, framework="pt")
275+
if saved_sf_fh_name not in cache_save_fh:
276+
cache_save_fh[saved_sf_fh_name] = safe_open(saved_sf_fh_name, framework="pt")
277+
278+
origin_tensor = cache_save_fh[origin_sf_fh_name].get_tensor(key)
279+
saved_tensor = cache_save_fh[saved_sf_fh_name].get_tensor(key)
280+
281+
self.assertTrue(torch.equal(origin_tensor, saved_tensor), f"Tensor mismatch for key: {key}")
282+
283+
# Verify MTP weights are present in the saved output
284+
mtp_keys = [key for key in saved_index["weight_map"].keys() if key.startswith("mtp.")]
285+
# TODO: remove skip after MTP is implemented
286+
_ = mtp_keys
287+
288+
# Verify the tensor count in safetensors matches the saved index
289+
safetensor_keys: list[str] = []
290+
for safetensor_path in tmpdir.glob("*.safetensors"):
291+
fh = safe_open(str(safetensor_path), framework="pt")
292+
safetensor_keys.extend(fh.keys())
293+
safetensor_keys.sort()
294+
model_index_keys = list(saved_index["weight_map"].keys())
295+
model_index_keys.sort()
296+
self.assertListEqual(safetensor_keys, model_index_keys)
297+
298+
dist.barrier()
299+
219300
@property
220301
def world_size(self) -> int:
221302
return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "4"))

xtuner/v1/model/base.py

Lines changed: 137 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
MixedPrecisionPolicy,
2727
fully_shard,
2828
)
29-
from torch.distributed.tensor import DTensor, Placement, Shard
29+
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_tensor
3030
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
3131
from typing_extensions import NotRequired, Self, TypedDict, overload
3232

@@ -82,6 +82,12 @@ class HFSaveCfg(PydanticBaseModel):
8282
worker_per_rank: Annotated[int, Parameter(group="model")] = 16
8383
max_save_rank: Annotated[int, Parameter(group="model")] = 16
8484
bucket_size: Annotated[int, Parameter(group="model")] = 1024**3 * 4
85+
# TODO: `XTunerBaseModel` should also be able to specify which parameters to be trained in fp32,
86+
# currently it could only be specified in HFSaveCfg
87+
# Each entry is a **regex** pattern (passed to `re.search`) matched against the HF parameter name.
88+
# Remember to escape literal dots, e.g. use r"model\.layers\.\d+\.weight" instead of
89+
# r"model.layers.\d+.weight" to avoid unintended wildcard matches.
90+
fp32_keys_pattern: Annotated[list[str] | None, Parameter(group="model")] = None
8591

8692

8793
class XTunerBaseModelConfig(PydanticBaseModel):
@@ -313,6 +319,7 @@ def fully_shard(
313319
"""Fully shard the model parameters."""
314320
self.fsdp_config = fsdp_config
315321
self.fsdp_mesh = self._init_world_mesh()
322+
self._world_mesh = self.fsdp_mesh
316323

317324
if self.fsdp_config.requires_grad:
318325
for name, module in self.named_modules():
@@ -337,15 +344,79 @@ def fully_shard(
337344
mp_policy = MixedPrecisionPolicy(
338345
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
339346
)
340-
fully_shard(
341-
self,
347+
self._fully_shard(
342348
mesh=self.fsdp_mesh,
343349
mp_policy=mp_policy,
344350
reshard_after_forward=fsdp_config.reshard_after_forward,
345351
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
346352
)
347353
return self
348354

355+
def _fully_shard(
356+
self,
357+
mesh: DeviceMesh,
358+
mp_policy: MixedPrecisionPolicy,
359+
reshard_after_forward: bool,
360+
offload_policy: CPUOffloadPolicy | None,
361+
module: nn.Module | None = None,
362+
) -> None:
363+
def traverse(module):
364+
for name, param in module.named_parameters(recurse=False):
365+
full_name = full_param_name_mapping[id(param)]
366+
full_name = self._clean_param_name(full_name)
367+
hf_name_list = self.to_hf_key_list(full_name)
368+
369+
for hf_name in hf_name_list:
370+
if any(re.search(p, hf_name) for p in patterns): # type: ignore
371+
if not isinstance(param, DTensor):
372+
dist_param = nn.Parameter(
373+
distribute_tensor(
374+
param, self.world_mesh, [Replicate() for _ in range(self.world_mesh.ndim)]
375+
),
376+
requires_grad=param.requires_grad,
377+
)
378+
module.register_parameter(name, dist_param)
379+
ignored_params.add(dist_param)
380+
else:
381+
# param is already a DTensor (e.g. distributed by
382+
# MoE._replicate_other_params on ep_mesh before _fully_shard
383+
# is called). We skip re-distributing on world_mesh and just
384+
# add it to ignored_params so FSDP leaves it alone.
385+
# ASSUMPTION: fp32 distribution always happens AFTER any
386+
# prior EP distribution, so the existing placement is correct.
387+
ignored_params.add(param)
388+
break
389+
390+
for child in module.children():
391+
traverse(child)
392+
393+
# Collect the parameters of `target` that match any fp32 pattern so they can be
394+
# excluded from FSDP sharding (passed as `ignored_params`).
395+
#
396+
# We intentionally iterate over `self.named_parameters()` rather than
397+
# `target.named_parameters()` so that `name` is always relative to the root model
398+
# (`self`). This matters when `target` is a sub-module (e.g. `self.embed_tokens`):
399+
# `target.named_parameters()` would yield bare names like `"weight"`, which
400+
# `to_hf_key_list` cannot resolve correctly. By iterating from `self` we get the
401+
# full path (e.g. `"embed_tokens.weight"`) and filter to `target`'s parameters
402+
# using identity comparison.
403+
full_param_name_mapping = {id(param): name for name, param in self.named_parameters()}
404+
ignored_params: set[nn.Parameter] = set()
405+
patterns = self.config.hf_save_cfg.fp32_keys_pattern
406+
407+
target = module or self
408+
if patterns:
409+
traverse(target)
410+
411+
fully_shard(
412+
target,
413+
mesh=mesh,
414+
mp_policy=mp_policy,
415+
reshard_after_forward=reshard_after_forward,
416+
offload_policy=offload_policy,
417+
ignored_params=ignored_params if ignored_params else None,
418+
)
419+
349420
def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model"):
350421
with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"):
351422
self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype, safetensors_prefix=safetensors_prefix)
@@ -396,6 +467,12 @@ def device(self) -> torch.device:
396467
return torch.device("cpu")
397468
return torch.device(DEVICE)
398469

470+
@property
471+
def world_mesh(self) -> DeviceMesh | None:
472+
if not hasattr(self, "_world_mesh"):
473+
self._world_mesh = self._init_world_mesh()
474+
return self._world_mesh
475+
399476
@property
400477
def default_compile_cfg(self) -> dict[str, TorchCompileOption]:
401478
return {}
@@ -670,6 +747,12 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat
670747
)
671748
return ret
672749

750+
def _get_save_dtype(self, name: str, dtype: torch.dtype) -> torch.dtype:
751+
patterns = self.config.hf_save_cfg.fp32_keys_pattern
752+
if patterns and any(re.search(p, name) for p in patterns):
753+
return torch.float32
754+
return dtype
755+
673756
def _get_shard_hf_param(
674757
self,
675758
params: list[tuple[torch.Tensor, LoadSpec]],
@@ -679,6 +762,16 @@ def _get_shard_hf_param(
679762
) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]:
680763
if not params:
681764
return
765+
766+
ignored_params, params = self._split_ignored_params(params)
767+
if ignored_params:
768+
name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
769+
hf_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params]
770+
yield name_list, hf_params
771+
772+
if not params:
773+
return
774+
682775
if dtype != torch.bfloat16:
683776
raise NotImplementedError
684777

@@ -696,7 +789,7 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis
696789
# Get unsharded params
697790
_unsharded_tensor_list = foreach_all_gather(fsdp_unsharded_tensor_list, load_spec0.group)
698791
unsharded_tensor_list = [
699-
torch.cat([i.to(dtype) for i in tensors], dim=load_spec0.dim) for tensors in _unsharded_tensor_list
792+
torch.cat(list(tensors), dim=load_spec0.dim) for tensors in _unsharded_tensor_list
700793
]
701794
name_list = [spec.hf_keys[0] for _, spec in fsdp_tensor_list]
702795
unsharded_tensor_list = [
@@ -711,11 +804,11 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis
711804

712805
safetensor_size = 0
713806
tensor_list: list[tuple[torch.Tensor, LoadSpec]] = []
714-
name_list: list[str] = []
807+
name_list = []
715808

716809
for param, load_spec in params:
717810
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
718-
local_tensor = local_tensor.to(dtype=dtype)
811+
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
719812
tensor_size = self._get_tensor_size(param, dtype)
720813
if safetensor_size + tensor_size > bucket_size and tensor_list:
721814
hf_params = _get_hf_params(tensor_list)
@@ -744,6 +837,12 @@ def _get_fused_hf_param(
744837
if not params:
745838
return
746839

840+
ignored_params, params = self._split_ignored_params(params)
841+
if ignored_params:
842+
fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
843+
fp32_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params]
844+
yield fp32_name_list, fp32_params
845+
747846
def _get_hf_params(
748847
fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]],
749848
name_list: list[str],
@@ -867,7 +966,7 @@ def _get_hf_params(
867966

868967
for param, load_spec in params:
869968
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
870-
local_tensor = local_tensor.bfloat16()
969+
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
871970
tensor_size = self._get_tensor_size(param, dtype)
872971
if safetensor_size + tensor_size > bucket_size and tensor_list:
873972
hf_params, name_list = _get_hf_params(tensor_list, name_list)
@@ -893,6 +992,15 @@ def _get_same_hf_param(
893992
) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]:
894993
if not params:
895994
return
995+
996+
ignored_params, params = self._split_ignored_params(params)
997+
if ignored_params:
998+
fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
999+
fp32_tensor_list: list[torch.Tensor] = [
1000+
param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params
1001+
]
1002+
yield fp32_name_list, fp32_tensor_list
1003+
8961004
if bucket_size is None:
8971005
bucket_size = self.config.hf_save_cfg.bucket_size
8981006
safetensor_size = 0
@@ -909,7 +1017,7 @@ def _get_same_hf_param(
9091017
buffer_name_list.append(load_spec.hf_keys[0])
9101018
continue
9111019
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
912-
local_tensor = local_tensor.bfloat16()
1020+
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
9131021
tensor_size = self._get_tensor_size(param, dtype)
9141022
if safetensor_size + tensor_size > bucket_size and tensor_list:
9151023
if self.fsdp_mesh is not None:
@@ -953,6 +1061,21 @@ def _get_same_hf_param(
9531061
if buffer_tensor_list:
9541062
yield buffer_name_list, buffer_tensor_list
9551063

1064+
def _is_ignored_params(self, key: str):
1065+
patterns = self.config.hf_save_cfg.fp32_keys_pattern
1066+
if patterns is None:
1067+
return False
1068+
return any(re.search(p, key) for p in patterns)
1069+
1070+
def _split_ignored_params(
1071+
self, params: list[tuple[torch.Tensor, LoadSpec]]
1072+
) -> tuple[list[tuple[torch.Tensor, LoadSpec]], list[tuple[torch.Tensor, LoadSpec]]]:
1073+
if not self.config.hf_save_cfg.fp32_keys_pattern:
1074+
return [], params
1075+
ignored_params = [(p, l) for p, l in params if self._is_ignored_params(l.hf_keys[0])]
1076+
remaining = [(p, l) for p, l in params if not self._is_ignored_params(l.hf_keys[0])]
1077+
return ignored_params, remaining
1078+
9561079
# TODO: Using `xtuenr.v1.utils.misc.clean_param_name`
9571080
def _clean_param_name(self, name: str) -> str:
9581081
if "_checkpoint_wrapped_module." in name:
@@ -1230,7 +1353,12 @@ def _load_same_hf_param(
12301353

12311354
loaded_tensor = loaded_tensor.to(local_tensor.device)
12321355

1233-
if self.fsdp_mesh is not None and isinstance(param, nn.Parameter):
1356+
if (
1357+
self.fsdp_mesh is not None
1358+
and isinstance(param, nn.Parameter)
1359+
and isinstance(param, DTensor)
1360+
and any(isinstance(p, Shard) for p in param.placements)
1361+
):
12341362
shape_before_fsdp = load_spec.shape
12351363
_, _offset = compute_local_shape_and_global_offset(
12361364
shape_before_fsdp, self.fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)]

xtuner/v1/model/compose/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
CPUOffloadPolicy,
1111
FSDPModule,
1212
MixedPrecisionPolicy,
13-
fully_shard,
1413
)
1514
from typing_extensions import override
1615

@@ -108,8 +107,7 @@ def fully_shard(
108107
# Note: 非常关键,不能删除这个 assert
109108
assert self.fsdp_mesh is not None
110109

111-
fully_shard(
112-
self,
110+
self._fully_shard(
113111
mesh=self.fsdp_mesh,
114112
mp_policy=mp_policy,
115113
reshard_after_forward=fsdp_config.reshard_after_forward,

xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ def fully_shard(
8383
# Note: 非常关键,不能删除这个 assert
8484
assert self.fsdp_mesh is not None
8585

86-
fully_shard(
87-
self,
86+
self._fully_shard(
8887
mesh=self.fsdp_mesh,
8988
mp_policy=mp_policy,
9089
reshard_after_forward=fsdp_config.reshard_after_forward,

xtuner/v1/model/compose/intern_s1/modeling_projector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def fully_shard(
7070
for param in self.parameters():
7171
param.requires_grad = False
7272

73-
fully_shard(
74-
self,
73+
self._fully_shard(
7574
mesh=self.fsdp_mesh,
7675
mp_policy=mp_policy,
7776
reshard_after_forward=True,

0 commit comments

Comments
 (0)