Skip to content

Commit 897a38b

Browse files
committed
update tests
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 030e92c commit 897a38b

3 files changed

Lines changed: 153 additions & 16 deletions

File tree

tests/unit_tests/models/afmoe/test_afmoe_layers.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,58 @@ def test_has_qk_norm(self, tiny_config, backend_config):
8080
assert hasattr(attn, "k_norm")
8181

8282
def test_forward_shape(self, tiny_config, backend_config, device):
83-
attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device).to(torch.float32)
83+
attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device)
8484

8585
batch, seq_len = 2, 8
86-
x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device)
86+
x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16)
8787
freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device)
8888

8989
out = attn(x, freqs_cis=freqs_cis)
9090
assert out.shape == (batch, seq_len, tiny_config.hidden_size)
9191

9292
def test_global_attention_forward_shape(self, tiny_config, backend_config, device):
93-
attn = AfmoeAttention(tiny_config, layer_idx=1, backend=backend_config).to(device).to(torch.float32)
93+
attn = AfmoeAttention(tiny_config, layer_idx=1, backend=backend_config).to(device)
9494

9595
batch, seq_len = 2, 8
96-
x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device)
96+
x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16)
9797
freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device)
9898

9999
out = attn(x, freqs_cis=freqs_cis)
100100
assert out.shape == (batch, seq_len, tiny_config.hidden_size)
101+
102+
103+
class TestAfmoeAttentionParity:
104+
def test_rope_conditional_local_vs_global(self, tiny_config, backend_config, device):
105+
"""Local attention (with RoPE) and global attention (without) must diverge given shared weights."""
106+
torch.manual_seed(42)
107+
local_attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device)
108+
global_attn = AfmoeAttention(tiny_config, layer_idx=1, backend=backend_config).to(device)
109+
global_attn.load_state_dict(local_attn.state_dict())
110+
111+
batch, seq_len = 2, 8
112+
x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16)
113+
freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device)
114+
115+
with torch.no_grad():
116+
local_out = local_attn(x, freqs_cis=freqs_cis)
117+
global_out = global_attn(x, freqs_cis=freqs_cis)
118+
119+
max_diff = (local_out - global_out).abs().max().item()
120+
assert max_diff > 0.01, f"RoPE should cause divergence, but max_diff={max_diff}"
121+
122+
def test_qk_norm_reduces_head_variance(self, tiny_config, backend_config, device):
123+
"""Per-head QK RMSNorm should equalize magnitudes across heads."""
124+
attn = AfmoeAttention(tiny_config, layer_idx=0, backend=backend_config).to(device)
125+
126+
batch, seq_len = 1, 4
127+
q = torch.randn(
128+
batch, seq_len, tiny_config.num_attention_heads, tiny_config.head_dim, device=device, dtype=torch.bfloat16
129+
)
130+
q[:, :, 0, :] *= 10.0 # Make first head 10x larger
131+
132+
with torch.no_grad():
133+
q_normed = attn.q_norm(q)
134+
135+
pre_var = q.norm(dim=-1).var(dim=-1).mean().item()
136+
post_var = q_normed.norm(dim=-1).var(dim=-1).mean().item()
137+
assert post_var < pre_var, "QK norm should reduce variance across heads"

tests/unit_tests/models/afmoe/test_afmoe_model.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from nemo_automodel.components.models.afmoe.config import AfmoeConfig
2121
from nemo_automodel.components.models.afmoe.model import AfmoeForCausalLM, AfmoeModel, Block, _build_moe_config
2222
from nemo_automodel.components.models.common import BackendConfig
23-
from nemo_automodel.components.moe.layers import MLP, MoE
23+
from nemo_automodel.components.moe.config import MoEConfig
24+
from nemo_automodel.components.moe.layers import MLP, Gate, MoE
2425

2526
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
2627

