Skip to content

Commit c90a8e8

Browse files
committed
Add tests for recurrent (T=1) and multi-T dispatch
- test_recurrent_t1: verify T=1 recurrent kernel against FLA naive reference across all FLA test configs - test_dispatch_multiple_seq_lengths: verify correctness for T in {1, 2, 32, 63, 64, 65, 128, 256}, covering both dispatch paths and chunk boundary edge cases
1 parent fc5018e commit c90a8e8

File tree

1 file changed

+80
-13
lines changed

1 file changed

+80
-13
lines changed

backends/cuda/tests/test_chunk_gated_delta_rule.py

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
from torch.export import export
4343

4444

45-
B, T, H, K, V = 1, 128, 4, 64, 64
45+
B, H, K, V = 1, 4, 64, 64
46+
T = 128 # default T for chunked tests
4647

4748
EXECUTORCH_ROOT = os.path.normpath(os.path.join(os.path.dirname(__file__), "../../.."))
4849
RUNNER_PATH = os.path.join(EXECUTORCH_ROOT, "cmake-out", "executor_runner")
@@ -88,32 +89,33 @@ def _make_inputs_from_fla(
8889
gate_logit_normalizer,
8990
mask_p=0.0,
9091
nonzero_h0=False,
92+
seq_len=T,
9193
dtype=torch.bfloat16,
9294
device="cuda",
9395
):
9496
"""Generate inputs following FLA test_chunk() conventions."""
9597
torch.manual_seed(seed)
96-
q = torch.rand(B, T, H, K, dtype=dtype, device=device)
97-
k = torch.rand(B, T, H, K, dtype=dtype, device=device)
98-
v = torch.rand(B, T, H, V, dtype=dtype, device=device)
99-
beta = torch.rand(B, T, H, dtype=torch.float32, device=device).sigmoid().to(dtype)
100-
g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.float32, device=device))
98+
q = torch.rand(B, seq_len, H, K, dtype=dtype, device=device)
99+
k = torch.rand(B, seq_len, H, K, dtype=dtype, device=device)
100+
v = torch.rand(B, seq_len, H, V, dtype=dtype, device=device)
101+
beta = torch.rand(B, seq_len, H, dtype=torch.float32, device=device).sigmoid().to(dtype)
102+
g = F.logsigmoid(torch.rand(B, seq_len, H, dtype=torch.float32, device=device))
101103
g = (g / gate_logit_normalizer).to(dtype)
102104
if mask_p > 0:
103-
g = g * (torch.rand(B, T, H, dtype=dtype, device=device) > mask_p)
105+
g = g * (torch.rand(B, seq_len, H, dtype=dtype, device=device) > mask_p)
104106
if nonzero_h0:
105107
h0 = torch.randn(B, H, K, V, dtype=dtype, device=device)
106108
else:
107109
h0 = torch.zeros(B, H, K, V, dtype=dtype, device=device)
108110
return q, k, v, g, beta, h0
109111

110112

111-
def _make_inputs(dtype=torch.bfloat16, device="cuda"):
112-
q = torch.randn(B, T, H, K, dtype=dtype, device=device)
113-
k = torch.randn(B, T, H, K, dtype=dtype, device=device)
114-
v = torch.randn(B, T, H, V, dtype=dtype, device=device)
115-
g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device))
116-
beta = torch.rand(B, T, H, dtype=dtype, device=device).sigmoid()
113+
def _make_inputs(seq_len=T, dtype=torch.bfloat16, device="cuda"):
114+
q = torch.randn(B, seq_len, H, K, dtype=dtype, device=device)
115+
k = torch.randn(B, seq_len, H, K, dtype=dtype, device=device)
116+
v = torch.randn(B, seq_len, H, V, dtype=dtype, device=device)
117+
g = F.logsigmoid(torch.randn(B, seq_len, H, dtype=dtype, device=device))
118+
beta = torch.rand(B, seq_len, H, dtype=dtype, device=device).sigmoid()
117119
initial_state = torch.randn(B, H, K, V, dtype=dtype, device=device)
118120
return q, k, v, g, beta, initial_state
119121

