Skip to content

Commit 8b4bd11

Browse files
committed
add infer_evo2 --use-subquadratic-ops flag with test for matching baseline
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
1 parent 17c42cf commit 8b4bd11

2 files changed

Lines changed: 62 additions & 0 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def setup_inference_engine(
358358
vortex_style_fp8: bool = False,
359359
random_seed: int = 1234,
360360
prompt_segmentation_threshold: Optional[int] = None,
361+
use_subquadratic_ops: bool = False,
361362
) -> Evo2InferenceComponents:
362363
"""Setup the Evo2 inference engine and related components.
363364
@@ -379,6 +380,9 @@ def setup_inference_engine(
379380
segmented during prefill to reduce peak memory. The first segment
380381
runs as a normal prefill; remaining tokens are processed one at a
381382
time before generation begins.
383+
use_subquadratic_ops: Use fused subquadratic-ops kernels (b2b causal
384+
conv1d in prefill, fft_causal_conv1d / causal_conv1d in
385+
parallel_fir).
382386
383387
Returns:
384388
Evo2InferenceComponents containing all inference components.
@@ -412,6 +416,7 @@ def setup_inference_engine(
412416
model_provider.sequence_parallel = False
413417

414418
model_provider.flash_decode = True
419+
model_provider.use_subquadratic_ops = use_subquadratic_ops
415420

416421
if vortex_style_fp8:
417422
model_provider.vortex_style_fp8 = True
@@ -808,6 +813,14 @@ def parse_args() -> argparse.Namespace:
808813
"generation begins. Useful for long prompts that would otherwise OOM. "
809814
"Also settable via EVO2_PST env var.",
810815
)
816+
ap.add_argument(
817+
"--use-subquadratic-ops",
818+
action="store_true",
819+
default=False,
820+
help="Use fused subquadratic-ops CUDA kernels (b2b causal conv1d in prefill, "
821+
"fft_causal_conv1d / causal_conv1d in parallel_fir). Speeds up prompt processing "
822+
"but has no effect on per-token decode throughput.",
823+
)
811824

812825
return ap.parse_args()
813826

@@ -831,6 +844,7 @@ def infer(
831844
max_seq_length: int = 8192,
832845
max_batch_size: int = 1,
833846
prompt_segmentation_threshold: Optional[int] = None,
847+
use_subquadratic_ops: bool = False,
834848
) -> List[Dict[str, Any]]:
835849
"""Run autoregressive text generation with Evo2 using MCore inference.
836850
@@ -858,6 +872,7 @@ def infer(
858872
GPU memory proportional to this value. For large models, only 1 may fit.
859873
prompt_segmentation_threshold: If set, prompts longer than this are segmented
860874
during prefill to reduce peak memory.
875+
use_subquadratic_ops: Use fused subquadratic-ops kernels in the inference path.
861876
862877
Returns:
863878
List of JSONL-serialisable result dicts.
@@ -878,6 +893,7 @@ def infer(
878893
vortex_style_fp8=vortex_style_fp8,
879894
random_seed=random_seed,
880895
prompt_segmentation_threshold=prompt_segmentation_threshold,
896+
use_subquadratic_ops=use_subquadratic_ops,
881897
)
882898

883899
mem_after_setup_gb = torch.cuda.max_memory_allocated() / (1024**3)
@@ -1003,6 +1019,7 @@ def main() -> None:
10031019
max_seq_length=max_seq_length,
10041020
max_batch_size=args.max_batch_size,
10051021
prompt_segmentation_threshold=prompt_segmentation_threshold,
1022+
use_subquadratic_ops=args.use_subquadratic_ops,
10061023
)
10071024

10081025

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def run_infer_subprocess(
284284
temperature: float = 1.0,
285285
top_k: int = 1,
286286
seed: int = 42,
287+
use_subquadratic_ops: bool = False,
287288
):
288289
"""Helper function to run inference as a subprocess.
289290
@@ -295,6 +296,7 @@ def run_infer_subprocess(
295296
temperature: Sampling temperature
296297
top_k: Top-k sampling parameter (1 for greedy)
297298
seed: Random seed for reproducibility
299+
use_subquadratic_ops: Pass --use-subquadratic-ops to the CLI.
298300
299301
Returns:
300302
The generated completion text from the first JSONL record
@@ -326,6 +328,8 @@ def run_infer_subprocess(
326328
"--seed",
327329
str(seed),
328330
]
331+
if use_subquadratic_ops:
332+
cmd.append("--use-subquadratic-ops")
329333

330334
env = copy.deepcopy(PRETEST_ENV)
331335

@@ -517,6 +521,47 @@ def test_identical_prompts_should_be_identical(mbridge_checkpoint_path, tmp_path
517521
)
518522

519523

524+
def test_subquadratic_ops_matches_baseline(mbridge_checkpoint_path, tmp_path):
525+
"""Greedy generation with --use-subquadratic-ops must match the standard path.
526+
527+
This is the end-to-end correctness check for the subq-ops inference path:
528+
Phase 1 routes engine.parallel_fir through subq-ops kernels during prefill,
529+
Phase 2 fuses proj+mixer convs via b2b_causal_conv1d during prefill and
530+
populates FIR caches for the subsequent decode steps. With greedy decoding
531+
(top_k=1) and the same seed, both paths must produce identical output.
532+
"""
533+
output_baseline = tmp_path / "output_baseline.jsonl"
534+
output_subq = tmp_path / "output_subq.jsonl"
535+
536+
generated_baseline = run_infer_subprocess(
537+
mbridge_checkpoint_path,
538+
prompt=PROMPT_1,
539+
output_file=output_baseline,
540+
max_new_tokens=20,
541+
temperature=1.0,
542+
top_k=1,
543+
seed=42,
544+
use_subquadratic_ops=False,
545+
)
546+
547+
generated_subq = run_infer_subprocess(
548+
mbridge_checkpoint_path,
549+
prompt=PROMPT_1,
550+
output_file=output_subq,
551+
max_new_tokens=20,
552+
temperature=1.0,
553+
top_k=1,
554+
seed=42,
555+
use_subquadratic_ops=True,
556+
)
557+
558+
assert len(generated_baseline) > 0, "Baseline generation produced empty output"
559+
assert len(generated_subq) > 0, "Subq-ops generation produced empty output"
560+
assert generated_baseline == generated_subq, (
561+
f"Subq-ops path diverged from baseline:\nBaseline: {generated_baseline}\nSubq-ops: {generated_subq}"
562+
)
563+
564+
520565
def test_different_prompts_produce_different_outputs(mbridge_checkpoint_path, tmp_path):
521566
"""Test that different prompts produce different sequences.
522567

0 commit comments

Comments
 (0)