Skip to content

Commit 5497faa

Browse files
authored
Support subquadratic-ops kernels in evo2 autoregressive inference (#1565)
### Description Closes the gap noted in `hyena_mixer.py` (`# todo: support inference_context for b2b_kernel`) and the README caveat that `--use-subquadratic-ops` "does not apply to autoregressive inference (`infer_evo2`)". After this PR, the same fused kernels that accelerate training and batch prediction also accelerate the prefill phase of autoregressive inference. Summary of change: 1. **`engine.parallel_fir`** now accepts `use_subquadratic_ops` and routes to `fft_causal_conv1d` (filters ≥ 128) or `causal_conv1d` (short filters), wired through both call sites in `hyena_utils.py`. 2. **`HyenaMixer.forward`** detects prefill (no FIR cache yet) and runs `b2b_causal_conv1d` for the fused proj+mixer convolution. The kernel doesn't expose its intermediate, so we run a tiny windowed proj-conv on the last `K_proj + K_mixer − 2` input positions to materialize the `(x2*v)` tail and seed the mixer's FIR cache. Works for both `hyena_short_conv` and `hyena_medium_conv`. 3. Removed the `del self._parameters["short_conv_weight"]` micro-optimization in `ParallelCausalDepthwiseConv1dWithState._get_weight()` — `B2BCausalConv1dModule` reads that raw param on every prefill, so deleting it after first decode broke multi-prompt inference. Memory cost is ~4 MB for a 1B model. `infer_evo2` gets a `--use-subquadratic-ops` flag. ## Testing - New parametrization `test_forward_manual[1b-8k-bf16-subquadratic-ops-flash]` covers the `(flash_decode=True, subquadratic_ops=True)` combination that was previously skipped. - New `test_subquadratic_ops_matches_baseline` runs greedy autoregressive generation with and without `--use-subquadratic-ops` and asserts identical output — this is the strict check that Phase 2 state population is correct (a wrong cache would diverge during decode). - Existing kernel comparison tests (`test_hyena_mixer_kernel.py`) and inference-context unit tests pass unchanged. ## Performance `infer_evo2`, evo2/1b-8k-bf16, single A6000, multiple identical prompts in one process to amortize the one-time JIT compile cost (~15 s the first time each subq-ops kernel sees a new shape). Steady-state numbers from batches 3+: | Prompt | Generation | Baseline | Subq-ops | Speedup | |---|---|---|---|---| | 4 096 tokens | 5 tokens | 0.57 s | 0.51 s | ~10% | | 8 000 tokens | 1 token | 1.02 s | 0.87 s | ~15% | The speedup is concentrated in prefill. The relative improvement grows with prompt length and shrinks as more decode tokens are amortized in. ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Refactor - [x] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks). This label can be used to enforce running all framework tests. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. #### Triggering Code Rabbit AI Review To trigger a code review from code rabbit, comment on a pull request with one of these commands: - @coderabbitai review - Triggers a standard review - @coderabbitai full review - Triggers a comprehensive review See https://docs.coderabbit.ai/reference/review-commands for a full list of commands. ### Pre-submit Checklist - [x] I have tested these changes locally - [x] I have updated the documentation accordingly - [x] I have added/updated tests as needed - [x] All existing tests pass successfully <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--use-subquadratic-ops` CLI option to optimize prompt/prefill processing during inference while leaving per-token decode unchanged. * **Documentation** * Clarified subquadratic-ops kernel behavior and performance impact on prefill throughput. * **Tests** * Added end-to-end test confirming subquadratic-ops generates identical inference results as baseline. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
1 parent 31a9e62 commit 5497faa

7 files changed

Lines changed: 199 additions & 31 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ torchrun --nproc-per-node 2 --no-python \
6767
--use-subquadratic-ops
6868
```
6969

70-
> **Tip:** The `--use-subquadratic-ops` flag enables a fused back-to-back
71-
> causal convolution CUDA kernel for the Hyena short-conv layers. This
72-
> provides a meaningful speed-up for training and prediction and is
73-
> recommended for all production runs. It does not apply to autoregressive
74-
> inference (`infer_evo2`). There is a one-time compilation cost on first
75-
> use.
70+
> **Tip:** The `--use-subquadratic-ops` flag enables fused subquadratic-ops
71+
> CUDA kernels (`b2b_causal_conv1d` for proj+mixer fusion in prefill,
72+
> `fft_causal_conv1d` / `causal_conv1d` inside `engine.parallel_fir`). It
73+
> applies to training, batch prediction (`predict_evo2`), and the prefill
74+
> phase of autoregressive inference (`infer_evo2`); per-token decode is
75+
> already in optimal recurrent form and is unaffected.
7676
7777
### Autoregressive generation (`infer_evo2`)
7878

@@ -97,6 +97,9 @@ Options:
9797
- `--top-k` / `--top-p` — top-k or nucleus sampling (0 = disabled).
9898
- `--tensor-parallel-size` — tensor parallelism for large models (default: 1).
9999
- `--max-seq-length` — maximum sequence length (default: 8192).
100+
- `--use-subquadratic-ops` — use fused subquadratic-ops kernels for prefill
101+
(b2b causal conv, FFT/causal conv1d in `parallel_fir`). Recommended when
102+
processing many prompts in one process.
100103

101104
### Batch sequence scoring (`predict_evo2`)
102105

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@
2020
from einops import rearrange
2121

2222

23+
try:
24+
from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d
25+
from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d
26+
except ImportError as _subq_import_error:
27+
_subq_causal_conv1d = None
28+
_subq_fft_causal_conv1d = None
29+
_subq_error_msg = f"subquadratic_ops_torch not available: {_subq_import_error}"
30+
31+
2332
def adjust_filter_shape_for_broadcast(u, h):
2433
"""Adjust filter shape for broadcasting compatibility with input tensor."""
2534
h = h.squeeze() # Standardize to [D, L] from [1, D, L] and [D, 1, L]
@@ -63,27 +72,47 @@ def parallel_fir(
6372
gated_bias,
6473
fir_length,
6574
compute_state,
75+
use_subquadratic_ops=False,
6676
):
6777
"""Compute parallel finite impulse response filtering with optional state computation."""
6878
L = u.shape[1] # noqa: N806
6979
u = rearrange(u, "b l d -> b d l")
7080

81+
if use_subquadratic_ops and _subq_fft_causal_conv1d is None:
82+
raise ImportError(_subq_error_msg)
83+
7184
if fir_length >= 128:
72-
with torch.autocast("cuda"):
73-
z = fftconv_func(
74-
u=u.to(torch.float32),
75-
k=weight[:, :, :L].to(torch.float32),
76-
D=bias,
77-
).to(dtype=u.dtype)
85+
if use_subquadratic_ops:
86+
# subq-ops fft_causal_conv1d expects [B, D, L] input and [D, L] filter; dtypes must match
87+
k = weight[:, :, :L].squeeze(1) if weight.dim() == 3 else weight[:, :L]
88+
u_fp32 = u.to(torch.float32)
89+
z = _subq_fft_causal_conv1d(u_fp32, k.to(torch.float32))
90+
if bias is not None:
91+
z = z + u_fp32 * bias.unsqueeze(-1)
92+
z = z.to(u.dtype)
93+
else:
94+
with torch.autocast("cuda"):
95+
z = fftconv_func(
96+
u=u.to(torch.float32),
97+
k=weight[:, :, :L].to(torch.float32),
98+
D=bias,
99+
).to(dtype=u.dtype)
78100
else:
79-
z = F.conv1d(
80-
u.to(torch.float32),
81-
weight.to(torch.float32),
82-
bias=None,
83-
stride=1,
84-
padding=fir_length - 1,
85-
groups=u.shape[1], # always set to D, regardless of filter grouping
86-
)[..., :L]
101+
if use_subquadratic_ops:
102+
# subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight; dtypes must match
103+
pad_size = fir_length - 1
104+
x_padded = F.pad(u.to(torch.float32), (pad_size, 0))
105+
w = weight.squeeze(1) if weight.dim() == 3 else weight
106+
z = _subq_causal_conv1d(x_padded, w.to(torch.float32))[..., pad_size:]
107+
else:
108+
z = F.conv1d(
109+
u.to(torch.float32),
110+
weight.to(torch.float32),
111+
bias=None,
112+
stride=1,
113+
padding=fir_length - 1,
114+
groups=u.shape[1], # always set to D, regardless of filter grouping
115+
)[..., :L]
87116

88117
z = z.to(u.dtype)
89118

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import torch
2424
import torch.nn as nn
25+
import torch.nn.functional as F # noqa: N812
2526
from einops import rearrange
2627
from megatron.core.process_groups_config import ProcessGroupCollection
2728
from megatron.core.transformer.module import MegatronModule
@@ -307,14 +308,20 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True
307308
else:
308309
features = rearrange(features, "l b d -> b d l").contiguous()
309310

310-
if (
311-
self.use_subquadratic_ops
312-
and self.operator_type in ["hyena_short_conv", "hyena_medium_conv"]
313-
and inference_context is None
314-
):
315-
# todo: support inference_context for b2b_kernel
316-
# Use the B2BCausalConv1dModule wrapper with the existing weights from the original model
311+
is_b2b_eligible = self.use_subquadratic_ops and self.operator_type in [
312+
"hyena_short_conv",
313+
"hyena_medium_conv",
314+
]
315+
# b2b runs during training (no inference_context) or during prefill (no FIR cache yet).
316+
# During decode (cache populated, L=1) we fall back to the regular per-token step path.
317+
is_prefill = inference_context is not None and id(self.hyena_proj_conv) not in getattr(
318+
inference_context, "fir_filter_state_dict", {}
319+
)
320+
321+
if is_b2b_eligible and (inference_context is None or is_prefill):
317322
z = self.b2b_kernel(features, _use_cp=_proj_use_cp)
323+
if is_prefill:
324+
self._populate_b2b_inference_state(features, inference_context)
318325
else:
319326
features = self.hyena_proj_conv(
320327
features, _use_cp=_proj_use_cp, inference_context=inference_context
@@ -330,3 +337,59 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True
330337
z = rearrange(z, "b d l -> l b d").contiguous()
331338
y, bias = self.dense(z)
332339
return y, bias
340+
341+
def _populate_b2b_inference_state(self, features, inference_context):
342+
"""Populate FIR state for proj_conv and mixer after a b2b prefill.
343+
344+
The b2b kernel doesn't expose its post-projection intermediate, but subsequent
345+
decode steps need (a) the proj_conv input tail and (b) the tail of `x2 * v`
346+
— the gated stream that mixer's short_conv operates on. We get (b) by running
347+
a windowed proj_conv on just the last (K_proj + K_mixer - 2) input positions.
348+
"""
349+
proj_kernel_size = self.hyena_proj_conv.kernel_size
350+
351+
# (a) proj_conv FIR state: input tail in [B, D, K_proj-1]
352+
proj_state = features[..., -(proj_kernel_size - 1) :].contiguous()
353+
proj_dict = getattr(inference_context, "fir_filter_state_dict", {})
354+
proj_dict[id(self.hyena_proj_conv)] = proj_state
355+
setattr(inference_context, "fir_filter_state_dict", proj_dict)
356+
357+
# (b) mixer FIR state: tail of (x2 * v), the gated post-projection stream
358+
if self.operator_type == "hyena_short_conv":
359+
mixer_kernel_size = self.mixer.short_conv.kernel_size
360+
else: # hyena_medium_conv
361+
mixer_kernel_size = self.mixer.kernel_size
362+
363+
tail_in_len = proj_kernel_size + mixer_kernel_size - 2
364+
if features.shape[-1] < tail_in_len:
365+
tail_in = F.pad(features, (tail_in_len - features.shape[-1], 0))
366+
else:
367+
tail_in = features[..., -tail_in_len:].contiguous()
368+
369+
# Reuse the cached transformed weight from get_weight() (lru_cache'd).
370+
proj_weight = self.hyena_proj_conv.get_weight()
371+
372+
intermediate = F.conv1d(
373+
F.pad(tail_in.to(torch.float32), (proj_kernel_size - 1, 0)),
374+
proj_weight,
375+
bias=None,
376+
stride=1,
377+
padding=0,
378+
groups=tail_in.shape[1],
379+
)[..., -(mixer_kernel_size - 1) :].to(features.dtype)
380+
381+
x1, x2, v = rearrange(intermediate, "b (g dg p) l -> b (g dg) p l", p=3, g=self.num_groups_per_tp_rank).unbind(
382+
dim=2
383+
)
384+
mixer_input_tail = (x2 * v).contiguous() # [B, D, K_mixer-1]
385+
386+
if self.operator_type == "hyena_short_conv":
387+
mixer_state_owner_id = id(self.mixer.short_conv)
388+
mixer_dict_key = "fir_filter_state_dict"
389+
else: # hyena_medium_conv
390+
mixer_state_owner_id = id(self.mixer)
391+
mixer_dict_key = "inner_fir_filter_state_dict"
392+
393+
mixer_dict = getattr(inference_context, mixer_dict_key, {})
394+
mixer_dict[mixer_state_owner_id] = mixer_input_tail
395+
setattr(inference_context, mixer_dict_key, mixer_dict)

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,7 @@ def get_filter_state(filter_name):
10511051
L=L,
10521052
fir_length=self.kernel_size, # self.short_filter_length,
10531053
compute_state=inference_context is not None,
1054+
use_subquadratic_ops=self.use_subquadratic_ops,
10541055
)
10551056
y = rearrange(y, "b d l -> b l d")
10561057
y = y * x1
@@ -1656,12 +1657,16 @@ def __init__(self, *args, **kwargs):
16561657
self.get_weight = lru_cache(maxsize=1)(self._get_weight)
16571658

16581659
def _get_weight(self):
1659-
"""Expand and cache the convolution weight, freeing the raw parameter."""
1660+
"""Expand and cache the convolution weight in inference-friendly form."""
1661+
# previously deleted self._parameters["short_conv_weight"] here as a
1662+
# memory micro-optimization, but the raw param is also read directly by
1663+
# B2BCausalConv1dModule on every prefill call. With subq-ops enabled in
1664+
# inference, the second prompt's b2b call fails after decode triggers
1665+
# this method on the first prompt
16601666
weight = self.short_conv_weight
16611667
if len(weight.shape) == 2:
16621668
weight = weight.unsqueeze(1)
16631669
weight = weight.repeat_interleave(self.group_dim, dim=0).to(torch.float32)
1664-
del self._parameters["short_conv_weight"]
16651670
return weight
16661671

16671672
def forward(self, x, inference_context=None, _use_cp=True): # noqa: D102
@@ -1697,6 +1702,7 @@ def get_filter_state(filter_name):
16971702
gated_bias=False,
16981703
fir_length=self.kernel_size, # self.short_filter_length,
16991704
compute_state=inference_context is not None,
1705+
use_subquadratic_ops=self.use_subquadratic_ops,
17001706
)
17011707
else:
17021708
if len(u.shape) > 2:

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def setup_inference_engine(
358358
vortex_style_fp8: bool = False,
359359
random_seed: int = 1234,
360360
prompt_segmentation_threshold: Optional[int] = None,
361+
use_subquadratic_ops: bool = False,
361362
) -> Evo2InferenceComponents:
362363
"""Setup the Evo2 inference engine and related components.
363364
@@ -379,6 +380,9 @@ def setup_inference_engine(
379380
segmented during prefill to reduce peak memory. The first segment
380381
runs as a normal prefill; remaining tokens are processed one at a
381382
time before generation begins.
383+
use_subquadratic_ops: Use fused subquadratic-ops kernels (b2b causal
384+
conv1d in prefill, fft_causal_conv1d / causal_conv1d in
385+
parallel_fir).
382386
383387
Returns:
384388
Evo2InferenceComponents containing all inference components.
@@ -412,6 +416,7 @@ def setup_inference_engine(
412416
model_provider.sequence_parallel = False
413417

414418
model_provider.flash_decode = True
419+
model_provider.use_subquadratic_ops = use_subquadratic_ops
415420

416421
if vortex_style_fp8:
417422
model_provider.vortex_style_fp8 = True
@@ -808,6 +813,14 @@ def parse_args() -> argparse.Namespace:
808813
"generation begins. Useful for long prompts that would otherwise OOM. "
809814
"Also settable via EVO2_PST env var.",
810815
)
816+
ap.add_argument(
817+
"--use-subquadratic-ops",
818+
action="store_true",
819+
default=False,
820+
help="Use fused subquadratic-ops CUDA kernels (b2b causal conv1d in prefill, "
821+
"fft_causal_conv1d / causal_conv1d in parallel_fir). Speeds up prompt processing "
822+
"but has no effect on per-token decode throughput.",
823+
)
811824

812825
return ap.parse_args()
813826

@@ -831,6 +844,7 @@ def infer(
831844
max_seq_length: int = 8192,
832845
max_batch_size: int = 1,
833846
prompt_segmentation_threshold: Optional[int] = None,
847+
use_subquadratic_ops: bool = False,
834848
) -> List[Dict[str, Any]]:
835849
"""Run autoregressive text generation with Evo2 using MCore inference.
836850
@@ -858,6 +872,7 @@ def infer(
858872
GPU memory proportional to this value. For large models, only 1 may fit.
859873
prompt_segmentation_threshold: If set, prompts longer than this are segmented
860874
during prefill to reduce peak memory.
875+
use_subquadratic_ops: Use fused subquadratic-ops kernels in the inference path.
861876
862877
Returns:
863878
List of JSONL-serialisable result dicts.
@@ -878,6 +893,7 @@ def infer(
878893
vortex_style_fp8=vortex_style_fp8,
879894
random_seed=random_seed,
880895
prompt_segmentation_threshold=prompt_segmentation_threshold,
896+
use_subquadratic_ops=use_subquadratic_ops,
881897
)
882898

883899
mem_after_setup_gb = torch.cuda.max_memory_allocated() / (1024**3)
@@ -1003,6 +1019,7 @@ def main() -> None:
10031019
max_seq_length=max_seq_length,
10041020
max_batch_size=args.max_batch_size,
10051021
prompt_segmentation_threshold=prompt_segmentation_threshold,
1022+
use_subquadratic_ops=args.use_subquadratic_ops,
10061023
)
10071024

10081025

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def run_infer_subprocess(
284284
temperature: float = 1.0,
285285
top_k: int = 1,
286286
seed: int = 42,
287+
use_subquadratic_ops: bool = False,
287288
):
288289
"""Helper function to run inference as a subprocess.
289290
@@ -295,6 +296,7 @@ def run_infer_subprocess(
295296
temperature: Sampling temperature
296297
top_k: Top-k sampling parameter (1 for greedy)
297298
seed: Random seed for reproducibility
299+
use_subquadratic_ops: Pass --use-subquadratic-ops to the CLI.
298300
299301
Returns:
300302
The generated completion text from the first JSONL record
@@ -326,6 +328,8 @@ def run_infer_subprocess(
326328
"--seed",
327329
str(seed),
328330
]
331+
if use_subquadratic_ops:
332+
cmd.append("--use-subquadratic-ops")
329333

330334
env = copy.deepcopy(PRETEST_ENV)
331335

@@ -517,6 +521,47 @@ def test_identical_prompts_should_be_identical(mbridge_checkpoint_path, tmp_path
517521
)
518522

519523

524+
def test_subquadratic_ops_matches_baseline(mbridge_checkpoint_path, tmp_path):
525+
"""Greedy generation with --use-subquadratic-ops must match the standard path.
526+
527+
This is the end-to-end correctness check for the subq-ops inference path:
528+
Phase 1 routes engine.parallel_fir through subq-ops kernels during prefill,
529+
Phase 2 fuses proj+mixer convs via b2b_causal_conv1d during prefill and
530+
populates FIR caches for the subsequent decode steps. With greedy decoding
531+
(top_k=1) and the same seed, both paths must produce identical output.
532+
"""
533+
output_baseline = tmp_path / "output_baseline.jsonl"
534+
output_subq = tmp_path / "output_subq.jsonl"
535+
536+
generated_baseline = run_infer_subprocess(
537+
mbridge_checkpoint_path,
538+
prompt=PROMPT_1,
539+
output_file=output_baseline,
540+
max_new_tokens=20,
541+
temperature=1.0,
542+
top_k=1,
543+
seed=42,
544+
use_subquadratic_ops=False,
545+
)
546+
547+
generated_subq = run_infer_subprocess(
548+
mbridge_checkpoint_path,
549+
prompt=PROMPT_1,
550+
output_file=output_subq,
551+
max_new_tokens=20,
552+
temperature=1.0,
553+
top_k=1,
554+
seed=42,
555+
use_subquadratic_ops=True,
556+
)
557+
558+
assert len(generated_baseline) > 0, "Baseline generation produced empty output"
559+
assert len(generated_subq) > 0, "Subq-ops generation produced empty output"
560+
assert generated_baseline == generated_subq, (
561+
f"Subq-ops path diverged from baseline:\nBaseline: {generated_baseline}\nSubq-ops: {generated_subq}"
562+
)
563+
564+
520565
def test_different_prompts_produce_different_outputs(mbridge_checkpoint_path, tmp_path):
521566
"""Test that different prompts produce different sequences.
522567

0 commit comments

Comments
 (0)