@@ -195,3 +196,85 @@ def test_fields_mapped_correctly(self, tiny_config):
195196
assert moe_cfg.route_scale == tiny_config.route_scale
196197
assert moe_cfg.norm_topk_prob is True
197198
assert moe_cfg.force_e_score_correction_bias is True
199+
200+
201+
class TestDualNormParity:
202+
def test_manual_trace_matches_forward(self, tiny_config, backend_config, device):
203+
"""Manual 4-norm residual trace must be bit-identical to Block.forward()."""
204+
torch.manual_seed(42)
205+
moe_config = _build_moe_config(tiny_config)
206+
block = Block(layer_idx=0, config=tiny_config, moe_config=moe_config, backend=backend_config).to(device)
207+
block.eval()
208+
209+
batch, seq_len = 1, 4
210+
x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16)
211+
freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device)
212+
213+
with torch.no_grad():
214+
# Manual trace: attention sublayer
215+
residual = x
216+
h = block.input_layernorm(x)
217+
h = block.self_attn(h, freqs_cis=freqs_cis)
218+
h = block.post_attention_layernorm(h)
219+
after_attn = residual + h
220+
221+
# Manual trace: MLP sublayer
222+
residual = after_attn
223+
h = block.pre_mlp_layernorm(after_attn)
224+
h = block._mlp(h, padding_mask=None)
225+
h = block.post_mlp_layernorm(h)
226+
expected = residual + h
227+
228+
# Block forward
229+
actual = block(x, freqs_cis=freqs_cis)
230+
231+
torch.testing.assert_close(actual, expected, rtol=0, atol=0)
232+
233+
234+
class TestMoeRoutingParity:
235+
def test_sigmoid_norm_scale(self, device):
236+
"""Manual sigmoid -> topk -> normalize -> scale must match Gate.forward()."""
237+
torch.manual_seed(42)
238+
239+
moe_config = MoEConfig(
240+
dim=64,
241+
inter_dim=128,
242+
moe_inter_dim=32,
243+
n_routed_experts=4,
244+
n_shared_experts=1,
245+
n_activated_experts=2,
246+
n_expert_groups=1,
247+
n_limited_groups=1,
248+
train_gate=False,
249+
gate_bias_update_factor=0.0,
250+
score_func="sigmoid",
251+
route_scale=2.0,
252+
aux_loss_coeff=0.0,
253+
norm_topk_prob=True,
254+
force_e_score_correction_bias=True,
255+
dtype=torch.bfloat16,
256+
)
257+
258+
gate = Gate(moe_config).to(device)
259+
torch.manual_seed(123)
260+
gate.weight.data = torch.randn(4, 64, device=device, dtype=torch.bfloat16)
261+
262+
x = torch.randn(8, 64, device=device, dtype=torch.bfloat16) # 8 tokens
263+
token_mask = torch.ones(8, dtype=torch.bool, device=device)
264+
265+
with torch.no_grad():
266+
weights, indices, aux_loss = gate(x, token_mask, cp_mesh=None)
267+
268+
# Manual reference: sigmoid -> bias -> topk -> gather original -> normalize -> scale
269+
with torch.no_grad():
270+
scores = torch.sigmoid(x @ gate.weight.data.T) # [8, 4]
271+
original_scores = scores.clone()
272+
biased = scores + gate.e_score_correction_bias # zeros, no-op
273+
manual_idx = torch.topk(biased, 2, dim=-1)[1]
274+
manual_w = original_scores.gather(1, manual_idx)
275+
manual_w = manual_w / (manual_w.sum(dim=-1, keepdim=True) + 1e-20)
276+
manual_w = manual_w * 2.0
277+
278+
assert torch.equal(indices, manual_idx), "Expert indices mismatch"
279+
torch.testing.assert_close(weights, manual_w, rtol=1e-3, atol=1e-3)
280+
assert aux_loss is None

tests/unit_tests/models/afmoe/test_afmoe_state_dict_adapter.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,30 +82,30 @@ def adapter(config, moe_config, backend):
8282
return AfmoeStateDictAdapter(config, moe_config, backend, dtype=torch.bfloat16)
8383

8484

