Short sequence prefix-invariant evo2 implementation#1580
Conversation
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
farhadrgh
left a comment
There was a problem hiding this comment.
Two bugs, two questions, and the last section flags that this PR regresses subq-ops inference support that already landed in #1565.
Bugs
1. hyena_utils.fftconv_func fix is incomplete, bidirectional path still broken.
The fix lands only inside the else: # causal branch:
if use_subquadratic_ops:
y = fft_causal_conv1d(u, k.squeeze(0))
else:
fft_size = max(fft_size, 2 * k.shape[-1]) # <-- only here
k_f = torch.fft.rfft(k, n=fft_size) / fft_sizeThe if bidirectional: branch immediately above still does torch.fft.rfft(k, n=fft_size) with the original fft_size = 2 * seqlen. Same truncation bug if anyone runs the bidirectional path with seqlen < K. Suggest hoisting the max(...) line to right after fft_size = 2 * seqlen so both branches benefit.
2. The short-filter causal_conv1d subq path was reverted, but the xfail only covers the fused B2B path.
Two separate code paths got removed in this PR, but the xfail (test_b2b_causal_conv1d_module_matches_sequential_reference) only documents one:
engine.parallel_firlost itsif use_subquadratic_ops: _subq_causal_conv1d(...)arm in the< 128branch.ParallelCausalDepthwiseConv1d.forwardnow always usescausal_conv1d_fninstead of dispatching to subq whenuse_subquadratic_ops=True.
Neither of those is the fused B2B kernel, they're plain depthwise short-filter convolutions. Is the issue actually with subq's causal_conv1d under causal-conv1d 1.6+, or did these get caught in the same revert? If it's the latter, worth keeping them — they're the easier speedup with no fusion semantics to verify.
Questions
3. @torch.compile removed from ImplicitModalFilter.filter does the comment refer to a specific reproducer? A pointer in the comment would help future readers, and if the bad-interaction scope is narrow we may be able to keep @torch.compile with dynamic=False or wrap the offending call site in torch.compiler.disable instead of dropping it altogether.
4. hyena_block.py variable-arity get_cpu_offload_context call, clean fix for the 6-vs-7-arg drift, but len(inspect.signature(...).parameters) is a brittle proxy (it counts a *args parameter as 1, which would silently break the slice). Worth a # tied to MCore <= 0.x note so future readers know to revisit if MCore changes the signature again.
Regression of #1565 (already on main)
This PR removes the two inference subq-ops code paths that landed in #1565 (merged 2026-04-30):
engine.parallel_firshort branch: theif use_subquadratic_ops: _subq_causal_conv1d(...)arm from #1565 is removed (item 2 above).HyenaMixer.forwardprefill: #1565 added_populate_b2b_inference_stateand gated the fused b2b kernel onuse_subquadratic_ops. This PR forces the gate off viaself.use_fused_b2b_causal_conv1d = False(hardcoded), so the fused path can never fire even when the user passes--use-subquadratic-ops. This also disables the original training andpredict_evo2b2b path that predates #1565.
Net effect for infer_evo2 --use-subquadratic-ops after this PR lands:
- The flag still routes long-filter FFT convs through subq-ops (
_subq_fft_causal_conv1d), so the existingtest_subquadratic_ops_matches_baselinecorrectness test will still pass. - But the short-filter and fused-B2B prefill paths are gone, so the measured ~15% prefill speedup at 8K prompt on the 1B model (single A6000) goes back to zero. Users get the CLI flag without the performance it was added for.
I get why this is happening, the xfail in test_hyena_utils.py shows the fused B2B kernel doesn't match the reference under causal-conv1d 1.6+. That's a real kernel-side bug. But two things:
(a) The fix for the fused-B2B mismatch shouldn't take out the short-filter causal_conv1d path too. They're independent (see item 2 above). If the subq short-filter kernel is also broken under 1.6+, a passing/failing test would clarify; if it isn't broken, please keep that path.
(b) Disabling the fused B2B path is reasonable as a temporary measure, but hardcoding the flag to False makes the regression permanent until someone re-edits the file. Please make it a real config attribute so it can be flipped back on once subquadratic-ops ships the 1.6+ fix, without another PR. Suggested:
self.use_fused_b2b_causal_conv1d = getattr(
transformer_config, "use_fused_b2b_causal_conv1d", False
)That way #1565's runtime behavior is recoverable via config, and we don't lose the speedup permanently. (And anyone hitting a predict_evo2 perf regression after this lands can re-enable it for the training/predict path independently.)
…nd fail loudly if the CUDA_ERROR_UNSUPPORTED_PTX_VERSION error comes up Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
|
|
||
| def _linear_causal_fft_size(input_len: int, filter_len: int) -> int: | ||
| """Return an FFT size that cannot alias a causal convolution prefix.""" | ||
| return max(2 * input_len, 2 * filter_len) |
There was a problem hiding this comment.
Here you can do this,
if filter_len <= 2* input_len:
return min(input_len + filter_len, 2 * filter_len)
return 2 * max(input_len, filter_len)
| ): | ||
| """Compute parallel finite impulse response filtering with optional state computation.""" | ||
| L = u.shape[1] # noqa: N806 | ||
| u = rearrange(u, "b l d -> b d l") |
There was a problem hiding this comment.
we can use subquadratic_ops rearrange here. rearrange
| """Apply a 1D convolution to the input sequence u using the filter k and the shortcut D.""" | ||
| seqlen = u.shape[-1] | ||
| fft_size = 2 * seqlen | ||
| fft_size = max(2 * seqlen, 2 * k.shape[-1]) |
There was a problem hiding this comment.
same as first fft_size selection algorithm.
Signed-off-by: John St. John <jstjohn@nvidia.com>
Description
Changes: