Skip to content

cp: 1813 fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params#1869

Merged
akoumpa merged 2 commits intomainfrom
cherry-pick-1813-main
Apr 16, 2026
Merged

cp: 1813 fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params#1869
akoumpa merged 2 commits intomainfrom
cherry-pick-1813-main

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

Summary

Cherry-pick of #1813 from r0.4.0 to main.

  • PR feat: FSDP2 w weight prefetching and async TP optimization #1711 changed _should_load_before_shard to return False for multi-GPU DP, so models stay on meta device through FSDP wrapping. This broke the __dict__ trick in PR fix: Qwen3.5 dense CP support and FSDP mixed-dtype fix #1710's patch_hf_model.
  • Move the gate computation (g = -A_log.exp() * softplus(a + dt_bias)) into _Fp32ParamHolder.forward() so FSDP's unshard/reshard lifecycle fires naturally around the fp32 params.
  • Override CPAwareGatedDeltaNet forward for both CP and non-CP paths to route through the holder. Class swap is now unconditional (needed for the non-CP forward override).
  • Add __getattr__ on the class for checkpoint/state_dict access to moved params.

Test plan

🤖 Generated with Claude Code

…1813)

* fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params

PR #1711 changed _should_load_before_shard to return False for multi-GPU
DP, so models stay on meta device through FSDP wrapping. This broke the
__dict__ trick in PR #1710's patch_hf_model.

Move the gate computation into _Fp32ParamHolder.forward() so FSDP's
unshard/reshard lifecycle fires naturally. Override CPAwareGatedDeltaNet
forward for both CP and non-CP paths to route through the holder.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* chore: remove test yaml not intended for PR

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: add sentinel to prevent __getattr__ re-wrapping

Address Claude review: guard against re-wrapping __getattr__ on
repeated patch_hf_model calls by checking a class-level sentinel
attribute.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: add upstream version comment to _forward_no_cp

Address Claude review: note the transformers version the forward was
copied from to ease future upstream diffing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: update MoE test expectations for _forward_no_cp path

TestForwardFastPath tests expected super().forward() to be called,
but the non-CP path now uses _forward_no_cp(). Update mocks to match.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test: add coverage for _Fp32ParamHolder, _compute_gate, and sentinel guard

Add unit tests for:
- _Fp32ParamHolder.forward gate computation and dtype preservation
- _compute_gate routing through holder vs inline fallback
- patch_hf_model sentinel preventing __getattr__ re-wrapping

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test: add coverage for _forward_no_cp and forward() dispatch paths

Add 14 new tests covering the critical _forward_no_cp method (lines
91-193) and forward() dispatch logic (lines 207-213) to satisfy
codecov/patch requirements for PR #1813:

- _forward_no_cp basic forward, cache_params=None, causal_conv1d_fn
  fallback, causal_conv1d_fn set, attention_mask, GQA repeat-interleave,
  _compute_gate delegation, and output dtype
- forward() dispatch when _cp_mesh is None or size <= 1, parameter
  pass-through, and extra CP kwargs
- _make_fp32_getattr fallback to AttributeError and real attr resolution

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 16, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test b8e9bc7

@HuiyingLi HuiyingLi changed the title fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params cp: 1813 fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params Apr 16, 2026
…rward_no_cp

The fast-path in CPAwareGatedDeltaNet.forward was refactored to call
self._forward_no_cp() instead of super().forward(), but this test still
mocked the base class forward and thus got called 0 times. Update the
mock target to match the new dispatch, and apply ruff format to the
two test files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/ok to test de746bf

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@akoumpa akoumpa merged commit db82563 into main Apr 16, 2026
70 of 72 checks passed
@akoumpa akoumpa deleted the cherry-pick-1813-main branch April 16, 2026 21:43
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
…params (#1869)

* fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params (#1813)

* fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params

PR #1711 changed _should_load_before_shard to return False for multi-GPU
DP, so models stay on meta device through FSDP wrapping. This broke the
__dict__ trick in PR #1710's patch_hf_model.

Move the gate computation into _Fp32ParamHolder.forward() so FSDP's
unshard/reshard lifecycle fires naturally. Override CPAwareGatedDeltaNet
forward for both CP and non-CP paths to route through the holder.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* chore: remove test yaml not intended for PR

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: add sentinel to prevent __getattr__ re-wrapping

Address Claude review: guard against re-wrapping __getattr__ on
repeated patch_hf_model calls by checking a class-level sentinel
attribute.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: add upstream version comment to _forward_no_cp

Address Claude review: note the transformers version the forward was
copied from to ease future upstream diffing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix: update MoE test expectations for _forward_no_cp path

TestForwardFastPath tests expected super().forward() to be called,
but the non-CP path now uses _forward_no_cp(). Update mocks to match.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test: add coverage for _Fp32ParamHolder, _compute_gate, and sentinel guard

Add unit tests for:
- _Fp32ParamHolder.forward gate computation and dtype preservation
- _compute_gate routing through holder vs inline fallback
- patch_hf_model sentinel preventing __getattr__ re-wrapping

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* test: add coverage for _forward_no_cp and forward() dispatch paths

Add 14 new tests covering the critical _forward_no_cp method (lines
91-193) and forward() dispatch logic (lines 207-213) to satisfy
codecov/patch requirements for PR #1813:

- _forward_no_cp basic forward, cache_params=None, causal_conv1d_fn
  fallback, causal_conv1d_fn set, attention_mask, GQA repeat-interleave,
  _compute_gate delegation, and output dtype
- forward() dispatch when _cp_mesh is None or size <= 1, parameter
  pass-through, and extra CP kwargs
- _make_fp32_getattr fallback to AttributeError and real attr resolution

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: update MoE test_no_cp_does_not_forward_cache_position to use _forward_no_cp

The fast-path in CPAwareGatedDeltaNet.forward was refactored to call
self._forward_no_cp() instead of super().forward(), but this test still
mocked the base class forward and thus got called 0 times. Update the
mock target to match the new dispatch, and apply ruff format to the
two test files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants