@@ -119,9 +119,6 @@ def __init__(
119119 self .fast_conv_mixer = self .hyena_config .fast_conv_mixer
120120
121121 self .use_subquadratic_ops = self .transformer_config .use_subquadratic_ops
122- # TODO: Re-enable B2BCausalConv1dModule for short/medium Hyena layers once
123- # subquadratic-ops updates it to support causal_conv1d 1.6+ semantics.
124- self .use_fused_b2b_causal_conv1d = False
125122
126123 # Per attention head and per partition values.
127124 assert torch .distributed .is_initialized ()
@@ -200,9 +197,9 @@ def __init__(
200197 use_conv_bias = self .transformer_config .use_short_conv_bias ,
201198 )
202199
203- if self .use_fused_b2b_causal_conv1d :
204- # Create a wrapper module that doesn't register parameters
205- # Use the existing weights from the original model
200+ if self .use_subquadratic_ops :
201+ # The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack
202+ # cannot run subquadratic_ops_torch correctly.
206203 self .b2b_kernel = B2BCausalConv1dModule (
207204 self .hyena_proj_conv ,
208205 self .mixer ,
@@ -231,9 +228,9 @@ def __init__(
231228 max_sequence_length ,
232229 )
233230
234- if self .use_fused_b2b_causal_conv1d and self .operator_type == "hyena_medium_conv" :
235- # Create a wrapper module that doesn't register parameters
236- # Use the existing weights from the original model
231+ if self .use_subquadratic_ops and self .operator_type == "hyena_medium_conv" :
232+ # The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack
233+ # cannot run subquadratic_ops_torch correctly.
237234 self .b2b_kernel = B2BCausalConv1dModule (
238235 self .hyena_proj_conv ,
239236 self .mixer ,
@@ -311,12 +308,12 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True
311308 else :
312309 features = rearrange (features , "l b d -> b d l" ).contiguous ()
313310
314- is_b2b_eligible = self .use_fused_b2b_causal_conv1d and self .operator_type in [
311+ is_b2b_eligible = self .use_subquadratic_ops and self .operator_type in [
315312 "hyena_short_conv" ,
316313 "hyena_medium_conv" ,
317314 ]
318- # b2b runs during training (no inference_context) or during prefill (no FIR cache yet).
319- # During decode (cache populated, L=1) we fall back to the regular per-token step path.
315+ # B2B runs during training (no inference_context) or during prefill (no FIR cache yet).
316+ # During decode, fall back to the regular per-token step path.
320317 is_prefill = inference_context is not None and id (self .hyena_proj_conv ) not in getattr (
321318 inference_context , "fir_filter_state_dict" , {}
322319 )
0 commit comments