Skip to content

Short sequence prefix-invariant evo2 implementation#1580

Open
jstjohn wants to merge 15 commits into
mainfrom
jstjohn/prefix_invariance_evo2
Open

Short sequence prefix-invariant evo2 implementation#1580
jstjohn wants to merge 15 commits into
mainfrom
jstjohn/prefix_invariance_evo2

Conversation

@jstjohn
Copy link
Copy Markdown
Collaborator

@jstjohn jstjohn commented May 22, 2026

Description

Changes:

  • codex added to top level devcontainer
  • bump causal-conv1d, megatron-bridge, and associated dependencies
  • add test coverage for prefix invariance when running evo2 on very short sequences through inference and training

jstjohn added 2 commits May 22, 2026 11:38
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 22, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: af8a8900-b869-4fb8-8fd3-3b44ccad7dda

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jstjohn/prefix_invariance_evo2

Comment @coderabbitai help to get the list of available commands and usage tips.

@jstjohn jstjohn requested review from farhadrgh and moradza May 22, 2026 18:54
Copy link
Copy Markdown
Collaborator

@farhadrgh farhadrgh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two bugs, two questions, and the last section flags that this PR regresses subq-ops inference support that already landed in #1565.

Bugs

1. hyena_utils.fftconv_func fix is incomplete, bidirectional path still broken.

The fix lands only inside the else: # causal branch:

if use_subquadratic_ops:
    y = fft_causal_conv1d(u, k.squeeze(0))
else:
    fft_size = max(fft_size, 2 * k.shape[-1])   # <-- only here
    k_f = torch.fft.rfft(k, n=fft_size) / fft_size

The if bidirectional: branch immediately above still does torch.fft.rfft(k, n=fft_size) with the original fft_size = 2 * seqlen. Same truncation bug if anyone runs the bidirectional path with seqlen < K. Suggest hoisting the max(...) line to right after fft_size = 2 * seqlen so both branches benefit.

2. The short-filter causal_conv1d subq path was reverted, but the xfail only covers the fused B2B path.

Two separate code paths got removed in this PR, but the xfail (test_b2b_causal_conv1d_module_matches_sequential_reference) only documents one:

  • engine.parallel_fir lost its if use_subquadratic_ops: _subq_causal_conv1d(...) arm in the < 128 branch.
  • ParallelCausalDepthwiseConv1d.forward now always uses causal_conv1d_fn instead of dispatching to subq when use_subquadratic_ops=True.

Neither of those is the fused B2B kernel, they're plain depthwise short-filter convolutions. Is the issue actually with subq's causal_conv1d under causal-conv1d 1.6+, or did these get caught in the same revert? If it's the latter, worth keeping them — they're the easier speedup with no fusion semantics to verify.

Questions

3. @torch.compile removed from ImplicitModalFilter.filter does the comment refer to a specific reproducer? A pointer in the comment would help future readers, and if the bad-interaction scope is narrow we may be able to keep @torch.compile with dynamic=False or wrap the offending call site in torch.compiler.disable instead of dropping it altogether.

4. hyena_block.py variable-arity get_cpu_offload_context call, clean fix for the 6-vs-7-arg drift, but len(inspect.signature(...).parameters) is a brittle proxy (it counts a *args parameter as 1, which would silently break the slice). Worth a # tied to MCore <= 0.x note so future readers know to revisit if MCore changes the signature again.

Regression of #1565 (already on main)

This PR removes the two inference subq-ops code paths that landed in #1565 (merged 2026-04-30):

  • engine.parallel_fir short branch: the if use_subquadratic_ops: _subq_causal_conv1d(...) arm from #1565 is removed (item 2 above).
  • HyenaMixer.forward prefill: #1565 added _populate_b2b_inference_state and gated the fused b2b kernel on use_subquadratic_ops. This PR forces the gate off via self.use_fused_b2b_causal_conv1d = False (hardcoded), so the fused path can never fire even when the user passes --use-subquadratic-ops. This also disables the original training and predict_evo2 b2b path that predates #1565.

Net effect for infer_evo2 --use-subquadratic-ops after this PR lands:

  • The flag still routes long-filter FFT convs through subq-ops (_subq_fft_causal_conv1d), so the existing test_subquadratic_ops_matches_baseline correctness test will still pass.
  • But the short-filter and fused-B2B prefill paths are gone, so the measured ~15% prefill speedup at 8K prompt on the 1B model (single A6000) goes back to zero. Users get the CLI flag without the performance it was added for.

I get why this is happening, the xfail in test_hyena_utils.py shows the fused B2B kernel doesn't match the reference under causal-conv1d 1.6+. That's a real kernel-side bug. But two things:

(a) The fix for the fused-B2B mismatch shouldn't take out the short-filter causal_conv1d path too. They're independent (see item 2 above). If the subq short-filter kernel is also broken under 1.6+, a passing/failing test would clarify; if it isn't broken, please keep that path.

(b) Disabling the fused B2B path is reasonable as a temporary measure, but hardcoding the flag to False makes the regression permanent until someone re-edits the file. Please make it a real config attribute so it can be flipped back on once subquadratic-ops ships the 1.6+ fix, without another PR. Suggested:

self.use_fused_b2b_causal_conv1d = getattr(
    transformer_config, "use_fused_b2b_causal_conv1d", False
)

That way #1565's runtime behavior is recoverable via config, and we don't lose the speedup permanently. (And anyone hitting a predict_evo2 perf regression after this lands can re-enable it for the training/predict path independently.)

jstjohn added 11 commits May 22, 2026 12:51
…nd fail loudly if the CUDA_ERROR_UNSUPPORTED_PTX_VERSION error comes up

Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>

def _linear_causal_fft_size(input_len: int, filter_len: int) -> int:
"""Return an FFT size that cannot alias a causal convolution prefix."""
return max(2 * input_len, 2 * filter_len)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can do this,

if filter_len <= 2* input_len:
   return min(input_len + filter_len, 2 * filter_len)
return 2 * max(input_len, filter_len)

):
"""Compute parallel finite impulse response filtering with optional state computation."""
L = u.shape[1] # noqa: N806
u = rearrange(u, "b l d -> b d l")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use subquadratic_ops rearrange here. rearrange

"""Apply a 1D convolution to the input sequence u using the filter k and the shortcut D."""
seqlen = u.shape[-1]
fft_size = 2 * seqlen
fft_size = max(2 * seqlen, 2 * k.shape[-1])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as first fft_size selection algorithm.

Signed-off-by: John St. John <jstjohn@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants