|
20 | 20 | from nemo_automodel.components.models.afmoe.config import AfmoeConfig |
21 | 21 | from nemo_automodel.components.models.afmoe.model import AfmoeForCausalLM, AfmoeModel, Block, _build_moe_config |
22 | 22 | from nemo_automodel.components.models.common import BackendConfig |
23 | | -from nemo_automodel.components.moe.layers import MLP, MoE |
| 23 | +from nemo_automodel.components.moe.config import MoEConfig |
| 24 | +from nemo_automodel.components.moe.layers import MLP, Gate, MoE |
24 | 25 |
|
25 | 26 | pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
26 | 27 |
|
@@ -195,3 +196,85 @@ def test_fields_mapped_correctly(self, tiny_config): |
195 | 196 | assert moe_cfg.route_scale == tiny_config.route_scale |
196 | 197 | assert moe_cfg.norm_topk_prob is True |
197 | 198 | assert moe_cfg.force_e_score_correction_bias is True |
| 199 | + |
| 200 | + |
| 201 | +class TestDualNormParity: |
| 202 | + def test_manual_trace_matches_forward(self, tiny_config, backend_config, device): |
| 203 | + """Manual 4-norm residual trace must be bit-identical to Block.forward().""" |
| 204 | + torch.manual_seed(42) |
| 205 | + moe_config = _build_moe_config(tiny_config) |
| 206 | + block = Block(layer_idx=0, config=tiny_config, moe_config=moe_config, backend=backend_config).to(device) |
| 207 | + block.eval() |
| 208 | + |
| 209 | + batch, seq_len = 1, 4 |
| 210 | + x = torch.randn(batch, seq_len, tiny_config.hidden_size, device=device, dtype=torch.bfloat16) |
| 211 | + freqs_cis = torch.randn(batch, seq_len, tiny_config.head_dim, device=device) |
| 212 | + |
| 213 | + with torch.no_grad(): |
| 214 | + # Manual trace: attention sublayer |
| 215 | + residual = x |
| 216 | + h = block.input_layernorm(x) |
| 217 | + h = block.self_attn(h, freqs_cis=freqs_cis) |
| 218 | + h = block.post_attention_layernorm(h) |
| 219 | + after_attn = residual + h |
| 220 | + |
| 221 | + # Manual trace: MLP sublayer |
| 222 | + residual = after_attn |
| 223 | + h = block.pre_mlp_layernorm(after_attn) |
| 224 | + h = block._mlp(h, padding_mask=None) |
| 225 | + h = block.post_mlp_layernorm(h) |
| 226 | + expected = residual + h |
| 227 | + |
| 228 | + # Block forward |
| 229 | + actual = block(x, freqs_cis=freqs_cis) |
| 230 | + |
| 231 | + torch.testing.assert_close(actual, expected, rtol=0, atol=0) |
| 232 | + |
| 233 | + |
| 234 | +class TestMoeRoutingParity: |
| 235 | + def test_sigmoid_norm_scale(self, device): |
| 236 | + """Manual sigmoid -> topk -> normalize -> scale must match Gate.forward().""" |
| 237 | + torch.manual_seed(42) |
| 238 | + |
| 239 | + moe_config = MoEConfig( |
| 240 | + dim=64, |
| 241 | + inter_dim=128, |
| 242 | + moe_inter_dim=32, |
| 243 | + n_routed_experts=4, |
| 244 | + n_shared_experts=1, |
| 245 | + n_activated_experts=2, |
| 246 | + n_expert_groups=1, |
| 247 | + n_limited_groups=1, |
| 248 | + train_gate=False, |
| 249 | + gate_bias_update_factor=0.0, |
| 250 | + score_func="sigmoid", |
| 251 | + route_scale=2.0, |
| 252 | + aux_loss_coeff=0.0, |
| 253 | + norm_topk_prob=True, |
| 254 | + force_e_score_correction_bias=True, |
| 255 | + dtype=torch.bfloat16, |
| 256 | + ) |
| 257 | + |
| 258 | + gate = Gate(moe_config).to(device) |
| 259 | + torch.manual_seed(123) |
| 260 | + gate.weight.data = torch.randn(4, 64, device=device, dtype=torch.bfloat16) |
| 261 | + |
| 262 | + x = torch.randn(8, 64, device=device, dtype=torch.bfloat16) # 8 tokens |
| 263 | + token_mask = torch.ones(8, dtype=torch.bool, device=device) |
| 264 | + |
| 265 | + with torch.no_grad(): |
| 266 | + weights, indices, aux_loss = gate(x, token_mask, cp_mesh=None) |
| 267 | + |
| 268 | + # Manual reference: sigmoid -> bias -> topk -> gather original -> normalize -> scale |
| 269 | + with torch.no_grad(): |
| 270 | + scores = torch.sigmoid(x @ gate.weight.data.T) # [8, 4] |
| 271 | + original_scores = scores.clone() |
| 272 | + biased = scores + gate.e_score_correction_bias # zeros, no-op |
| 273 | + manual_idx = torch.topk(biased, 2, dim=-1)[1] |
| 274 | + manual_w = original_scores.gather(1, manual_idx) |
| 275 | + manual_w = manual_w / (manual_w.sum(dim=-1, keepdim=True) + 1e-20) |
| 276 | + manual_w = manual_w * 2.0 |
| 277 | + |
| 278 | + assert torch.equal(indices, manual_idx), "Expert indices mismatch" |
| 279 | + torch.testing.assert_close(weights, manual_w, rtol=1e-3, atol=1e-3) |
| 280 | + assert aux_loss is None |
0 commit comments