@@ -358,6 +358,7 @@ def setup_inference_engine(
358358 vortex_style_fp8 : bool = False ,
359359 random_seed : int = 1234 ,
360360 prompt_segmentation_threshold : Optional [int ] = None ,
361+ use_subquadratic_ops : bool = False ,
361362) -> Evo2InferenceComponents :
362363 """Setup the Evo2 inference engine and related components.
363364
@@ -379,6 +380,9 @@ def setup_inference_engine(
379380 segmented during prefill to reduce peak memory. The first segment
380381 runs as a normal prefill; remaining tokens are processed one at a
381382 time before generation begins.
383+ use_subquadratic_ops: Use fused subquadratic-ops kernels (b2b causal
384+ conv1d in prefill, fft_causal_conv1d / causal_conv1d in
385+ parallel_fir).
382386
383387 Returns:
384388 Evo2InferenceComponents containing all inference components.
@@ -412,6 +416,7 @@ def setup_inference_engine(
412416 model_provider .sequence_parallel = False
413417
414418 model_provider .flash_decode = True
419+ model_provider .use_subquadratic_ops = use_subquadratic_ops
415420
416421 if vortex_style_fp8 :
417422 model_provider .vortex_style_fp8 = True
@@ -808,6 +813,14 @@ def parse_args() -> argparse.Namespace:
808813 "generation begins. Useful for long prompts that would otherwise OOM. "
809814 "Also settable via EVO2_PST env var." ,
810815 )
816+ ap .add_argument (
817+ "--use-subquadratic-ops" ,
818+ action = "store_true" ,
819+ default = False ,
820+ help = "Use fused subquadratic-ops CUDA kernels (b2b causal conv1d in prefill, "
821+ "fft_causal_conv1d / causal_conv1d in parallel_fir). Speeds up prompt processing "
822+ "but has no effect on per-token decode throughput." ,
823+ )
811824
812825 return ap .parse_args ()
813826
@@ -831,6 +844,7 @@ def infer(
831844 max_seq_length : int = 8192 ,
832845 max_batch_size : int = 1 ,
833846 prompt_segmentation_threshold : Optional [int ] = None ,
847+ use_subquadratic_ops : bool = False ,
834848) -> List [Dict [str , Any ]]:
835849 """Run autoregressive text generation with Evo2 using MCore inference.
836850
@@ -858,6 +872,7 @@ def infer(
858872 GPU memory proportional to this value. For large models, only 1 may fit.
859873 prompt_segmentation_threshold: If set, prompts longer than this are segmented
860874 during prefill to reduce peak memory.
875+ use_subquadratic_ops: Use fused subquadratic-ops kernels in the inference path.
861876
862877 Returns:
863878 List of JSONL-serialisable result dicts.
@@ -878,6 +893,7 @@ def infer(
878893 vortex_style_fp8 = vortex_style_fp8 ,
879894 random_seed = random_seed ,
880895 prompt_segmentation_threshold = prompt_segmentation_threshold ,
896+ use_subquadratic_ops = use_subquadratic_ops ,
881897 )
882898
883899 mem_after_setup_gb = torch .cuda .max_memory_allocated () / (1024 ** 3 )
@@ -1003,6 +1019,7 @@ def main() -> None:
10031019 max_seq_length = max_seq_length ,
10041020 max_batch_size = args .max_batch_size ,
10051021 prompt_segmentation_threshold = prompt_segmentation_threshold ,
1022+ use_subquadratic_ops = args .use_subquadratic_ops ,
10061023 )
10071024
10081025
0 commit comments