85-
def _make_hf_expert_state_dict(n_layers=2, n_experts=4, hidden=64, moe_inter=32, num_dense=1):
85+
def _make_hf_expert_state_dict(n_layers=2, n_experts=4, hidden=64, moe_inter=32, num_dense=1, dtype=torch.bfloat16):
8686
"""Create a minimal HF-format state dict with router, experts, and expert_bias."""
8787
sd = {}
8888
for layer_idx in range(n_layers):
8989
prefix = f"model.layers.{layer_idx}"
9090
if layer_idx >= num_dense:
9191
# Router gate
92-
sd[f"{prefix}.mlp.router.gate.weight"] = torch.randn(n_experts, hidden)
92+
sd[f"{prefix}.mlp.router.gate.weight"] = torch.randn(n_experts, hidden, dtype=dtype)
9393
# Expert bias
9494
sd[f"{prefix}.mlp.expert_bias"] = torch.zeros(n_experts)
9595
# Per-expert weights
9696
for e in range(n_experts):
97-
sd[f"{prefix}.mlp.experts.{e}.gate_proj.weight"] = torch.randn(moe_inter, hidden)
98-
sd[f"{prefix}.mlp.experts.{e}.up_proj.weight"] = torch.randn(moe_inter, hidden)
99-
sd[f"{prefix}.mlp.experts.{e}.down_proj.weight"] = torch.randn(hidden, moe_inter)
97+
sd[f"{prefix}.mlp.experts.{e}.gate_proj.weight"] = torch.randn(moe_inter, hidden, dtype=dtype)
98+
sd[f"{prefix}.mlp.experts.{e}.up_proj.weight"] = torch.randn(moe_inter, hidden, dtype=dtype)
99+
sd[f"{prefix}.mlp.experts.{e}.down_proj.weight"] = torch.randn(hidden, moe_inter, dtype=dtype)
100100
# Shared expert
101-
sd[f"{prefix}.mlp.shared_experts.gate_proj.weight"] = torch.randn(moe_inter, hidden)
102-
sd[f"{prefix}.mlp.shared_experts.up_proj.weight"] = torch.randn(moe_inter, hidden)
103-
sd[f"{prefix}.mlp.shared_experts.down_proj.weight"] = torch.randn(hidden, moe_inter)
101+
sd[f"{prefix}.mlp.shared_experts.gate_proj.weight"] = torch.randn(moe_inter, hidden, dtype=dtype)
102+
sd[f"{prefix}.mlp.shared_experts.up_proj.weight"] = torch.randn(moe_inter, hidden, dtype=dtype)
103+
sd[f"{prefix}.mlp.shared_experts.down_proj.weight"] = torch.randn(hidden, moe_inter, dtype=dtype)
104104
else:
105105
# Dense MLP
106-
sd[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(128, hidden)
107-
sd[f"{prefix}.mlp.up_proj.weight"] = torch.randn(128, hidden)
108-
sd[f"{prefix}.mlp.down_proj.weight"] = torch.randn(hidden, 128)
106+
sd[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(128, hidden, dtype=dtype)
107+
sd[f"{prefix}.mlp.up_proj.weight"] = torch.randn(128, hidden, dtype=dtype)
108+
sd[f"{prefix}.mlp.down_proj.weight"] = torch.randn(hidden, 128, dtype=dtype)
109109
return sd
110110

111111

@@ -187,3 +187,20 @@ def test_to_hf_splits_experts(self, adapter):
187187
assert f"model.layers.1.mlp.experts.{e}.gate_proj.weight" in hf_sd
188188
assert f"model.layers.1.mlp.experts.{e}.up_proj.weight" in hf_sd
189189
assert f"model.layers.1.mlp.experts.{e}.down_proj.weight" in hf_sd
190+
191+
def test_roundtrip_preserves_all_values(self, adapter):
192+
"""HF -> NeMo -> HF round-trip must preserve exact tensor values."""
193+
torch.manual_seed(42)
194+
hf_sd = _make_hf_expert_state_dict()
195+
originals = {k: v.clone() for k, v in hf_sd.items()}
196+
197+
nemo_sd = adapter.from_hf(hf_sd)
198+
roundtrip_sd = adapter.to_hf(nemo_sd)
199+
200+
assert set(roundtrip_sd.keys()) == set(originals.keys()), (
201+
f"Missing: {set(originals.keys()) - set(roundtrip_sd.keys())}, "
202+
f"Extra: {set(roundtrip_sd.keys()) - set(originals.keys())}"
203+
)
204+
for key in originals:
205+
max_diff = (originals[key].float() - roundtrip_sd[key].float()).abs().max().item()
206+
assert max_diff == 0.0, f"Round-trip mismatch for {key}: max_diff={max_diff}"

0 commit comments

Comments
 (0)