Skip to content

fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params#1813

Merged
HuiyingLi merged 7 commits intor0.4.0from
huiyingl/fix-qwen35-fsdp2-meta-device
Apr 15, 2026
Merged

fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params#1813
HuiyingLi merged 7 commits intor0.4.0from
huiyingl/fix-qwen35-fsdp2-meta-device

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi commented Apr 13, 2026

Summary

  • PR feat: FSDP2 w weight prefetching and async TP optimization #1711 changed _should_load_before_shard to return False for multi-GPU DP, so models stay on meta device through FSDP wrapping. This broke the __dict__ trick in PR fix: Qwen3.5 dense CP support and FSDP mixed-dtype fix #1710's patch_hf_model.
  • Move the gate computation (g = -A_log.exp() * softplus(a + dt_bias)) into _Fp32ParamHolder.forward() so FSDP's unshard/reshard lifecycle fires naturally around the fp32 params.
  • Override CPAwareGatedDeltaNet forward for both CP and non-CP paths to route through the holder. Class swap is now unconditional (needed for the non-CP forward override).
  • Add __getattr__ on the class for checkpoint/state_dict access to moved params.

Test plan

🤖 Generated with Claude Code

PR #1711 changed _should_load_before_shard to return False for multi-GPU
DP, so models stay on meta device through FSDP wrapping. This broke the
__dict__ trick in PR #1710's patch_hf_model.

Move the gate computation into _Fp32ParamHolder.forward() so FSDP's
unshard/reshard lifecycle fires naturally. Override CPAwareGatedDeltaNet
forward for both CP and non-CP paths to route through the holder.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test a08e559

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test b9a7513

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review


_logger = logging.getLogger(__name__)
patched = 0
patched_classes = set()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: patched_classes is local to each patch_hf_model call, but CPAwareGatedDeltaNet inherits from Qwen3_5GatedDeltaNet, so a second invocation would pass the isinstance check on line 515 and re-wrap cls.__getattr__, creating a growing chain of wrappers (each calling the next, all checking _fp32_params).

This may not happen today since parallelizer.py only calls it once, but it's a latent bug. A simple guard would be a sentinel on the class:

Suggested change
patched_classes = set()
patched_classes = set()
_PATCHED_ATTR = "_fp32_getattr_patched"

Then at line 537:

if cls not in patched_classes and not getattr(cls, _PATCHED_ATTR, False):
    cls.__getattr__ = _make_fp32_getattr(cls.__getattr__)
    setattr(cls, _PATCHED_ATTR, True)
    patched_classes.add(cls)

…guard

Add unit tests for:
- _Fp32ParamHolder.forward gate computation and dtype preservation
- _compute_gate routing through holder vs inline fallback
- patch_hf_model sentinel preventing __getattr__ re-wrapping

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test cea2a0c

Add 14 new tests covering the critical _forward_no_cp method (lines
91-193) and forward() dispatch logic (lines 207-213) to satisfy
codecov/patch requirements for PR #1813:

- _forward_no_cp basic forward, cache_params=None, causal_conv1d_fn
  fallback, causal_conv1d_fn set, attention_mask, GQA repeat-interleave,
  _compute_gate delegation, and output dtype
- forward() dispatch when _cp_mesh is None or size <= 1, parameter
  pass-through, and extra CP kwargs
- _make_fp32_getattr fallback to AttributeError and real attr resolution

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test a17d84a

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

r0.4.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant