Commit 5497faa
authored
Support subquadratic-ops kernels in evo2 autoregressive inference (#1565)
### Description
Closes the gap noted in `hyena_mixer.py` (`# todo: support
inference_context for b2b_kernel`) and the README caveat that
`--use-subquadratic-ops` "does not apply to autoregressive inference
(`infer_evo2`)". After this PR, the same fused kernels that accelerate
training and batch prediction also accelerate the prefill phase of
autoregressive inference.
Summary of change:
1. **`engine.parallel_fir`** now accepts `use_subquadratic_ops` and
routes to `fft_causal_conv1d` (filters ≥ 128) or `causal_conv1d` (short
filters), wired through both call sites in `hyena_utils.py`.
2. **`HyenaMixer.forward`** detects prefill (no FIR cache yet) and runs
`b2b_causal_conv1d` for the fused proj+mixer convolution. The kernel
doesn't expose its intermediate, so we run a tiny windowed proj-conv on
the last `K_proj + K_mixer − 2` input positions to materialize the
`(x2*v)` tail and seed the mixer's FIR cache. Works for both
`hyena_short_conv` and `hyena_medium_conv`.
3. Removed the `del self._parameters["short_conv_weight"]`
micro-optimization in
`ParallelCausalDepthwiseConv1dWithState._get_weight()` —
`B2BCausalConv1dModule` reads that raw param on every prefill, so
deleting it after first decode broke multi-prompt inference. Memory cost
is ~4 MB for a 1B model.
`infer_evo2` gets a `--use-subquadratic-ops` flag.
## Testing
- New parametrization
`test_forward_manual[1b-8k-bf16-subquadratic-ops-flash]` covers the
`(flash_decode=True, subquadratic_ops=True)` combination that was
previously skipped.
- New `test_subquadratic_ops_matches_baseline` runs greedy
autoregressive generation with and without `--use-subquadratic-ops` and
asserts identical output — this is the strict check that Phase 2 state
population is correct (a wrong cache would diverge during decode).
- Existing kernel comparison tests (`test_hyena_mixer_kernel.py`) and
inference-context unit tests pass unchanged.
## Performance
`infer_evo2`, evo2/1b-8k-bf16, single A6000, multiple identical prompts
in one process to amortize the one-time JIT compile cost (~15 s the
first time each subq-ops kernel sees a new shape). Steady-state numbers
from batches 3+:
| Prompt | Generation | Baseline | Subq-ops | Speedup |
|---|---|---|---|---|
| 4 096 tokens | 5 tokens | 0.57 s | 0.51 s | ~10% |
| 8 000 tokens | 1 token | 1.02 s | 0.87 s | ~15% |
The speedup is concentrated in prefill. The relative improvement grows
with prompt length and shrinks as more decode tokens are amortized in.
### Type of changes
<!-- Mark the relevant option with an [x] -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Refactor
- [x] Documentation update
- [ ] Other (please describe):
### CI Pipeline Configuration
Configure CI behavior by applying the relevant labels. By default, only
basic unit tests are run.
-
[ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip)
- Skip all CI tests for this PR
-
[ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks)
- Run Jupyter notebooks execution tests
-
[ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow)
- Run slow single GPU integration tests marked as @pytest.mark.slow
-
[ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all)
- Run all tests (unit tests, slow tests, and notebooks). This label can
be used to enforce running all framework tests.
-
[ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes)
- Run tests for all recipes (under bionemo-recipes). This label can be
used to enforce running tests for all recipes.
Unit tests marked as `@pytest.mark.multi_gpu` or
`@pytest.mark.distributed` are not run in the PR pipeline.
For more details, see [CONTRIBUTING](CONTRIBUTING.md)
> [!NOTE]
> By default, only basic unit tests are run. Add appropriate labels to
enable an additional test coverage.
#### Authorizing CI Runs
We use
[copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation)
to manage authorization of CI
runs on NVIDIA's compute resources.
- If a pull request is opened by a trusted user and contains only
trusted changes, the pull request's code will
automatically be copied to a pull-request/ prefixed branch in the source
repository (e.g. pull-request/123)
- If a pull request is opened by an untrusted user or contains untrusted
changes, an NVIDIA org member must leave an
`/ok to test` comment on the pull request to trigger CI. This will need
to be done for each new commit.
#### Triggering Code Rabbit AI Review
To trigger a code review from code rabbit, comment on a pull request
with one of these commands:
- @coderabbitai review - Triggers a standard review
- @coderabbitai full review - Triggers a comprehensive review
See https://docs.coderabbit.ai/reference/review-commands for a full list
of commands.
### Pre-submit Checklist
- [x] I have tested these changes locally
- [x] I have updated the documentation accordingly
- [x] I have added/updated tests as needed
- [x] All existing tests pass successfully
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added `--use-subquadratic-ops` CLI option to optimize prompt/prefill
processing during inference while leaving per-token decode unchanged.
* **Documentation**
* Clarified subquadratic-ops kernel behavior and performance impact on
prefill throughput.
* **Tests**
* Added end-to-end test confirming subquadratic-ops generates identical
inference results as baseline.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>1 parent 31a9e62 commit 5497faa
7 files changed
Lines changed: 199 additions & 31 deletions
File tree
- bionemo-recipes/recipes/evo2_megatron
- src/bionemo/evo2
- models/megatron/hyena
- run
- tests/bionemo/evo2
- run
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
67 | 67 | | |
68 | 68 | | |
69 | 69 | | |
70 | | - | |
71 | | - | |
72 | | - | |
73 | | - | |
74 | | - | |
75 | | - | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
| |||
97 | 97 | | |
98 | 98 | | |
99 | 99 | | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
100 | 103 | | |
101 | 104 | | |
102 | 105 | | |
| |||
Lines changed: 43 additions & 14 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
23 | 32 | | |
24 | 33 | | |
25 | 34 | | |
| |||
63 | 72 | | |
64 | 73 | | |
65 | 74 | | |
| 75 | + | |
66 | 76 | | |
67 | 77 | | |
68 | 78 | | |
69 | 79 | | |
70 | 80 | | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
71 | 84 | | |
72 | | - | |
73 | | - | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
78 | 100 | | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
87 | 116 | | |
88 | 117 | | |
89 | 118 | | |
| |||
Lines changed: 70 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
307 | 308 | | |
308 | 309 | | |
309 | 310 | | |
310 | | - | |
311 | | - | |
312 | | - | |
313 | | - | |
314 | | - | |
315 | | - | |
316 | | - | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
317 | 322 | | |
| 323 | + | |
| 324 | + | |
318 | 325 | | |
319 | 326 | | |
320 | 327 | | |
| |||
330 | 337 | | |
331 | 338 | | |
332 | 339 | | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
Lines changed: 8 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1051 | 1051 | | |
1052 | 1052 | | |
1053 | 1053 | | |
| 1054 | + | |
1054 | 1055 | | |
1055 | 1056 | | |
1056 | 1057 | | |
| |||
1656 | 1657 | | |
1657 | 1658 | | |
1658 | 1659 | | |
1659 | | - | |
| 1660 | + | |
| 1661 | + | |
| 1662 | + | |
| 1663 | + | |
| 1664 | + | |
| 1665 | + | |
1660 | 1666 | | |
1661 | 1667 | | |
1662 | 1668 | | |
1663 | 1669 | | |
1664 | | - | |
1665 | 1670 | | |
1666 | 1671 | | |
1667 | 1672 | | |
| |||
1697 | 1702 | | |
1698 | 1703 | | |
1699 | 1704 | | |
| 1705 | + | |
1700 | 1706 | | |
1701 | 1707 | | |
1702 | 1708 | | |
| |||
Lines changed: 17 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
358 | 358 | | |
359 | 359 | | |
360 | 360 | | |
| 361 | + | |
361 | 362 | | |
362 | 363 | | |
363 | 364 | | |
| |||
379 | 380 | | |
380 | 381 | | |
381 | 382 | | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
382 | 386 | | |
383 | 387 | | |
384 | 388 | | |
| |||
412 | 416 | | |
413 | 417 | | |
414 | 418 | | |
| 419 | + | |
415 | 420 | | |
416 | 421 | | |
417 | 422 | | |
| |||
808 | 813 | | |
809 | 814 | | |
810 | 815 | | |
| 816 | + | |
| 817 | + | |
| 818 | + | |
| 819 | + | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
| 823 | + | |
811 | 824 | | |
812 | 825 | | |
813 | 826 | | |
| |||
831 | 844 | | |
832 | 845 | | |
833 | 846 | | |
| 847 | + | |
834 | 848 | | |
835 | 849 | | |
836 | 850 | | |
| |||
858 | 872 | | |
859 | 873 | | |
860 | 874 | | |
| 875 | + | |
861 | 876 | | |
862 | 877 | | |
863 | 878 | | |
| |||
878 | 893 | | |
879 | 894 | | |
880 | 895 | | |
| 896 | + | |
881 | 897 | | |
882 | 898 | | |
883 | 899 | | |
| |||
1003 | 1019 | | |
1004 | 1020 | | |
1005 | 1021 | | |
| 1022 | + | |
1006 | 1023 | | |
1007 | 1024 | | |
1008 | 1025 | | |
| |||
Lines changed: 45 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
284 | 284 | | |
285 | 285 | | |
286 | 286 | | |
| 287 | + | |
287 | 288 | | |
288 | 289 | | |
289 | 290 | | |
| |||
295 | 296 | | |
296 | 297 | | |
297 | 298 | | |
| 299 | + | |
298 | 300 | | |
299 | 301 | | |
300 | 302 | | |
| |||
326 | 328 | | |
327 | 329 | | |
328 | 330 | | |
| 331 | + | |
| 332 | + | |
329 | 333 | | |
330 | 334 | | |
331 | 335 | | |
| |||
517 | 521 | | |
518 | 522 | | |
519 | 523 | | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
520 | 565 | | |
521 | 566 | | |
522 | 567 | | |
| |||
0 commit comments