|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Correctness test: fully-fused Triton GatedDeltaNet decode kernel vs PyTorch reference. |
| 9 | +
|
| 10 | +Verifies that torch.ops.triton.fused_deltanet_decode produces the same output |
| 11 | +and state as the original GatedDeltaNet T=1 recurrence with manual Q/K/V split, |
| 12 | +L2 norm, head repeat, and gating. |
| 13 | +""" |
| 14 | + |
| 15 | +import unittest |
| 16 | + |
| 17 | +import torch |
| 18 | +import torch.nn.functional as F |
| 19 | + |
| 20 | + |
| 21 | +def _skip_if_no_cuda(): |
| 22 | + if not torch.cuda.is_available(): |
| 23 | + raise unittest.SkipTest("CUDA not available") |
| 24 | + if not torch.cuda.is_bf16_supported(): |
| 25 | + raise unittest.SkipTest("BF16 not supported on this GPU") |
| 26 | + |
| 27 | + |
| 28 | +def _import_fused_deltanet_decode(): |
| 29 | + from executorch.backends.cuda.triton.kernels.fused_deltanet_decode import ( |
| 30 | + fused_deltanet_decode, # noqa: F401 — registers torch.ops.triton.* |
| 31 | + ) |
| 32 | + |
| 33 | + return fused_deltanet_decode |
| 34 | + |
| 35 | + |
| 36 | +def _max_abs_error(a, b): |
| 37 | + return (a.float() - b.float()).abs().max().item() |
| 38 | + |
| 39 | + |
| 40 | +# bf16 kernel vs fp32 reference tolerance. |
| 41 | +MAX_ABS_TOL = 0.05 |
| 42 | +MULTISTEP_TOL = 0.1 |
| 43 | + |
| 44 | + |
| 45 | +def _reference_deltanet_decode( |
| 46 | + qkv_conv, |
| 47 | + alpha, |
| 48 | + beta_raw, |
| 49 | + neg_A_exp, |
| 50 | + dt_bias, |
| 51 | + state, |
| 52 | + num_k_heads, |
| 53 | + num_v_heads, |
| 54 | + head_k_dim, |
| 55 | + head_v_dim, |
| 56 | +): |
| 57 | + """Reference PyTorch implementation matching model.py's original T=1 path. |
| 58 | +
|
| 59 | + Does Q/K/V split, L2 norm, head repeat, gating, then recurrent update. |
| 60 | + """ |
| 61 | + B = qkv_conv.shape[0] |
| 62 | + key_dim = num_k_heads * head_k_dim |
| 63 | + |
| 64 | + q = qkv_conv[:, :key_dim].reshape(B, num_k_heads, head_k_dim) |
| 65 | + k = qkv_conv[:, key_dim : 2 * key_dim].reshape(B, num_k_heads, head_k_dim) |
| 66 | + v = qkv_conv[:, 2 * key_dim :].reshape(B, num_v_heads, head_v_dim) |
| 67 | + |
| 68 | + q = F.normalize(q.float(), p=2, dim=-1) |
| 69 | + k = F.normalize(k.float(), p=2, dim=-1) |
| 70 | + v = v.float() |
| 71 | + |
| 72 | + head_repeat = num_v_heads // num_k_heads |
| 73 | + if head_repeat > 1: |
| 74 | + q = q.repeat_interleave(head_repeat, dim=1) |
| 75 | + k = k.repeat_interleave(head_repeat, dim=1) |
| 76 | + |
| 77 | + beta = torch.sigmoid(beta_raw.float()) |
| 78 | + g = neg_A_exp.float() * F.softplus(alpha.float() + dt_bias.float()) |
| 79 | + |
| 80 | + scale = head_k_dim**-0.5 |
| 81 | + state_f32 = state.float() |
| 82 | + |
| 83 | + decay = torch.exp(g).unsqueeze(-1).unsqueeze(-1) |
| 84 | + state_f32 = state_f32 * decay |
| 85 | + |
| 86 | + Sk = torch.einsum("bhkv,bhk->bhv", state_f32, k) |
| 87 | + delta = beta.unsqueeze(-1) * (v - Sk) |
| 88 | + state_f32 = state_f32 + torch.einsum("bhk,bhv->bhkv", k, delta) |
| 89 | + |
| 90 | + output = torch.einsum("bhkv,bhk->bhv", state_f32, q) * scale |
| 91 | + |
| 92 | + new_state = state_f32.to(state.dtype) |
| 93 | + return output, new_state |
| 94 | + |
| 95 | + |
| 96 | +# Qwen3.5 MoE dimensions (used across tests) |
| 97 | +NUM_K_HEADS = 16 |
| 98 | +NUM_V_HEADS = 32 |
| 99 | +HEAD_K_DIM = 128 |
| 100 | +HEAD_V_DIM = 128 |
| 101 | +KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048 |
| 102 | +VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 4096 |
| 103 | +CONV_DIM = 2 * KEY_DIM + VALUE_DIM # 8192 |
| 104 | + |
| 105 | + |
| 106 | +class TestFusedDeltanetDecode(unittest.TestCase): |
| 107 | + """Test fused GatedDeltaNet decode kernel correctness against PyTorch reference.""" |
| 108 | + |
| 109 | + @classmethod |
| 110 | + def setUpClass(cls): |
| 111 | + _skip_if_no_cuda() |
| 112 | + cls.fused_fn = _import_fused_deltanet_decode() |
| 113 | + torch.manual_seed(42) |
| 114 | + |
| 115 | + cls.A_log = torch.log(torch.empty(NUM_V_HEADS, device="cuda").uniform_(0.5, 8)) |
| 116 | + cls.neg_A_exp = -torch.exp(cls.A_log).float() |
| 117 | + cls.dt_bias = torch.ones(NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 118 | + |
| 119 | + def _run_fused(self, qkv, alpha, beta_raw, state): |
| 120 | + """Run fused kernel and return (output, new_state).""" |
| 121 | + output, new_state = torch.ops.triton.fused_deltanet_decode( |
| 122 | + qkv, |
| 123 | + alpha, |
| 124 | + beta_raw, |
| 125 | + self.A_log, |
| 126 | + self.dt_bias, |
| 127 | + state, |
| 128 | + ) |
| 129 | + return output, new_state |
| 130 | + |
| 131 | + def _run_reference(self, qkv, alpha, beta_raw, state): |
| 132 | + """Run reference and return (output, new_state).""" |
| 133 | + return _reference_deltanet_decode( |
| 134 | + qkv, |
| 135 | + alpha, |
| 136 | + beta_raw, |
| 137 | + self.neg_A_exp, |
| 138 | + self.dt_bias, |
| 139 | + state, |
| 140 | + NUM_K_HEADS, |
| 141 | + NUM_V_HEADS, |
| 142 | + HEAD_K_DIM, |
| 143 | + HEAD_V_DIM, |
| 144 | + ) |
| 145 | + |
| 146 | + # ------------------------------------------------------------------ |
| 147 | + # Correctness |
| 148 | + # ------------------------------------------------------------------ |
| 149 | + |
| 150 | + def test_basic(self): |
| 151 | + """Single batch, Qwen3.5 MoE dimensions.""" |
| 152 | + B = 1 |
| 153 | + torch.manual_seed(42) |
| 154 | + qkv = torch.randn(B, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1 |
| 155 | + alpha = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 156 | + beta_raw = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 157 | + state = ( |
| 158 | + torch.randn( |
| 159 | + B, |
| 160 | + NUM_V_HEADS, |
| 161 | + HEAD_K_DIM, |
| 162 | + HEAD_V_DIM, |
| 163 | + device="cuda", |
| 164 | + dtype=torch.bfloat16, |
| 165 | + ) |
| 166 | + * 0.1 |
| 167 | + ) |
| 168 | + |
| 169 | + ref_out, ref_state = self._run_reference( |
| 170 | + qkv.clone(), |
| 171 | + alpha.clone(), |
| 172 | + beta_raw.clone(), |
| 173 | + state.clone(), |
| 174 | + ) |
| 175 | + fused_out, fused_state = self._run_fused( |
| 176 | + qkv.clone(), |
| 177 | + alpha.clone(), |
| 178 | + beta_raw.clone(), |
| 179 | + state.clone(), |
| 180 | + ) |
| 181 | + |
| 182 | + self.assertLess( |
| 183 | + _max_abs_error(fused_out, ref_out), MAX_ABS_TOL, "output mismatch" |
| 184 | + ) |
| 185 | + self.assertLess( |
| 186 | + _max_abs_error(fused_state, ref_state), MAX_ABS_TOL, "state mismatch" |
| 187 | + ) |
| 188 | + |
| 189 | + def test_batch(self): |
| 190 | + """Batch size > 1.""" |
| 191 | + for B in [2, 4]: |
| 192 | + with self.subTest(B=B): |
| 193 | + torch.manual_seed(42) |
| 194 | + qkv = ( |
| 195 | + torch.randn(B, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1 |
| 196 | + ) |
| 197 | + alpha = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 198 | + beta_raw = torch.randn( |
| 199 | + B, NUM_V_HEADS, device="cuda", dtype=torch.float32 |
| 200 | + ) |
| 201 | + state = ( |
| 202 | + torch.randn( |
| 203 | + B, |
| 204 | + NUM_V_HEADS, |
| 205 | + HEAD_K_DIM, |
| 206 | + HEAD_V_DIM, |
| 207 | + device="cuda", |
| 208 | + dtype=torch.bfloat16, |
| 209 | + ) |
| 210 | + * 0.1 |
| 211 | + ) |
| 212 | + |
| 213 | + ref_out, ref_state = self._run_reference( |
| 214 | + qkv.clone(), |
| 215 | + alpha.clone(), |
| 216 | + beta_raw.clone(), |
| 217 | + state.clone(), |
| 218 | + ) |
| 219 | + fused_out, fused_state = self._run_fused( |
| 220 | + qkv.clone(), |
| 221 | + alpha.clone(), |
| 222 | + beta_raw.clone(), |
| 223 | + state.clone(), |
| 224 | + ) |
| 225 | + |
| 226 | + self.assertLess( |
| 227 | + _max_abs_error(fused_out, ref_out), |
| 228 | + MAX_ABS_TOL, |
| 229 | + f"B={B} output mismatch", |
| 230 | + ) |
| 231 | + self.assertLess( |
| 232 | + _max_abs_error(fused_state, ref_state), |
| 233 | + MAX_ABS_TOL, |
| 234 | + f"B={B} state mismatch", |
| 235 | + ) |
| 236 | + |
| 237 | + def test_multistep(self): |
| 238 | + """10-step sequential decode checks accumulation drift.""" |
| 239 | + torch.manual_seed(42) |
| 240 | + state_ref = ( |
| 241 | + torch.randn( |
| 242 | + 1, |
| 243 | + NUM_V_HEADS, |
| 244 | + HEAD_K_DIM, |
| 245 | + HEAD_V_DIM, |
| 246 | + device="cuda", |
| 247 | + dtype=torch.bfloat16, |
| 248 | + ) |
| 249 | + * 0.01 |
| 250 | + ) |
| 251 | + state_fused = state_ref.clone() |
| 252 | + |
| 253 | + for _ in range(10): |
| 254 | + qkv = torch.randn(1, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1 |
| 255 | + alpha = torch.randn(1, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 256 | + beta_raw = torch.randn(1, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 257 | + |
| 258 | + ref_out, state_ref = self._run_reference( |
| 259 | + qkv.clone(), |
| 260 | + alpha.clone(), |
| 261 | + beta_raw.clone(), |
| 262 | + state_ref, |
| 263 | + ) |
| 264 | + fused_out, state_fused = self._run_fused( |
| 265 | + qkv.clone(), |
| 266 | + alpha.clone(), |
| 267 | + beta_raw.clone(), |
| 268 | + state_fused, |
| 269 | + ) |
| 270 | + |
| 271 | + self.assertLess( |
| 272 | + _max_abs_error(fused_out, ref_out), |
| 273 | + MULTISTEP_TOL, |
| 274 | + "multi-step output drift", |
| 275 | + ) |
| 276 | + self.assertLess( |
| 277 | + _max_abs_error(state_fused, state_ref), |
| 278 | + MULTISTEP_TOL, |
| 279 | + "multi-step state drift", |
| 280 | + ) |
| 281 | + |
| 282 | + def test_state_not_mutated(self): |
| 283 | + """Kernel must not mutate the input state tensor.""" |
| 284 | + B = 1 |
| 285 | + torch.manual_seed(42) |
| 286 | + qkv = torch.randn(B, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1 |
| 287 | + alpha = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 288 | + beta_raw = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 289 | + state = ( |
| 290 | + torch.randn( |
| 291 | + B, |
| 292 | + NUM_V_HEADS, |
| 293 | + HEAD_K_DIM, |
| 294 | + HEAD_V_DIM, |
| 295 | + device="cuda", |
| 296 | + dtype=torch.bfloat16, |
| 297 | + ) |
| 298 | + * 0.1 |
| 299 | + ) |
| 300 | + state_copy = state.clone() |
| 301 | + |
| 302 | + _, _ = self._run_fused(qkv, alpha, beta_raw, state) |
| 303 | + |
| 304 | + self.assertTrue(torch.equal(state, state_copy), "input state was mutated") |
| 305 | + |
| 306 | + # ------------------------------------------------------------------ |
| 307 | + # CUDA Graph compatibility |
| 308 | + # ------------------------------------------------------------------ |
| 309 | + |
| 310 | + def test_cuda_graph(self): |
| 311 | + """Kernel must be capturable in a CUDA graph.""" |
| 312 | + B = 1 |
| 313 | + torch.manual_seed(42) |
| 314 | + qkv = torch.randn(B, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1 |
| 315 | + alpha = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 316 | + beta_raw = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32) |
| 317 | + state = ( |
| 318 | + torch.randn( |
| 319 | + B, |
| 320 | + NUM_V_HEADS, |
| 321 | + HEAD_K_DIM, |
| 322 | + HEAD_V_DIM, |
| 323 | + device="cuda", |
| 324 | + dtype=torch.bfloat16, |
| 325 | + ) |
| 326 | + * 0.1 |
| 327 | + ) |
| 328 | + |
| 329 | + # Warmup |
| 330 | + for _ in range(3): |
| 331 | + _ = self._run_fused(qkv, alpha, beta_raw, state) |
| 332 | + |
| 333 | + # Capture |
| 334 | + graph = torch.cuda.CUDAGraph() |
| 335 | + with torch.cuda.graph(graph): |
| 336 | + out_cg, state_cg = self._run_fused(qkv, alpha, beta_raw, state) |
| 337 | + |
| 338 | + # Replay |
| 339 | + graph.replay() |
| 340 | + |
| 341 | + # Compare with reference |
| 342 | + ref_out, _ = self._run_reference( |
| 343 | + qkv.clone(), |
| 344 | + alpha.clone(), |
| 345 | + beta_raw.clone(), |
| 346 | + state.clone(), |
| 347 | + ) |
| 348 | + self.assertFalse(torch.isnan(out_cg).any(), "NaN in CUDA graph output") |
| 349 | + self.assertLess( |
| 350 | + _max_abs_error(out_cg, ref_out), |
| 351 | + MAX_ABS_TOL, |
| 352 | + "CUDA graph output mismatch", |
| 353 | + ) |
| 354 | + |
| 355 | + |
| 356 | +if __name__ == "__main__": |
| 357 | + unittest.main() |
0 commit comments