Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 171 additions & 13 deletions nemo_automodel/components/models/qwen3_5_moe/cp_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,121 @@ def __init__(self, config, layer_idx: int):
super().__init__(config, layer_idx)
self._cp_mesh = None

def _compute_gate(self, a: torch.Tensor) -> torch.Tensor:
"""Compute the gating value ``g`` using fp32 params.

When ``_fp32_params`` exists (FSDP mixed-dtype), delegates to
the holder's forward so FSDP unshard/reshard lifecycle is natural.
Otherwise falls back to the inline computation.
"""
if hasattr(self, "_fp32_params"):
return self._fp32_params(a, self.dt_bias)
return -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)

def _forward_no_cp(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position=None,
attention_mask: torch.Tensor | None = None,
):
"""HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.

Copied from transformers==5.3.0 Qwen3_5GatedDeltaNet.forward
with gate computation replaced by self._compute_gate(a).
"""
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states

hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape

use_precomputed_states = (
cache_params is not None and cache_params.has_previous_state and seq_len == 1 and cache_position is not None
)

if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]

mixed_qkv = self.in_proj_qkv(hidden_states)
mixed_qkv = mixed_qkv.transpose(1, 2)

z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)

b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)

if use_precomputed_states:
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
else:
if cache_params is not None:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
cache_params.conv_states[self.layer_idx] = conv_state
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])

mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1)

query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)

beta = b.sigmoid()
g = self._compute_gate(a)

if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

if not use_precomputed_states:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)

if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state

core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)

return self.out_proj(core_attn_out)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -88,9 +203,9 @@ def forward(
cu_seqlens: torch.Tensor | None = None,
seq_index: torch.Tensor | None = None,
):
# Fast path: no CP → original HF forward
# Fast path: no CP → run HF forward with fp32-safe gate computation
if self._cp_mesh is None or self._cp_mesh.size() <= 1:
return super().forward(
return self._forward_no_cp(
hidden_states,
cache_params=cache_params,
attention_mask=attention_mask,
Expand Down Expand Up @@ -299,7 +414,7 @@ def _forward_with_cp(

# ---- Gate & beta ----
beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
g = self._compute_gate(a)

# GVA: repeat q/k heads to match v heads
if self.num_v_heads // self.num_k_heads > 1:
Expand Down Expand Up @@ -340,14 +455,49 @@ def _forward_with_cp(
return output


class _Fp32ParamHolder(torch.nn.Module):
"""Holder for float32 params (A_log) that need a separate FSDP group.

The ``forward`` computes the gating value ``g`` that HF's
``Qwen3_5GatedDeltaNet.forward`` would normally compute inline.
By doing the computation *inside* this module's forward, FSDP's
unshard/reshard lifecycle works naturally — the params are
unsharded during the computation and resharded after.
"""

def forward(self, a: torch.Tensor, dt_bias: torch.Tensor) -> torch.Tensor:
return -self.A_log.float().exp() * F.softplus(a.float() + dt_bias)


def _make_fp32_getattr(orig_getattr):
"""Create a ``__getattr__`` that resolves fp32 params from ``_fp32_params``.

Allows ``self.A_log`` to resolve from the holder submodule so that
code outside forward (e.g. state_dict, checkpointing) can still
access the parameter by name.
"""

def _getattr_with_fp32(self, name):
modules = self.__dict__.get("_modules", {})
fp32_holder = modules.get("_fp32_params")
if fp32_holder is not None and name in fp32_holder._parameters:
return fp32_holder._parameters[name]
return orig_getattr(self, name)

return _getattr_with_fp32


def patch_hf_model(model, cp_enabled=False):
"""Patch HF Qwen3.5 GatedDeltaNet modules for FSDP and optional CP support.

For FSDP compatibility, move float32 bare params (A_log) into a
_fp32_params submodule so fully_shard_by_dtype can wrap them separately.
``_fp32_params`` submodule so ``fully_shard_by_dtype`` can wrap them
in a separate FSDP group.

When ``cp_enabled=True``, also swap each module's __class__ to
CPAwareGatedDeltaNet for context parallelism support.
Every module's ``__class__`` is swapped to ``CPAwareGatedDeltaNet``
whose ``forward()`` calls ``self._fp32_params()`` to trigger FSDP
unshard before accessing the fp32 params. When ``cp_enabled=True``,
the CP mesh is also configured.
"""
import logging

Expand All @@ -357,29 +507,37 @@ def patch_hf_model(model, cp_enabled=False):
return

_logger = logging.getLogger(__name__)
_PATCHED_ATTR = "_fp32_getattr_patched"
patched = 0
patched_classes = set()
for name, mod in model.named_modules():
if not isinstance(mod, Qwen3_5GatedDeltaNet):
continue

if cp_enabled:
mod.__class__ = CPAwareGatedDeltaNet
mod._cp_mesh = None
mod.__class__ = CPAwareGatedDeltaNet
mod._cp_mesh = None

# Move float32 bare params into a holder submodule for FSDP.
# The __dict__ reference lets HF forward access self.A_log directly,
# while FSDP manages the param via the _fp32_params submodule.
# The CPAwareGatedDeltaNet forward calls self._fp32_params()
# to trigger FSDP unshard; __getattr__ redirects self.A_log
# to the holder so it returns the unsharded plain tensor.
holder = None
for pname in list(mod._parameters.keys()):
param = mod._parameters[pname]
if param is not None and param.dtype == torch.float32:
if holder is None:
holder = torch.nn.Module()
holder = _Fp32ParamHolder()
setattr(holder, pname, param)
del mod._parameters[pname]
mod.__dict__[pname] = param
if holder is not None:
mod.add_module("_fp32_params", holder)

# Guard against re-wrapping __getattr__ on repeated calls.
cls = type(mod)
if cls not in patched_classes and not getattr(cls, _PATCHED_ATTR, False):
cls.__getattr__ = _make_fp32_getattr(cls.__getattr__)
setattr(cls, _PATCHED_ATTR, True)
patched_classes.add(cls)
patched += 1

if patched > 0:
Expand Down
Loading
Loading