Skip to content

Commit 9bbcf4a

Browse files
HuiyingLiclaude
authored andcommitted
cp: 1813 fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params (#1869)
* fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params (#1813) * fix: FSDP2 meta-device crash for Qwen3.5 GatedDeltaNet fp32 params PR #1711 changed _should_load_before_shard to return False for multi-GPU DP, so models stay on meta device through FSDP wrapping. This broke the __dict__ trick in PR #1710's patch_hf_model. Move the gate computation into _Fp32ParamHolder.forward() so FSDP's unshard/reshard lifecycle fires naturally. Override CPAwareGatedDeltaNet forward for both CP and non-CP paths to route through the holder. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * chore: remove test yaml not intended for PR Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: add sentinel to prevent __getattr__ re-wrapping Address Claude review: guard against re-wrapping __getattr__ on repeated patch_hf_model calls by checking a class-level sentinel attribute. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: add upstream version comment to _forward_no_cp Address Claude review: note the transformers version the forward was copied from to ease future upstream diffing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * fix: update MoE test expectations for _forward_no_cp path TestForwardFastPath tests expected super().forward() to be called, but the non-CP path now uses _forward_no_cp(). Update mocks to match. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test: add coverage for _Fp32ParamHolder, _compute_gate, and sentinel guard Add unit tests for: - _Fp32ParamHolder.forward gate computation and dtype preservation - _compute_gate routing through holder vs inline fallback - patch_hf_model sentinel preventing __getattr__ re-wrapping Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * test: add coverage for _forward_no_cp and forward() dispatch paths Add 14 new tests covering the critical _forward_no_cp method (lines 91-193) and forward() dispatch logic (lines 207-213) to satisfy codecov/patch requirements for PR #1813: - _forward_no_cp basic forward, cache_params=None, causal_conv1d_fn fallback, causal_conv1d_fn set, attention_mask, GQA repeat-interleave, _compute_gate delegation, and output dtype - forward() dispatch when _cp_mesh is None or size <= 1, parameter pass-through, and extra CP kwargs - _make_fp32_getattr fallback to AttributeError and real attr resolution Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: update MoE test_no_cp_does_not_forward_cache_position to use _forward_no_cp The fast-path in CPAwareGatedDeltaNet.forward was refactored to call self._forward_no_cp() instead of super().forward(), but this test still mocked the base class forward and thus got called 0 times. Update the mock target to match the new dispatch, and apply ruff format to the two test files. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com> --------- Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 09b91f0 commit 9bbcf4a

3 files changed

Lines changed: 679 additions & 69 deletions

File tree

nemo_automodel/components/models/qwen3_5_moe/cp_linear_attn.py

Lines changed: 171 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,121 @@ def __init__(self, config, layer_idx: int):
7777
super().__init__(config, layer_idx)
7878
self._cp_mesh = None
7979

80+
def _compute_gate(self, a: torch.Tensor) -> torch.Tensor:
81+
"""Compute the gating value ``g`` using fp32 params.
82+
83+
When ``_fp32_params`` exists (FSDP mixed-dtype), delegates to
84+
the holder's forward so FSDP unshard/reshard lifecycle is natural.
85+
Otherwise falls back to the inline computation.
86+
"""
87+
if hasattr(self, "_fp32_params"):
88+
return self._fp32_params(a, self.dt_bias)
89+
return -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
90+
91+
def _forward_no_cp(
92+
self,
93+
hidden_states: torch.Tensor,
94+
cache_params=None,
95+
cache_position=None,
96+
attention_mask: torch.Tensor | None = None,
97+
):
98+
"""HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.
99+
100+
Copied from transformers==5.3.0 Qwen3_5GatedDeltaNet.forward
101+
with gate computation replaced by self._compute_gate(a).
102+
"""
103+
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
104+
105+
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
106+
batch_size, seq_len, _ = hidden_states.shape
107+
108+
use_precomputed_states = (
109+
cache_params is not None and cache_params.has_previous_state and seq_len == 1 and cache_position is not None
110+
)
111+
112+
if cache_params is not None:
113+
conv_state = cache_params.conv_states[self.layer_idx]
114+
recurrent_state = cache_params.recurrent_states[self.layer_idx]
115+
116+
mixed_qkv = self.in_proj_qkv(hidden_states)
117+
mixed_qkv = mixed_qkv.transpose(1, 2)
118+
119+
z = self.in_proj_z(hidden_states)
120+
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
121+
122+
b = self.in_proj_b(hidden_states)
123+
a = self.in_proj_a(hidden_states)
124+
125+
if use_precomputed_states:
126+
mixed_qkv = self.causal_conv1d_update(
127+
mixed_qkv,
128+
conv_state,
129+
self.conv1d.weight.squeeze(1),
130+
self.conv1d.bias,
131+
self.activation,
132+
)
133+
else:
134+
if cache_params is not None:
135+
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
136+
cache_params.conv_states[self.layer_idx] = conv_state
137+
if self.causal_conv1d_fn is not None:
138+
mixed_qkv = self.causal_conv1d_fn(
139+
x=mixed_qkv,
140+
weight=self.conv1d.weight.squeeze(1),
141+
bias=self.conv1d.bias,
142+
activation=self.activation,
143+
seq_idx=None,
144+
)
145+
else:
146+
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
147+
148+
mixed_qkv = mixed_qkv.transpose(1, 2)
149+
query, key, value = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1)
150+
151+
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
152+
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
153+
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
154+
155+
beta = b.sigmoid()
156+
g = self._compute_gate(a)
157+
158+
if self.num_v_heads // self.num_k_heads > 1:
159+
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
160+
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
161+
162+
if not use_precomputed_states:
163+
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
164+
query,
165+
key,
166+
value,
167+
g=g,
168+
beta=beta,
169+
initial_state=None,
170+
output_final_state=cache_params is not None,
171+
use_qk_l2norm_in_kernel=True,
172+
)
173+
else:
174+
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
175+
query,
176+
key,
177+
value,
178+
g=g,
179+
beta=beta,
180+
initial_state=recurrent_state,
181+
output_final_state=cache_params is not None,
182+
use_qk_l2norm_in_kernel=True,
183+
)
184+
185+
if cache_params is not None:
186+
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
187+
188+
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
189+
z = z.reshape(-1, self.head_v_dim)
190+
core_attn_out = self.norm(core_attn_out, z)
191+
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
192+
193+
return self.out_proj(core_attn_out)
194+
80195
def forward(
81196
self,
82197
hidden_states: torch.Tensor,
@@ -88,9 +203,9 @@ def forward(
88203
cu_seqlens: torch.Tensor | None = None,
89204
seq_index: torch.Tensor | None = None,
90205
):
91-
# Fast path: no CP → original HF forward
206+
# Fast path: no CP → run HF forward with fp32-safe gate computation
92207
if self._cp_mesh is None or self._cp_mesh.size() <= 1:
93-
return super().forward(
208+
return self._forward_no_cp(
94209
hidden_states,
95210
cache_params=cache_params,
96211
attention_mask=attention_mask,
@@ -299,7 +414,7 @@ def _forward_with_cp(
299414

300415
# ---- Gate & beta ----
301416
beta = b.sigmoid()
302-
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
417+
g = self._compute_gate(a)
303418

304419
# GVA: repeat q/k heads to match v heads
305420
if self.num_v_heads // self.num_k_heads > 1:
@@ -340,14 +455,49 @@ def _forward_with_cp(
340455
return output
341456

342457

458+
class _Fp32ParamHolder(torch.nn.Module):
459+
"""Holder for float32 params (A_log) that need a separate FSDP group.
460+
461+
The ``forward`` computes the gating value ``g`` that HF's
462+
``Qwen3_5GatedDeltaNet.forward`` would normally compute inline.
463+
By doing the computation *inside* this module's forward, FSDP's
464+
unshard/reshard lifecycle works naturally — the params are
465+
unsharded during the computation and resharded after.
466+
"""
467+
468+
def forward(self, a: torch.Tensor, dt_bias: torch.Tensor) -> torch.Tensor:
469+
return -self.A_log.float().exp() * F.softplus(a.float() + dt_bias)
470+
471+
472+
def _make_fp32_getattr(orig_getattr):
473+
"""Create a ``__getattr__`` that resolves fp32 params from ``_fp32_params``.
474+
475+
Allows ``self.A_log`` to resolve from the holder submodule so that
476+
code outside forward (e.g. state_dict, checkpointing) can still
477+
access the parameter by name.
478+
"""
479+
480+
def _getattr_with_fp32(self, name):
481+
modules = self.__dict__.get("_modules", {})
482+
fp32_holder = modules.get("_fp32_params")
483+
if fp32_holder is not None and name in fp32_holder._parameters:
484+
return fp32_holder._parameters[name]
485+
return orig_getattr(self, name)
486+
487+
return _getattr_with_fp32
488+
489+
343490
def patch_hf_model(model, cp_enabled=False):
344491
"""Patch HF Qwen3.5 GatedDeltaNet modules for FSDP and optional CP support.
345492
346493
For FSDP compatibility, move float32 bare params (A_log) into a
347-
_fp32_params submodule so fully_shard_by_dtype can wrap them separately.
494+
``_fp32_params`` submodule so ``fully_shard_by_dtype`` can wrap them
495+
in a separate FSDP group.
348496
349-
When ``cp_enabled=True``, also swap each module's __class__ to
350-
CPAwareGatedDeltaNet for context parallelism support.
497+
Every module's ``__class__`` is swapped to ``CPAwareGatedDeltaNet``
498+
whose ``forward()`` calls ``self._fp32_params()`` to trigger FSDP
499+
unshard before accessing the fp32 params. When ``cp_enabled=True``,
500+
the CP mesh is also configured.
351501
"""
352502
import logging
353503

@@ -357,29 +507,37 @@ def patch_hf_model(model, cp_enabled=False):
357507
return
358508

359509
_logger = logging.getLogger(__name__)
510+
_PATCHED_ATTR = "_fp32_getattr_patched"
360511
patched = 0
512+
patched_classes = set()
361513
for name, mod in model.named_modules():
362514
if not isinstance(mod, Qwen3_5GatedDeltaNet):
363515
continue
364516

365-
if cp_enabled:
366-
mod.__class__ = CPAwareGatedDeltaNet
367-
mod._cp_mesh = None
517+
mod.__class__ = CPAwareGatedDeltaNet
518+
mod._cp_mesh = None
368519

369520
# Move float32 bare params into a holder submodule for FSDP.
370-
# The __dict__ reference lets HF forward access self.A_log directly,
371-
# while FSDP manages the param via the _fp32_params submodule.
521+
# The CPAwareGatedDeltaNet forward calls self._fp32_params()
522+
# to trigger FSDP unshard; __getattr__ redirects self.A_log
523+
# to the holder so it returns the unsharded plain tensor.
372524
holder = None
373525
for pname in list(mod._parameters.keys()):
374526
param = mod._parameters[pname]
375527
if param is not None and param.dtype == torch.float32:
376528
if holder is None:
377-
holder = torch.nn.Module()
529+
holder = _Fp32ParamHolder()
378530
setattr(holder, pname, param)
379531
del mod._parameters[pname]
380-
mod.__dict__[pname] = param
381532
if holder is not None:
382533
mod.add_module("_fp32_params", holder)
534+
535+
# Guard against re-wrapping __getattr__ on repeated calls.
536+
cls = type(mod)
537+
if cls not in patched_classes and not getattr(cls, _PATCHED_ATTR, False):
538+
cls.__getattr__ = _make_fp32_getattr(cls.__getattr__)
539+
setattr(cls, _PATCHED_ATTR, True)
540+
patched_classes.add(cls)
383541
patched += 1
384542

385543
if patched > 0:

0 commit comments

Comments
 (0)