Commit ed60ab2
committed
perf(attention_mask): vectorise dense_mask_to_jagged_arbitrary_func
Replace the per-row Python loop with a cumsum + nonzero scatter so the
function issues a single host sync (for `max_intervals`) instead of one
per row × per interval × per .item() call.
Why
---
Greptile flagged this as P1: the loop has 4 host-syncing ops in the
inner body — `row.any()`, two `.nonzero()` materialisations, and
`start_pos[iv].item()` / `end_pos[iv].item()`. For B=64, seqlen=1024,
~2 intervals/row, that's ≈500 k forced GPU→CPU syncs per call. The
function is on the jagged-FA fallback path in `SIDGRModel.decoder_step`
(when the caller passes a dense `attention_mask` instead of a
prebuilt `arbitrary_func`), so this dominates training step time on
that path.
How
---
- `starts` / `ends` boundary detection was already vectorised; keep
that.
- Mask out positions outside each sample's `[0, seq_len)` so padded
rows/cols don't produce spurious intervals.
- `starts.cumsum(dim=-1)` assigns each transition a 1-based interval
index without any sync.
- `starts.nonzero()` gives all (b, q, k) coordinates in one shot; index
into `af` via vectorised assignment. One nonzero call per side
replaces ~N × seq_len of them.
- Same for `ends`, with the existing `+1` (exclusive) offset preserved.
Verification
------------
Add `TestDenseMaskToJaggedVectorisedMatchesLoop` comparing the new
vectorised path against the existing loop-based test helper across:
jagged causal, target-grouped (4 beam_width × candidate_len cases),
all-zero mask, multi-interval per row, uneven seq_lens.
Local: 27/27 pass (was 20), pre-commit clean, no behaviour change for
the existing 20 tests.
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>1 parent 5efdadc commit ed60ab2
2 files changed
Lines changed: 117 additions & 26 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
447 | 447 | | |
448 | 448 | | |
449 | 449 | | |
450 | | - | |
| 450 | + | |
451 | 451 | | |
452 | 452 | | |
453 | 453 | | |
| |||
457 | 457 | | |
458 | 458 | | |
459 | 459 | | |
460 | | - | |
| 460 | + | |
461 | 461 | | |
462 | 462 | | |
463 | 463 | | |
464 | 464 | | |
465 | 465 | | |
466 | | - | |
467 | | - | |
468 | | - | |
469 | | - | |
470 | | - | |
471 | | - | |
472 | | - | |
473 | | - | |
474 | | - | |
475 | | - | |
476 | | - | |
477 | | - | |
478 | | - | |
479 | | - | |
480 | | - | |
481 | | - | |
482 | | - | |
483 | | - | |
484 | | - | |
485 | | - | |
486 | | - | |
487 | | - | |
488 | | - | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
489 | 503 | | |
490 | 504 | | |
491 | 505 | | |
| |||
Lines changed: 78 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
30 | 29 | | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
31 | 34 | | |
32 | 35 | | |
33 | 36 | | |
| |||
262 | 265 | | |
263 | 266 | | |
264 | 267 | | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
0 commit comments