@@ -252,6 +254,71 @@ def test_eager_matches_fla(self):
252254

253255
self.assertLess((o_ours.float() - o_ref.float()).abs().max().item(), 0.01)
254256

257+
def test_recurrent_t1(self):
258+
"""T=1 (decode) uses recurrent kernel — verify vs naive reference."""
259+
from fla.ops.gated_delta_rule.naive import naive_recurrent_gated_delta_rule
260+
261+
model = ChunkGatedDeltaModel().eval()
262+
for seed, norm, mask_p, nonzero_h0, desc in FLA_TEST_CONFIGS:
263+
with self.subTest(desc=desc):
264+
inputs = _make_inputs_from_fla(seed, norm, mask_p, nonzero_h0, seq_len=1)
265+
q, k, v, g, beta, h0 = inputs
266+
267+
with torch.no_grad():
268+
o_ours, s_ours = model(q, k, v, g, beta, h0)
269+
270+
o_ref, s_ref = naive_recurrent_gated_delta_rule(
271+
q=F.normalize(q, p=2, dim=-1),
272+
k=F.normalize(k, p=2, dim=-1),
273+
v=v,
274+
beta=beta,
275+
g=g,
276+
initial_state=h0,
277+
output_final_state=True,
278+
)
279+
280+
self.assertEqual(o_ours.shape, torch.Size([B, 1, H, V]))
281+
self.assertEqual(s_ours.shape, torch.Size([B, H, K, V]))
282+
o_diff = (o_ours.float() - o_ref.float()).abs().max().item()
283+
s_diff = (s_ours.float() - s_ref.float()).abs().max().item()
284+
self.assertLess(o_diff, 0.01, f"{desc}: output diff {o_diff}")
285+
self.assertLess(s_diff, 0.01, f"{desc}: state diff {s_diff}")
286+
287+
def test_dispatch_multiple_seq_lengths(self):
288+
"""Verify correctness across T values hitting both dispatch paths."""
289+
from fla.ops.gated_delta_rule.naive import naive_recurrent_gated_delta_rule
290+
291+
model = ChunkGatedDeltaModel().eval()
292+
# T=1 → recurrent, T>1 → chunked; include boundary values
293+
for seq_len in [1, 2, 32, 63, 64, 65, 128, 256]:
294+
with self.subTest(T=seq_len):
295+
inputs = _make_inputs_from_fla(42, 1.0, 0.0, True, seq_len=seq_len)
296+
q, k, v, g, beta, h0 = inputs
297+
298+
with torch.no_grad():
299+
o_ours, s_ours = model(q, k, v, g, beta, h0)
300+
301+
o_ref, s_ref = naive_recurrent_gated_delta_rule(
302+
q=F.normalize(q, p=2, dim=-1),
303+
k=F.normalize(k, p=2, dim=-1),
304+
v=v,
305+
beta=beta,
306+
g=g,
307+
initial_state=h0,
308+
output_final_state=True,
309+
)
310+
311+
self.assertEqual(o_ours.shape, torch.Size([B, seq_len, H, V]))
312+
self.assertEqual(s_ours.shape, torch.Size([B, H, K, V]))
313+
o_diff = (o_ours.float() - o_ref.float()).abs().max().item()
314+
s_diff = (s_ours.float() - s_ref.float()).abs().max().item()
315+
self.assertLess(
316+
o_diff, 0.02, f"T={seq_len}: output diff {o_diff}"
317+
)
318+
self.assertLess(
319+
s_diff, 0.02, f"T={seq_len}: state diff {s_diff}"
320+
)
321+
255322
def test_export_cuda(self):
256323
with tempfile.TemporaryDirectory() as tmpdir:
257324
pte_path = export_chunk_gated_delta(tmpdir)

0 commit comments

Comments
 (0)