Skip to content

Commit ed60ab2

Browse files
committed
perf(attention_mask): vectorise dense_mask_to_jagged_arbitrary_func
Replace the per-row Python loop with a cumsum + nonzero scatter so the function issues a single host sync (for `max_intervals`) instead of one per row × per interval × per .item() call. Why --- Greptile flagged this as P1: the loop has 4 host-syncing ops in the inner body — `row.any()`, two `.nonzero()` materialisations, and `start_pos[iv].item()` / `end_pos[iv].item()`. For B=64, seqlen=1024, ~2 intervals/row, that's ≈500 k forced GPU→CPU syncs per call. The function is on the jagged-FA fallback path in `SIDGRModel.decoder_step` (when the caller passes a dense `attention_mask` instead of a prebuilt `arbitrary_func`), so this dominates training step time on that path. How --- - `starts` / `ends` boundary detection was already vectorised; keep that. - Mask out positions outside each sample's `[0, seq_len)` so padded rows/cols don't produce spurious intervals. - `starts.cumsum(dim=-1)` assigns each transition a 1-based interval index without any sync. - `starts.nonzero()` gives all (b, q, k) coordinates in one shot; index into `af` via vectorised assignment. One nonzero call per side replaces ~N × seq_len of them. - Same for `ends`, with the existing `+1` (exclusive) offset preserved. Verification ------------ Add `TestDenseMaskToJaggedVectorisedMatchesLoop` comparing the new vectorised path against the existing loop-based test helper across: jagged causal, target-grouped (4 beam_width × candidate_len cases), all-zero mask, multi-interval per row, uneven seq_lens. Local: 27/27 pass (was 20), pre-commit clean, no behaviour change for the existing 20 tests. Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
1 parent 5efdadc commit ed60ab2

2 files changed

Lines changed: 117 additions & 26 deletions

File tree

