Skip to content

Commit e1fc09e

Browse files
authored
fix(deepseek-v4): transpose hyperconnection comb (#2159)
Signed-off-by: larkzhang-nv <larkz@nvidia.com>
1 parent 1581fd0 commit e1fc09e

2 files changed

Lines changed: 63 additions & 3 deletions

File tree

nemo_automodel/components/models/deepseek_v4/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def forward(
153153
rotary_compress=rotary_compress,
154154
)
155155
dtype = x.dtype
156-
# Expand: new_stream[h] = post[h] * attn_out + Σ_k comb[h,k] * x[k]
157-
x = post.to(dtype).unsqueeze(-1) * attn_out.unsqueeze(-2) + torch.matmul(comb.to(dtype), x)
156+
# Expand: native DSV4 uses comb[j, h] * residual[j], i.e. comb.T @ residual.
157+
x = post.to(dtype).unsqueeze(-1) * attn_out.unsqueeze(-2) + torch.matmul(comb.transpose(-1, -2).to(dtype), x)
158158

159159
# --- MLP site: same pattern ---
160160
pre, post, comb = self.ffn_hc.compute_weights(x)
@@ -165,7 +165,7 @@ def forward(
165165
self.mlp.gate.set_input_ids(input_ids)
166166
mlp_out = self.mlp(self.post_attention_layernorm(collapsed), padding_mask)
167167
dtype = x.dtype
168-
return post.to(dtype).unsqueeze(-1) * mlp_out.unsqueeze(-2) + torch.matmul(comb.to(dtype), x)
168+
return post.to(dtype).unsqueeze(-1) * mlp_out.unsqueeze(-2) + torch.matmul(comb.transpose(-1, -2).to(dtype), x)
169169

170170
def init_weights(self, buffer_device: torch.device) -> None:
171171
self.input_layernorm.reset_parameters()

tests/unit_tests/models/deepseek_v4/test_dsv4_model_smoke.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,66 @@ def _make_model(config: DeepseekV4Config) -> DeepseekV4ForCausalLM:
114114

115115

116116
class TestDeepseekV4ModelSmoke:
117+
def test_hc_comb_transpose_used_at_attn_and_mlp_sites(self):
118+
"""Both HC expand sites mix residual streams as ``comb.T @ x``."""
119+
120+
class _FixedHC(torch.nn.Module):
121+
def __init__(self, comb):
122+
super().__init__()
123+
self.register_buffer("comb", comb)
124+
125+
def compute_weights(self, hidden_streams):
126+
bsz, seq, hc_mult = hidden_streams.shape[:3]
127+
pre = torch.zeros(bsz, seq, hc_mult, dtype=torch.float32, device=hidden_streams.device)
128+
post = torch.zeros_like(pre)
129+
return pre, post, self.comb.expand(bsz, seq, -1, -1)
130+
131+
class _ZeroAttention(torch.nn.Module):
132+
def forward(self, hidden_states, **kwargs):
133+
return torch.zeros_like(hidden_states), None
134+
135+
class _ZeroMLP(torch.nn.Module):
136+
def forward(self, hidden_states, padding_mask=None):
137+
return torch.zeros_like(hidden_states)
138+
139+
cfg = _tiny_config(num_hidden_layers=1, num_hash_layers=0, compress_ratios=[0])
140+
model = _make_model(cfg)
141+
block = model.model.layers["0"]
142+
block.attn_hc = _FixedHC(
143+
torch.tensor(
144+
[
145+
[1.0, 2.0, 0.0, 1.0],
146+
[0.0, 1.0, 3.0, 0.0],
147+
[4.0, 0.0, 1.0, 2.0],
148+
[0.0, 5.0, 0.0, 1.0],
149+
]
150+
)
151+
)
152+
block.ffn_hc = _FixedHC(
153+
torch.tensor(
154+
[
155+
[2.0, 0.0, 1.0, 0.0],
156+
[1.0, 3.0, 0.0, 2.0],
157+
[0.0, 1.0, 4.0, 0.0],
158+
[5.0, 0.0, 1.0, 1.0],
159+
]
160+
)
161+
)
162+
block.self_attn = _ZeroAttention()
163+
block.mlp = _ZeroMLP()
164+
block.input_layernorm = torch.nn.Identity()
165+
block.post_attention_layernorm = torch.nn.Identity()
166+
167+
x = torch.arange(cfg.hc_mult * cfg.hidden_size, dtype=torch.float32).view(1, 1, cfg.hc_mult, cfg.hidden_size)
168+
expected = torch.matmul(block.attn_hc.comb.transpose(-1, -2), x)
169+
expected = torch.matmul(block.ffn_hc.comb.transpose(-1, -2), expected)
170+
wrong_orientation = torch.matmul(block.ffn_hc.comb, torch.matmul(block.attn_hc.comb, x))
171+
172+
actual = block(x, position_embeddings=(torch.empty(0), torch.empty(0)))
173+
174+
torch.testing.assert_close(actual, expected)
175+
assert not torch.allclose(expected, wrong_orientation)
176+
117177
@_REQUIRES_CUDA
118178
def test_forward_shape(self):
119179
"""Forward pass produces logits of the right shape."""

0 commit comments

Comments
 (0)