Skip to content

Commit 2be5190

Browse files
committed
Roll back more variable renamings
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent 9532cba commit 2be5190

1 file changed

Lines changed: 9 additions & 12 deletions

File tree

  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,6 @@ def __init__(
119119
self.fast_conv_mixer = self.hyena_config.fast_conv_mixer
120120

121121
self.use_subquadratic_ops = self.transformer_config.use_subquadratic_ops
122-
# TODO: Re-enable B2BCausalConv1dModule for short/medium Hyena layers once
123-
# subquadratic-ops updates it to support causal_conv1d 1.6+ semantics.
124-
self.use_fused_b2b_causal_conv1d = False
125122

126123
# Per attention head and per partition values.
127124
assert torch.distributed.is_initialized()
@@ -200,9 +197,9 @@ def __init__(
200197
use_conv_bias=self.transformer_config.use_short_conv_bias,
201198
)
202199

203-
if self.use_fused_b2b_causal_conv1d:
204-
# Create a wrapper module that doesn't register parameters
205-
# Use the existing weights from the original model
200+
if self.use_subquadratic_ops:
201+
# The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack
202+
# cannot run subquadratic_ops_torch correctly.
206203
self.b2b_kernel = B2BCausalConv1dModule(
207204
self.hyena_proj_conv,
208205
self.mixer,
@@ -231,9 +228,9 @@ def __init__(
231228
max_sequence_length,
232229
)
233230

234-
if self.use_fused_b2b_causal_conv1d and self.operator_type == "hyena_medium_conv":
235-
# Create a wrapper module that doesn't register parameters
236-
# Use the existing weights from the original model
231+
if self.use_subquadratic_ops and self.operator_type == "hyena_medium_conv":
232+
# The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack
233+
# cannot run subquadratic_ops_torch correctly.
237234
self.b2b_kernel = B2BCausalConv1dModule(
238235
self.hyena_proj_conv,
239236
self.mixer,
@@ -311,12 +308,12 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True
311308
else:
312309
features = rearrange(features, "l b d -> b d l").contiguous()
313310

314-
is_b2b_eligible = self.use_fused_b2b_causal_conv1d and self.operator_type in [
311+
is_b2b_eligible = self.use_subquadratic_ops and self.operator_type in [
315312
"hyena_short_conv",
316313
"hyena_medium_conv",
317314
]
318-
# b2b runs during training (no inference_context) or during prefill (no FIR cache yet).
319-
# During decode (cache populated, L=1) we fall back to the regular per-token step path.
315+
# B2B runs during training (no inference_context) or during prefill (no FIR cache yet).
316+
# During decode, fall back to the regular per-token step path.
320317
is_prefill = inference_context is not None and id(self.hyena_proj_conv) not in getattr(
321318
inference_context, "fir_filter_state_dict", {}
322319
)

0 commit comments

Comments
 (0)