examples/sid_gr/model/attention_mask.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def dense_mask_to_jagged_arbitrary_func(
447447
B, N, _ = valid_mask.shape
448448
device = valid_mask.device
449449

450-
# Detect interval boundaries via transitions
450+
# Detect interval boundaries via transitions (vectorised on [B, N, N]).
451451
shifted = torch.zeros_like(valid_mask)
452452
shifted[:, :, 1:] = valid_mask[:, :, :-1]
453453
starts = valid_mask & ~shifted
@@ -457,35 +457,49 @@ def dense_mask_to_jagged_arbitrary_func(
457457
ends = valid_mask & ~ends_shifted
458458

459459
max_intervals = int(starts.sum(dim=-1).max().item())
460-
# max(2 * max_intervals + 1, 3) is always odd, so no extra parity fix-up.
460+
# 2 * max_intervals + 1 is always odd, so no extra parity fix-up.
461461
n_func = max(2 * max_intervals + 1, 3)
462462

463463
af = torch.zeros(
464464
1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device
465465
)
466-
467-
for b in range(B):
468-
batch_start = offsets[b].item()
469-
batch_end = offsets[b + 1].item()
470-
seq_len = batch_end - batch_start
471-
472-
for local_q in range(seq_len):
473-
global_q = batch_start + local_q
474-
row = valid_mask[b, local_q, :seq_len]
475-
if not row.any():
476-
continue
477-
478-
start_pos = starts[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1)
479-
end_pos = ends[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) + 1
480-
481-
# In flattened coordinates, the first visible key is at
482-
# batch_start (not 0), so F0 is always 0. All intervals go
483-
# into the explicit (F1,F2), (F3,F4), ... slots.
484-
for iv in range(len(start_pos)):
485-
s = start_pos[iv].item() + batch_start
486-
e = end_pos[iv].item() + batch_start
487-
af[0, 0, 2 * iv + 1, global_q] = s
488-
af[0, 0, 2 * iv + 2, global_q] = e
466+
if max_intervals == 0:
467+
return af # mask is all-False; rows stay zero (F0=0 ⇒ no keys visible)
468+
469+
# Mask out positions outside each sample's [0, seq_len) range so the
470+
# padded rows/cols never contribute spurious intervals.
471+
seq_lens = offsets[1:] - offsets[:-1] # [B]
472+
batch_starts = offsets[:-1] # [B]
473+
arange_n = torch.arange(N, device=device)
474+
in_range = arange_n.unsqueeze(0) < seq_lens.unsqueeze(-1) # [B, N]
475+
in_qk = in_range.unsqueeze(-1) & in_range.unsqueeze(-2) # [B, N, N]
476+
starts = starts & in_qk
477+
ends = ends & in_qk
478+
479+
# cumsum along the key axis assigns a 1-based interval index to each
480+
# transition position; e.g. the 3rd True in starts[b, q, :] gets value 3.
481+
iv_starts = starts.cumsum(dim=-1) # [B, N, N]
482+
iv_ends = ends.cumsum(dim=-1) # [B, N, N]
483+
484+
# Scatter all start transitions into af in a single op.
485+
sc = starts.nonzero(as_tuple=False) # [Ns, 3] = (b, q, k)
486+
if sc.numel() > 0:
487+
bs, qs, ks = sc[:, 0], sc[:, 1], sc[:, 2]
488+
ivs = iv_starts[bs, qs, ks] # [Ns] 1-based
489+
global_qs = batch_starts[bs] + qs
490+
global_ks = (batch_starts[bs] + ks).to(torch.int32)
491+
af_row = (2 * (ivs - 1) + 1).long()
492+
af[0, 0, af_row, global_qs] = global_ks
493+
494+
# Same for ends (exclusive: +1 because the loop version added +1).
495+
ec = ends.nonzero(as_tuple=False)
496+
if ec.numel() > 0:
497+
be, qe, ke = ec[:, 0], ec[:, 1], ec[:, 2]
498+
ive = iv_ends[be, qe, ke]
499+
global_qe = batch_starts[be] + qe
500+
global_ke = (batch_starts[be] + ke + 1).to(torch.int32)
501+
af_row_e = (2 * (ive - 1) + 2).long()
502+
af[0, 0, af_row_e, global_qe] = global_ke
489503

490504
return af
491505

examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626
from attention_mask import (
2727
build_jagged_causal_arbitrary_func,
2828
dense_mask_to_arbitrary_func,
29-
padded_target_aware_causal_mask,
3029
)
30+
from attention_mask import (
31+
dense_mask_to_jagged_arbitrary_func as dense_mask_to_jagged_arbitrary_func_vec,
32+
)
33+
from attention_mask import padded_target_aware_causal_mask
3134

3235
sys.path.pop(0)
3336

@@ -262,3 +265,77 @@ def test_dense_to_jagged_target_grouped(self, beam_width, candidate_len):
262265
expected[s : s + sl, s : s + sl] = valid_3d[b, :sl, :sl]
263266

264267
assert torch.equal(expected, recon)
268+
269+
270+
class TestDenseMaskToJaggedVectorisedMatchesLoop:
271+
"""The model's ``dense_mask_to_jagged_arbitrary_func`` is vectorised
272+
(cumsum + scatter, one host sync) while the loop-based helper above
273+
spells out the same algorithm row by row. Verify they agree on every
274+
mask shape exercised by the rest of this file plus a couple of
275+
pathological cases.
276+
"""
277+
278+
@staticmethod
279+
def _both(valid_mask, offsets, total):
280+
a = dense_mask_to_jagged_arbitrary_func(valid_mask, offsets, total)
281+
b = dense_mask_to_jagged_arbitrary_func_vec(valid_mask, offsets, total)
282+
assert a.shape == b.shape, (a.shape, b.shape)
283+
assert torch.equal(a, b), f"vectorised != loop\nloop:\n{a}\nvec:\n{b}"
284+
285+
def test_jagged_causal(self):
286+
offsets = torch.tensor([0, 3, 7], device="cuda")
287+
B, total, max_seqlen = 2, 7, 4
288+
per_batch = torch.zeros(
289+
B, max_seqlen, max_seqlen, dtype=torch.bool, device="cuda"
290+
)
291+
for b in range(B):
292+
sl = (offsets[b + 1] - offsets[b]).item()
293+
per_batch[b, :sl, :sl] = torch.tril(
294+
torch.ones(sl, sl, dtype=torch.bool, device="cuda")
295+
)
296+
self._both(per_batch, offsets, total)
297+
298+
@pytest.mark.parametrize("beam_width", [2, 3])
299+
@pytest.mark.parametrize("candidate_len", [1, 3])
300+
def test_target_grouped(self, beam_width, candidate_len):
301+
hist_lens = torch.tensor([5, 3], device="cuda")
302+
max_hist = 5
303+
inverted = padded_target_aware_causal_mask(
304+
hist_lens, max_hist, beam_width, candidate_len
305+
)
306+
valid = ~inverted
307+
total_per_batch = (hist_lens + beam_width * candidate_len).tolist()
308+
offsets = torch.tensor(
309+
[0] + [sum(total_per_batch[: i + 1]) for i in range(2)],
310+
device="cuda",
311+
)
312+
total = offsets[-1].item()
313+
self._both(valid, offsets, total)
314+
315+
def test_all_zero_mask(self):
316+
offsets = torch.tensor([0, 4, 8], device="cuda")
317+
valid = torch.zeros(2, 4, 4, dtype=torch.bool, device="cuda")
318+
self._both(valid, offsets, 8)
319+
320+
def test_multi_interval_per_row(self):
321+
# Row 2 has TWO disjoint intervals (gap mid-row); exercises iv > 1.
322+
offsets = torch.tensor([0, 5], device="cuda")
323+
valid = torch.zeros(1, 5, 5, dtype=torch.bool, device="cuda")
324+
valid[0, 2, 0] = True
325+
valid[0, 2, 1] = True
326+
valid[0, 2, 3] = True
327+
valid[0, 2, 4] = True
328+
self._both(valid, offsets, 5)
329+
330+
def test_uneven_seq_lens(self):
331+
offsets = torch.tensor([0, 2, 8, 9], device="cuda")
332+
B, total, max_seqlen = 3, 9, 6
333+
per_batch = torch.zeros(
334+
B, max_seqlen, max_seqlen, dtype=torch.bool, device="cuda"
335+
)
336+
for b in range(B):
337+
sl = (offsets[b + 1] - offsets[b]).item()
338+
per_batch[b, :sl, :sl] = torch.tril(
339+
torch.ones(sl, sl, dtype=torch.bool, device="cuda")
340+
)
341+
self._both(per_batch, offsets, total)

0 commit comments

Comments
 (0)