Skip to content

Commit 2d7c2e7

Browse files
committed
fix benchmark fwd+bwd
1 parent 5a67775 commit 2d7c2e7

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

users/zeyer/nn_rf/encoder/chunked_conformer_v2_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,10 @@ def _bench(fn, label, *, with_grad: bool = False, **extra_globals):
354354
# leaf tensor for RF: wrap raw leaf in RF Tensor so _apply_rope_real sees it
355355
x_leaf_rf = x_s.copy_template()
356356
x_leaf_rf.raw_tensor = x_s.raw_tensor.clone().requires_grad_(True)
357+
# Warmup: torch.compile compiles the backward lazily on first call;
358+
# without this the first timed iteration includes compilation time.
359+
_rope_compiled_raw(x_leaf_c, pe_raw).sum().backward()
360+
_apply_rope_real(x_leaf_rf, pe_s, head_dim).raw_tensor.sum().backward()
357361
t_c = _bench(
358362
lambda: _rope_compiled_raw(x_leaf_c, pe_raw).sum().backward(),
359363
f"compiled fwd+bwd T={T}",

0 commit comments

Comments
 (0)