From 121c57eea2425c459bb37802860e162c98c40e0e Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 11:38:53 -0700 Subject: [PATCH 01/15] Initial commit of short sequence prefix-invariant evo2 implementation Signed-off-by: John St. John --- .devcontainer/Dockerfile | 19 +++ .devcontainer/devcontainer.json | 1 + .devcontainer/initializeCommand.sh | 1 + .devcontainer/start.sh | 2 +- .../recipes/evo2_megatron/.ci_build.sh | 4 +- .../recipes/evo2_megatron/pyproject.toml | 3 +- .../evo2/models/megatron/hyena/engine.py | 38 +++--- .../evo2/models/megatron/hyena/hyena_block.py | 10 +- .../evo2/models/megatron/hyena/hyena_layer.py | 6 +- .../evo2/models/megatron/hyena/hyena_mixer.py | 11 +- .../evo2/models/megatron/hyena/hyena_utils.py | 9 +- .../src/bionemo/evo2/run/predict.py | 20 ++- .../tests/bionemo/evo2/conftest.py | 28 +++- .../evo2/models/megatron/hyena/test_engine.py | 79 +++++++++++ .../models/megatron/hyena/test_hyena_utils.py | 61 +++++++++ .../tests/bionemo/evo2/run/test_infer.py | 127 +++++++++++++++++- .../tests/bionemo/evo2/run/test_predict.py | 100 ++++++++++++++ 17 files changed, 474 insertions(+), 45 deletions(-) create mode 100644 bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 8d8d51b7b3..72f3ba4c62 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -10,7 +10,26 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \ PIP_CONSTRAINT= pip install -r /workspace/requirements.txt +# Sandboxed agent CLIs use these helpers on Linux. +RUN apt-get update && apt-get install -y --no-install-recommends \ + bubblewrap \ + uidmap \ + && rm -rf /var/lib/apt/lists/* + COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/ + +# Install Node.js 22 LTS and the OpenAI Codex CLI. +RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \ + && apt-get install -y --no-install-recommends nodejs \ + && rm -rf /var/lib/apt/lists/* \ + && npm install -g --no-fund --no-audit @openai/codex \ + && npm cache clean --force + +# Default Codex to Landlock where nested namespaces are restricted. +RUN mkdir -p /home/ubuntu/.codex \ + && printf '[features]\nuse_legacy_landlock = true\n' > /home/ubuntu/.codex/config.toml \ + && chown -R ubuntu:ubuntu /home/ubuntu/.codex + USER ubuntu RUN curl https://cursor.com/install -fsS | bash # Install cursor-agent CLI tool RUN curl -fsSL https://claude.ai/install.sh | bash # Install Claude CLI tool diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index a896c36a22..deac11c06e 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -8,6 +8,7 @@ "source=${localEnv:HOME}/.cache,target=/home/ubuntu/.cache,type=bind,consistency=cached", "source=${localEnv:HOME}/.claude,target=/home/ubuntu/.claude,type=bind,consistency=cached", "source=${localEnv:HOME}/.claude.json,target=/home/ubuntu/.claude.json,type=bind,consistency=cached", + "source=${localEnv:HOME}/.codex,target=/home/ubuntu/.codex,type=bind,consistency=cached", "source=${localEnv:HOME}/.config,target=/home/ubuntu/.config,type=bind,consistency=cached", "source=${localEnv:HOME}/.cursor,target=/home/ubuntu/.cursor,type=bind,consistency=cached", "source=${localEnv:HOME}/.gnupg,target=/home/ubuntu/.gnupg,type=bind,consistency=cached", diff --git a/.devcontainer/initializeCommand.sh b/.devcontainer/initializeCommand.sh index 182db741f8..aae34bd54f 100755 --- a/.devcontainer/initializeCommand.sh +++ b/.devcontainer/initializeCommand.sh @@ -8,6 +8,7 @@ mkdir -p ~/.gnupg mkdir -p ~/.config mkdir -p ~/.cursor mkdir -p ~/.claude +mkdir -p ~/.codex [ ! -f ~/.netrc ] && touch ~/.netrc [ ! -f ~/.bash_history_devcontainer ] && touch ~/.bash_history_devcontainer diff --git a/.devcontainer/start.sh b/.devcontainer/start.sh index 6f75fbe6b0..834aabf871 100755 --- a/.devcontainer/start.sh +++ b/.devcontainer/start.sh @@ -9,7 +9,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" DEVCONTAINER_JSON="${SCRIPT_DIR}/devcontainer.json" CONTAINER_NAME="${BIONEMO_CONTAINER_NAME:-bionemo-devcontainer}" -IMAGE_NAME="${BIONEMO_IMAGE_NAME:-bionemo-devcontainer:latest}" +IMAGE_NAME="${BIONEMO_IMAGE_NAME:-${CONTAINER_NAME}:latest}" # --------------------------------------------------------------------------- # Helpers diff --git a/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh b/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh index 58ee94d1f3..92997a6343 100755 --- a/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh +++ b/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh @@ -38,8 +38,8 @@ for pkg_dir in "$RECIPE_ROOT/../../../sub-packages/bionemo-recipeutils" "$RECIPE fi done -# 6. Install the recipe with all remaining dependencies -uv pip install -c pip-constraints.txt -e . --no-build-isolation +# 6. Install the recipe with all remaining dependencies, including test extras +uv pip install -c pip-constraints.txt -e '.[test]' --no-build-isolation # 7. Restore original pyproject.toml (the edit was only needed for uv resolution) mv pyproject.toml.ci_bak pyproject.toml diff --git a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml index e83ac2c76c..1392215b66 100644 --- a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml +++ b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml @@ -93,7 +93,8 @@ override-dependencies = [ [tool.uv.sources] # Shared recipe utilities (framework-agnostic) # External dependencies with specific git tags/commits -causal_conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d.git", tag = "v1.5.4" } +# 1.6.1 fixes a custom-op no-storage failure in no-grad/frozen forward paths. +causal_conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d.git", tag = "v1.6.1" } nv-grouped-gemm = { git = "https://github.com/fanshiqing/grouped_gemm", tag = "v1.1.4.post6" } # Internal dependencies diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py index c9ff82ba00..047d123730 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py @@ -21,14 +21,17 @@ try: - from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d except ImportError as _subq_import_error: - _subq_causal_conv1d = None _subq_fft_causal_conv1d = None _subq_error_msg = f"subquadratic_ops_torch not available: {_subq_import_error}" +def _linear_causal_fft_size(input_len: int, filter_len: int) -> int: + """Return an FFT size that cannot alias a causal convolution prefix.""" + return max(2 * input_len, 2 * filter_len) + + def adjust_filter_shape_for_broadcast(u, h): """Adjust filter shape for broadcasting compatibility with input tensor.""" h = h.squeeze() # Standardize to [D, L] from [1, D, L] and [D, 1, L] @@ -50,7 +53,7 @@ def fftconv_func(*, u, k, D): # noqa: N803 The convolution is computed in the frequency domain and then transformed back to the time domain. """ seqlen = u.shape[-1] - fft_size = 2 * seqlen + fft_size = _linear_causal_fft_size(seqlen, k.shape[-1]) k_f = torch.fft.rfft(k, n=fft_size) / fft_size k_f = adjust_filter_shape_for_broadcast(u, k_f) @@ -98,21 +101,14 @@ def parallel_fir( D=bias, ).to(dtype=u.dtype) else: - if use_subquadratic_ops: - # subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight; dtypes must match - pad_size = fir_length - 1 - x_padded = F.pad(u.to(torch.float32), (pad_size, 0)) - w = weight.squeeze(1) if weight.dim() == 3 else weight - z = _subq_causal_conv1d(x_padded, w.to(torch.float32))[..., pad_size:] - else: - z = F.conv1d( - u.to(torch.float32), - weight.to(torch.float32), - bias=None, - stride=1, - padding=fir_length - 1, - groups=u.shape[1], # always set to D, regardless of filter grouping - )[..., :L] + z = F.conv1d( + u.to(torch.float32), + weight.to(torch.float32), + bias=None, + stride=1, + padding=fir_length - 1, + groups=u.shape[1], # always set to D, regardless of filter grouping + )[..., :L] z = z.to(u.dtype) @@ -130,7 +126,7 @@ def parallel_fir( def parallel_iir(*, z_pre, h, D, L, poles, t, hidden_size, compute_state): # noqa: N803 """Compute the output state of the short convolutional filter.""" - fft_size = 2 * L + fft_size = _linear_causal_fft_size(L, h.shape[-1]) x1, x2, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1) x1v = x1 * v @@ -221,9 +217,9 @@ def prefill_via_modal_fft(*, x1v, L, poles, t, X_s): # noqa: N803 # When the model has a long convolution derived from a recurrence in modal form and prefill_style is "fft", # we split the filter into poles and residues and reuse FFT computation on the input. bs = x1v.shape[0] - fft_size = 2 * L + fft_size = X_s.shape[-1] state_s = (poles.to(torch.float32) * t).exp() - state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # noqa N806: B, D, state_dim, 2 * L + state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # noqa N806: B, D, state_dim, fft_size state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size) # Do not try to fix `UserWarning: Casting complex values to real discards # the imaginary part` by inserting state.real conversion anywhere before diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py index 93acf70e57..c413dd94b7 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from contextlib import nullcontext from dataclasses import dataclass from typing import Optional, Union @@ -120,13 +121,20 @@ def __init__( pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(layer_type_list) if get_cpu_offload_context is not None: - (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( + # Megatron Core changed this helper from six to seven positional arguments + # across releases. Pass only the arguments accepted by the installed version. + offload_args = [ self.config.cpu_offloading, self.config.cpu_offloading_num_layers, self.config.num_layers, self.config.cpu_offloading_activations, self.config.cpu_offloading_weights, self.config.cpu_offloading_double_buffering, + getattr(self.config, "cpu_offloading_retain_pinned_cpu_buffers", False), + ] + num_offload_params = len(inspect.signature(get_cpu_offload_context).parameters) + (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( + *offload_args[:num_offload_params], ) self.config._cpu_offloading_context = self.offload_context if self.config.cpu_offloading else None else: diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.py index 18867f3689..577bf9af2a 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.py @@ -108,8 +108,10 @@ def bias_dropout_add_exec_handler(self): if self.training: return torch.enable_grad else: - # Validation, Test, Inference, Etc. - return torch.inference_mode + # torch.inference_mode marks outputs as inference tensors. Those flags + # persist after leaving the context and can break downstream autograd or + # torch.library custom ops that consume frozen model outputs. + return torch.no_grad def forward( self, diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py index d8cb265e31..425060ac96 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py @@ -42,6 +42,7 @@ logger = logging.getLogger(__name__) + try: from transformer_engine.common.recipe import DelayedScaling, Format except ImportError: @@ -117,8 +118,10 @@ def __init__( self.fast_conv_proj = self.hyena_config.fast_conv_proj self.fast_conv_mixer = self.hyena_config.fast_conv_mixer - # Use b2b causal conv1d self.use_subquadratic_ops = self.transformer_config.use_subquadratic_ops + # TODO: Re-enable B2BCausalConv1dModule for short/medium Hyena layers once + # subquadratic-ops updates it to support causal_conv1d 1.6+ semantics. + self.use_fused_b2b_causal_conv1d = False # Per attention head and per partition values. assert torch.distributed.is_initialized() @@ -197,7 +200,7 @@ def __init__( use_conv_bias=self.transformer_config.use_short_conv_bias, ) - if self.use_subquadratic_ops: + if self.use_fused_b2b_causal_conv1d: # Create a wrapper module that doesn't register parameters # Use the existing weights from the original model self.b2b_kernel = B2BCausalConv1dModule( @@ -228,7 +231,7 @@ def __init__( max_sequence_length, ) - if self.use_subquadratic_ops and self.operator_type == "hyena_medium_conv": + if self.use_fused_b2b_causal_conv1d and self.operator_type == "hyena_medium_conv": # Create a wrapper module that doesn't register parameters # Use the existing weights from the original model self.b2b_kernel = B2BCausalConv1dModule( @@ -308,7 +311,7 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True else: features = rearrange(features, "l b d -> b d l").contiguous() - is_b2b_eligible = self.use_subquadratic_ops and self.operator_type in [ + is_b2b_eligible = self.use_fused_b2b_causal_conv1d and self.operator_type in [ "hyena_short_conv", "hyena_medium_conv", ] diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index 44be308aaf..436e11fae5 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -486,6 +486,7 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal if use_subquadratic_ops: y = fft_causal_conv1d(u, k.squeeze(0)) else: + fft_size = max(fft_size, 2 * k.shape[-1]) k_f = torch.fft.rfft(k, n=fft_size) / fft_size if k_rev is not None: k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size @@ -754,7 +755,8 @@ def forward(self, L, *args, **kwargs): # noqa: N803 """ return self.filter(L, *args, **kwargs) - @torch.compile(mode="max-autotune") + # Keep this eager. Compiling this helper can leave global dispatcher state + # that interferes with unrelated custom autograd/custom-op call sites. def filter(self, L, *args, **kwargs): # noqa: N803 """Compute the filter as a function of h and decay for the requested sequence length.""" h = self.h[:, :L] @@ -1456,10 +1458,7 @@ def forward(self, x, inference_context=None, _use_cp=True): # subquadratic_ops causal_conv1d is only applied to the projection conv of Hyena LI layer # Projection conv is fused with SE/MR layers (B2BCausalConv1dModule) if self.use_fast_causal_conv: # hyena_proj_conv case - if self.use_subquadratic_ops: # hyena_proj_conv of LI layer when subquadratic_ops is enabled - y = causal_conv1d(x, weight)[..., pad_size:] - else: - y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] + y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] else: # hyena_short_conv case y = F.conv1d( x, diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py index c888d46c94..66ba70b1af 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py @@ -76,7 +76,6 @@ ) from megatron.bridge.training.config import DistributedInitConfig, RNGConfig from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES, get_mixed_precision_config -from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer from megatron.bridge.training.utils.checkpoint_utils import ( file_exists, get_checkpoint_run_config_filename, @@ -97,6 +96,14 @@ from megatron.core.utils import get_batch_on_this_cp_rank from torch import Tensor + +try: + from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer +except ImportError: + from megatron.core.tokenizers.text.libraries.huggingface_tokenizer import ( + HuggingFaceTokenizer as _HuggingFaceTokenizer, + ) + from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset from bionemo.recipeutils.inference.collation import batch_collator @@ -656,7 +663,6 @@ def _predict_step( if not parallel_state.is_pipeline_last_stage(): return None - # Forward pass output_tensor = model( input_ids=batch["tokens"], position_ids=batch["position_ids"], @@ -1036,9 +1042,11 @@ def predict( f"Valid range: -{original_num_layers} to {original_num_layers - 1}." ) - # Set the model to use fewer layers and skip post-processing (output heads) + # Set the model to use fewer layers and skip post-processing (output heads). model_provider.num_layers = target_num_layers model_provider.post_process = False + if hasattr(model_provider, "fp32_residual_connection"): + model_provider.fp32_residual_connection = False # Also truncate the hybrid_override_pattern if it exists, since it must match num_layers if hasattr(model_provider, "hybrid_override_pattern") and model_provider.hybrid_override_pattern is not None: @@ -1276,6 +1284,12 @@ def predict( def main() -> None: """CLI entry point for Evo2 prediction.""" args = parse_args() + try: + from megatron.bridge.utils.instantiate_utils import register_allowed_target_prefix + + register_allowed_target_prefix("bionemo.") + except ImportError: + pass predict( fasta_path=args.fasta, ckpt_dir=args.ckpt_dir, diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py index 35b66ed455..e9e427282b 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py @@ -15,15 +15,41 @@ # conftest.py +# ruff: noqa: E402, I001 import copy import gc +import importlib import os import shlex import subprocess +import sys from pathlib import Path +from types import ModuleType import pytest -import torch + +# Load Transformer Engine's core library before torch; this avoids a CUDA library +# symbol-resolution failure seen when Megatron imports torch before TE is loaded. +importlib.import_module("transformer_engine.pytorch") +torch = importlib.import_module("torch") + +# Megatron Bridge imports Transformers conversion modules during collection. In +# this NGC image, torchvision is present but its custom ops are unavailable, so +# make Transformers skip optional vision imports for these non-vision tests. +transformers_import_utils = importlib.import_module("transformers.utils.import_utils") +transformers_import_utils._torchvision_available = False + + +# Megatron Bridge eagerly imports VLM recipe modules through its recipes package. +# Those modules import qwen_vl_utils, which imports torchvision. Provide the +# non-vision tests with a stub so Evo2 collection does not depend on VLM extras. +def _unavailable_process_vision_info(*args, **kwargs): + raise RuntimeError("qwen_vl_utils is unavailable in Evo2 tests") + + +qwen_vl_utils = ModuleType("qwen_vl_utils") +qwen_vl_utils.process_vision_info = _unavailable_process_vision_info +sys.modules["qwen_vl_utils"] = qwen_vl_utils from bionemo.core.data.load import load as bionemo_load from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH_512 diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py new file mode 100644 index 0000000000..02b169296f --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 + +import torch + +from bionemo.evo2.models.megatron.hyena import engine + + +def test_fftconv_func_is_prefix_invariant_when_filter_is_longer_than_input(): + """Short-input FFT convolution should match the prefix of a longer-input convolution.""" + torch.manual_seed(1234) + batch_size = 2 + hidden_size = 4 + short_len = 5 + long_len = 128 + filter_len = 128 + + u_long = torch.randn(batch_size, hidden_size, long_len) + u_short = u_long[..., :short_len].contiguous() + k = torch.randn(hidden_size, 1, filter_len) + d = torch.randn(hidden_size) + + short_out = engine.fftconv_func(u=u_short, k=k, D=d) + long_out = engine.fftconv_func(u=u_long, k=k, D=d)[..., :short_len] + + torch.testing.assert_close(short_out, long_out, rtol=1e-5, atol=1e-5) + + +def test_parallel_iir_is_prefix_invariant_when_filter_is_longer_than_input(): + """The IIR prefill convolution should not circularly alias short prefixes.""" + torch.manual_seed(1234) + batch_size = 2 + hidden_size = 4 + short_len = 5 + long_len = 128 + filter_len = 128 + + z_long = torch.randn(batch_size, 3 * hidden_size, long_len) + z_short = z_long[..., :short_len].contiguous() + h = torch.randn(hidden_size, filter_len) + d = torch.randn(hidden_size) + + short_out, _ = engine.parallel_iir( + z_pre=z_short, + h=h, + D=d, + L=short_len, + poles=None, + t=None, + hidden_size=hidden_size, + compute_state=False, + ) + long_out, _ = engine.parallel_iir( + z_pre=z_long, + h=h, + D=d, + L=long_len, + poles=None, + t=None, + hidden_size=hidden_size, + compute_state=False, + ) + + torch.testing.assert_close(short_out, long_out[:, :short_len], rtol=1e-5, atol=1e-5) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py index 2ba21709fa..1a1ce86627 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py @@ -20,6 +20,7 @@ import pytest import torch import torch.distributed as dist +import torch.nn.functional as F # noqa: N812 from bionemo.evo2.models.megatron.hyena.hyena_utils import ( B2BCausalConv1dModule, @@ -294,6 +295,66 @@ def test_b2b_causal_conv1d_effective_padding_size(): assert b2b_module.effective_pad_size == expected_pad_size +@pytest.mark.xfail( + reason="subquadratic-ops fused B2B kernel does not match causal_conv1d 1.6+ short-conv semantics", + strict=True, +) +def test_b2b_causal_conv1d_module_matches_sequential_reference(): + """Document the isolated B2B mismatch before re-enabling the fused path.""" + if not torch.cuda.is_available(): + pytest.skip("B2B causal conv isolation test requires CUDA") + + torch.manual_seed(1234) + batch_size = 2 + hidden_size = 4 + seq_len = 16 + proj_kernel_size = 3 + mixer_kernel_size = 7 + device = torch.device("cuda") + + x = torch.randn(batch_size, 3 * hidden_size, seq_len, device=device) + proj_weight = torch.randn(3 * hidden_size, proj_kernel_size, device=device) + mixer_weight = torch.randn(hidden_size, mixer_kernel_size, device=device) + bias = torch.randn(hidden_size, device=device) + + proj_conv = torch.nn.Module() + proj_conv.kernel_size = proj_kernel_size + proj_conv.short_conv_weight = proj_weight + proj_conv.group_dim = 1 + + mixer = torch.nn.Module() + mixer.use_conv_bias = True + mixer.group_dim = 1 + mixer.conv_bias = bias + mixer.short_conv = torch.nn.Module() + mixer.short_conv.kernel_size = mixer_kernel_size + mixer.short_conv.short_conv_weight = mixer_weight.unsqueeze(1) + + b2b_module = B2BCausalConv1dModule( + proj_conv, + mixer, + operator_type="hyena_short_conv", + pg_collection=MockProcessGroupCollection(), + ) + + fused = b2b_module(x).float() + projected = F.conv1d( + F.pad(x.float(), (proj_kernel_size - 1, 0)), + proj_weight.float().flip(-1).unsqueeze(1), + groups=3 * hidden_size, + ) + x1, x2, v = projected[:, ::3], projected[:, 1::3], projected[:, 2::3] + z = x2 * v + mixed = F.conv1d( + F.pad(z, (mixer_kernel_size - 1, 0)), + mixer_weight.float().flip(-1).unsqueeze(1), + groups=hidden_size, + ) + reference = x1 * (mixed + bias.float()[None, :, None] * z) + + torch.testing.assert_close(fused, reference, rtol=1e-4, atol=1e-4) + + def test_zigzag_get_overlapping_patches(): # noqa: D103 # Test the actual output of zigzag_get_overlapping_patches data = torch.arange(8).reshape(2, 4) # shape [2, 4] diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py index e2428f4b4f..4cc1e2cf13 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py @@ -390,6 +390,125 @@ def _write_prompts_jsonl(prompt_file: Path, prompts: list[tuple[str, str]]) -> N f.writelines(json.dumps({"id": prompt_id, "prompt": prompt_text}) + "\n" for prompt_id, prompt_text in prompts) +@pytest.fixture( + params=[False, True], + ids=["causal-conv1d", "subquadratic-ops"], +) +def infer_use_subquadratic_ops(request): + """Whether infer should use subquadratic Hyena kernels.""" + return request.param + + +def _run_infer_prompt_file( + *, + mbridge_checkpoint_path: Path, + prompt_file: Path, + output_file: Path, + max_batch_size: int, + use_subquadratic_ops: bool, +) -> dict[str, dict]: + open_port = find_free_network_port() + cmd = [ + "torchrun", + "--nproc_per_node", + "1", + "--nnodes", + "1", + "--master_port", + str(open_port), + "-m", + "bionemo.evo2.run.infer", + "--ckpt-dir", + str(mbridge_checkpoint_path), + "--prompt-file", + str(prompt_file), + "--max-new-tokens", + "1", + "--output-file", + str(output_file), + "--temperature", + "1.0", + "--top-k", + "1", + "--seed", + "1234", + "--max-batch-size", + str(max_batch_size), + "--max-seq-length", + "512", + "--return-log-probs", + ] + if use_subquadratic_ops: + cmd.append("--use-subquadratic-ops") + + result = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + timeout=512, + env=copy.deepcopy(PRETEST_ENV), + ) + assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + records = _read_jsonl_results(output_file) + return {record["id"]: record for record in records} + + +def _completion_logprobs(record: dict) -> torch.Tensor: + logprobs = record.get("logprobs", {}).get("completion_logprobs") + assert logprobs is not None, f"Missing completion logprobs in record: {record}" + tensor = torch.as_tensor(logprobs, dtype=torch.float32).flatten() + assert tensor.numel() == 1 + return tensor + + +@pytest.mark.timeout(512) +@pytest.mark.slow +def test_infer_evo2_short_prefill_is_prefix_invariant_across_batch_padding( + mbridge_checkpoint_path, + tmp_path, + infer_use_subquadratic_ops: bool, +): + """A short prefill should generate the same next token alone or in a padded batch.""" + if torch.cuda.device_count() < 1: + pytest.skip("Inference prefill prefix-invariance test requires a GPU") + + short_prompt = "ACGTACGTAA" + padding_prompt = ("GGCCGGGCGCGGTGGCTCACGCCTGTAATCCCAGCACTTTGGGAGGCCGAGGCGGGCGGATCACGAGGTC" * 4)[:256] + + alone_prompt_file = tmp_path / "short_alone_prompts.jsonl" + padded_prompt_file = tmp_path / "short_padded_prompts.jsonl" + _write_prompts_jsonl(alone_prompt_file, [("short", short_prompt)]) + _write_prompts_jsonl(padded_prompt_file, [("padding", padding_prompt), ("short", short_prompt)]) + + alone_records = _run_infer_prompt_file( + mbridge_checkpoint_path=mbridge_checkpoint_path, + prompt_file=alone_prompt_file, + output_file=tmp_path / "alone_output.jsonl", + max_batch_size=1, + use_subquadratic_ops=infer_use_subquadratic_ops, + ) + padded_records = _run_infer_prompt_file( + mbridge_checkpoint_path=mbridge_checkpoint_path, + prompt_file=padded_prompt_file, + output_file=tmp_path / "padded_output.jsonl", + max_batch_size=2, + use_subquadratic_ops=infer_use_subquadratic_ops, + ) + + assert set(alone_records) == {"short"} + assert set(padded_records) == {"padding", "short"} + assert padded_records["short"]["prompt"] == short_prompt + assert alone_records["short"]["completion"] == padded_records["short"]["completion"] + + torch.testing.assert_close( + _completion_logprobs(alone_records["short"]), + _completion_logprobs(padded_records["short"]), + rtol=2e-2, + atol=5e-2, + ) + + def run_infer_subprocess_parallel( mbridge_checkpoint_path, prompt_file: Path, @@ -524,10 +643,10 @@ def test_identical_prompts_should_be_identical(mbridge_checkpoint_path, tmp_path def test_subquadratic_ops_matches_baseline(mbridge_checkpoint_path, tmp_path): """Greedy generation with --use-subquadratic-ops must match the standard path. - This is the end-to-end correctness check for the subq-ops inference path: - Phase 1 routes engine.parallel_fir through subq-ops kernels during prefill, - Phase 2 fuses proj+mixer convs via b2b_causal_conv1d during prefill and - populates FIR caches for the subsequent decode steps. With greedy decoding + This is the end-to-end correctness check for the subq-ops inference path. + The currently enabled subq path uses fft_causal_conv1d for FFT-sized filters; + short direct kernels and fused B2B prefill stay on the standard path until + the fused kernels support causal_conv1d 1.6+ semantics. With greedy decoding (top_k=1) and the same seed, both paths must produce identical output. """ output_baseline = tmp_path / "output_baseline.jsonl" diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py index e28006cb74..b319720fa9 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py @@ -622,6 +622,106 @@ def test_predict_evo2_embedding_extraction( assert len(seq_idx_map) == num_sequences +@pytest.fixture( + params=[False, True], + ids=["causal-conv1d", "subquadratic-ops"], +) +def use_subquadratic_ops(request): + """Whether predict should use subquadratic Hyena kernels.""" + return request.param + + +@pytest.mark.timeout(512) +@pytest.mark.slow +def test_predict_evo2_short_embedding_is_prefix_invariant_across_batch_padding( + tmp_path, + mbridge_checkpoint_1b_8k_bf16_path: Path, + use_subquadratic_ops: bool, +): + """A short sequence should embed the same alone or padded in a longer batch.""" + if torch.cuda.device_count() < 1: + pytest.skip("Embedding prediction test requires a GPU") + + short_sequence = "ACGTACGTAA" + padding_sequence = (ALU_SEQUENCE * (256 // len(ALU_SEQUENCE) + 1))[:256] + + def _write_fasta(fasta_path: Path, records: dict[str, str]) -> None: + fasta_path.write_text("".join(f">{name}\n{sequence}\n" for name, sequence in records.items())) + + def _run_predict(fasta_path: Path, output_dir: Path) -> tuple[dict[str, torch.Tensor], dict[str, int]]: + open_port = find_free_network_port() + subquadratic_arg = " --use-subquadratic-ops" if use_subquadratic_ops else "" + command = ( + f"torchrun --nproc_per_node 1 --nnodes 1 --master_port {open_port} " + f"-m bionemo.evo2.run.predict --fasta {fasta_path} --ckpt-dir {mbridge_checkpoint_1b_8k_bf16_path} " + f"--output-dir {output_dir} --micro-batch-size 2 --write-interval epoch --embedding-layer -1" + f"{subquadratic_arg}" + ) + result = subprocess.run( + shlex.split(command), + check=False, + cwd=tmp_path, + capture_output=True, + text=True, + ) + if result.returncode != 0: + print("STDOUT:\n" + result.stdout) + print("STDERR:\n" + result.stderr) + assert result.returncode == 0, f"predict_evo2 command failed with code {result.returncode}" + + pred_files = sorted(glob.glob(str(output_dir / "predictions__rank_*__dp_rank_*.pt"))) + assert len(pred_files) == 1, f"Expected 1 prediction file, got {len(pred_files)}" + with open(output_dir / "seq_idx_map.json") as f: + seq_idx_map = json.load(f) + return torch.load(pred_files[0], weights_only=True), seq_idx_map + + def _unpadded_dna_embeddings( + preds: dict[str, torch.Tensor], + seq_idx_map: dict[str, int], + seqid: str, + dna_length: int, + ) -> torch.Tensor: + matches = (preds["seq_idx"] == seq_idx_map[seqid]).nonzero(as_tuple=True)[0] + assert matches.numel() == 1 + row = matches.item() + assert preds["pad_mask"][row].sum().item() == dna_length + return preds["hidden_embeddings"][row, :dna_length].to(torch.float32) + + def _relative_frobenius_error(left: torch.Tensor, right: torch.Tensor) -> float: + numerator = (left - right).float().pow(2).sum().sqrt() + denominator = right.float().pow(2).sum().sqrt() + return float(numerator / (denominator + 1e-30)) + + def _assert_prefix_embeddings_close(left: torch.Tensor, right: torch.Tensor) -> None: + rel_error = _relative_frobenius_error(left, right) + bound = 4.0 * (1.03**33) * float(torch.finfo(torch.bfloat16).eps) + if rel_error <= bound: + return + + rel_shuffled_hidden = _relative_frobenius_error(left, torch.roll(right, shifts=-1, dims=-1)) + rel_shuffled_sequence = _relative_frobenius_error(left, torch.roll(right, shifts=-1, dims=0)) + max_abs_diff = (left - right).abs().max().item() + raise AssertionError( + "Prefix embeddings exceeded bf16 relative-norm tolerance: " + f"rel={rel_error}, bound={bound}, rel_shuffled_hidden={rel_shuffled_hidden}, " + f"rel_shuffled_sequence={rel_shuffled_sequence}, max_abs_diff={max_abs_diff}" + ) + + alone_fasta = tmp_path / "short_alone.fasta" + padded_fasta = tmp_path / "short_padded.fasta" + _write_fasta(alone_fasta, {"short": short_sequence}) + _write_fasta(padded_fasta, {"short": short_sequence, "padding": padding_sequence}) + alone_preds, alone_seq_idx_map = _run_predict(alone_fasta, tmp_path / "alone_output") + padded_preds, padded_seq_idx_map = _run_predict(padded_fasta, tmp_path / "padded_output") + assert alone_preds["hidden_embeddings"].shape[1] == len(short_sequence) + assert padded_preds["hidden_embeddings"].shape[1] == len(padding_sequence) + + alone_embeddings = _unpadded_dna_embeddings(alone_preds, alone_seq_idx_map, "short", len(short_sequence)) + padded_embeddings = _unpadded_dna_embeddings(padded_preds, padded_seq_idx_map, "short", len(short_sequence)) + + _assert_prefix_embeddings_close(alone_embeddings, padded_embeddings) + + @pytest.mark.slow def test_predict_evo2_embedding_layer_validation( tmp_path, From f2cd6a1367ac8dc8e6c0994cc09aa15047d92103 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 11:52:55 -0700 Subject: [PATCH 02/15] Bump megatron bridge dep Signed-off-by: John St. John --- bionemo-recipes/recipes/evo2_megatron/.ci_build.sh | 2 +- .../recipes/evo2_megatron/build_requirements.txt | 2 ++ .../recipes/evo2_megatron/pyproject.toml | 14 +++++++++----- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh b/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh index 92997a6343..4eb8c487aa 100755 --- a/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh +++ b/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh @@ -5,7 +5,7 @@ rm -f /usr/local/lib/python*/dist-packages/transformer_engine-*.dist-info/direct_url.json export UV_LOCK_TIMEOUT=900 # increase to 15 minutes (900 seconds), adjust as needed export UV_LINK_MODE=copy -uv venv --system-site-packages +uv venv --clear --system-site-packages # 2. Activate the environment source .venv/bin/activate diff --git a/bionemo-recipes/recipes/evo2_megatron/build_requirements.txt b/bionemo-recipes/recipes/evo2_megatron/build_requirements.txt index 5c7e5906a1..38cbd1fe91 100644 --- a/bionemo-recipes/recipes/evo2_megatron/build_requirements.txt +++ b/bionemo-recipes/recipes/evo2_megatron/build_requirements.txt @@ -1,3 +1,5 @@ poetry-core +poetry_dynamic_versioning # build dep of nvidia-resiliency-ext (transitively pulled by megatron-bridge); needed in the venv because we install with --no-build-isolation +grpcio-tools # build dep of nvidia-resiliency-ext: its setup.py shells out to `python -m grpc_tools.protoc` to compile *.proto files; --no-build-isolation means we have to provide it in the venv up-front wheel_stub ninja # should speed up causal-conv1d build diff --git a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml index 1392215b66..74b395c5e8 100644 --- a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml +++ b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "causal_conv1d", "nv-grouped-gemm", "megatron-core", - "nvidia-resiliency-ext", + # nvidia-resiliency-ext is pulled transitively by megatron-bridge. "emerging_optimizers", "subquadratic-ops-torch-cu13", @@ -88,6 +88,9 @@ override-dependencies = [ "triton; sys_platform == 'never'", "transformer-engine; sys_platform == 'never'", "transformer-engine[pytorch]; sys_platform == 'never'", + # Avoid optional log-pattern-mining dependency conflicts from nvidia-resiliency-ext. + "logsage; sys_platform == 'never'", + "drain3; sys_platform == 'never'", ] [tool.uv.sources] @@ -100,11 +103,12 @@ nv-grouped-gemm = { git = "https://github.com/fanshiqing/grouped_gemm", tag = "v # Internal dependencies bionemo-recipeutils = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch = "main", subdirectory = "sub-packages/bionemo-recipeutils" } bionemo-core = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch = "main", subdirectory = "sub-packages/bionemo-core" } -nvidia-resiliency-ext = { git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git", rev = "54f85fe422d296cf04ea524130014bd3a2c3add1" } +# nvidia-resiliency-ext is intentionally left to Megatron-Bridge so the transitive pin stays consistent. -# Megatron Bundle. This points to a version that still supports the deprecated no_weight_decay_cond field until the API for an alternative has been finalized. -megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f" } -megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f", subdirectory = "3rdparty/Megatron-LM" } +# Megatron Bundle. MCore is sourced from the same Megatron-Bridge release tag. +megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", tag = "v0.4.1" } +megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", tag = "v0.4.1", subdirectory = "3rdparty/Megatron-LM" } [tool.uv.extra-build-dependencies] warp-lang = ["wheel_stub"] +nvidia-resiliency-ext = ["poetry_dynamic_versioning"] From 49be647490dce74a1dcca68cd9f87d080d8bc6f7 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 12:51:22 -0700 Subject: [PATCH 03/15] Go back to the original subq version, assume it works on other gpus and fail loudly if the CUDA_ERROR_UNSUPPORTED_PTX_VERSION error comes up Signed-off-by: John St. John --- .../recipes/evo2_megatron/pyproject.toml | 3 + .../evo2/models/megatron/hyena/engine.py | 34 +++- .../evo2/models/megatron/hyena/hyena_mixer.py | 4 +- .../evo2/models/megatron/hyena/hyena_utils.py | 26 ++- .../megatron/hyena/subquadratic_safety.py | 158 ++++++++++++++++++ .../src/bionemo/evo2/run/infer.py | 9 +- .../evo2/models/megatron/hyena/test_engine.py | 53 +++++- .../models/megatron/hyena/test_hyena_utils.py | 4 +- .../tests/bionemo/evo2/run/test_infer.py | 15 +- .../tests/bionemo/evo2/run/test_predict.py | 6 + 10 files changed, 288 insertions(+), 24 deletions(-) create mode 100644 bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py diff --git a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml index 74b395c5e8..93246f161f 100644 --- a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml +++ b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ # nvidia-resiliency-ext is pulled transitively by megatron-bridge. "emerging_optimizers", "subquadratic-ops-torch-cu13", + "email-validator", # These are dependencies for examples only, but are useful for actually doing analyses with this model "biopython", @@ -88,6 +89,8 @@ override-dependencies = [ "triton; sys_platform == 'never'", "transformer-engine; sys_platform == 'never'", "transformer-engine[pytorch]; sys_platform == 'never'", + # Avoid alpha Pydantic releases; langchain imports pulled by nvidia-resiliency-ext are not compatible. + "pydantic>=2.12,<2.14", # Avoid optional log-pattern-mining dependency conflicts from nvidia-resiliency-ext. "logsage; sys_platform == 'never'", "drain3; sys_platform == 'never'", diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py index 047d123730..470a909ec9 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py @@ -19,10 +19,17 @@ import torch.nn.functional as F # noqa: N812 from einops import rearrange +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ( + ensure_subquadratic_causal_conv1d_supported, + ensure_subquadratic_fft_causal_conv1d_supported, +) + try: + from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d except ImportError as _subq_import_error: + _subq_causal_conv1d = None _subq_fft_causal_conv1d = None _subq_error_msg = f"subquadratic_ops_torch not available: {_subq_import_error}" @@ -87,6 +94,7 @@ def parallel_fir( if fir_length >= 128: if use_subquadratic_ops: # subq-ops fft_causal_conv1d expects [B, D, L] input and [D, L] filter; dtypes must match + ensure_subquadratic_fft_causal_conv1d_supported() k = weight[:, :, :L].squeeze(1) if weight.dim() == 3 else weight[:, :L] u_fp32 = u.to(torch.float32) z = _subq_fft_causal_conv1d(u_fp32, k.to(torch.float32)) @@ -101,14 +109,24 @@ def parallel_fir( D=bias, ).to(dtype=u.dtype) else: - z = F.conv1d( - u.to(torch.float32), - weight.to(torch.float32), - bias=None, - stride=1, - padding=fir_length - 1, - groups=u.shape[1], # always set to D, regardless of filter grouping - )[..., :L] + if use_subquadratic_ops: + if _subq_causal_conv1d is None: + raise ImportError(_subq_error_msg) + # subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight. + ensure_subquadratic_causal_conv1d_supported() + pad_size = fir_length - 1 + x_padded = F.pad(u.to(torch.float32), (pad_size, 0)) + w = weight.squeeze(1) if weight.dim() == 3 else weight + z = _subq_causal_conv1d(x_padded, w.to(torch.float32))[..., pad_size:] + else: + z = F.conv1d( + u.to(torch.float32), + weight.to(torch.float32), + bias=None, + stride=1, + padding=fir_length - 1, + groups=u.shape[1], + )[..., :L] z = z.to(u.dtype) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py index 425060ac96..508f670205 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py @@ -119,9 +119,7 @@ def __init__( self.fast_conv_mixer = self.hyena_config.fast_conv_mixer self.use_subquadratic_ops = self.transformer_config.use_subquadratic_ops - # TODO: Re-enable B2BCausalConv1dModule for short/medium Hyena layers once - # subquadratic-ops updates it to support causal_conv1d 1.6+ semantics. - self.use_fused_b2b_causal_conv1d = False + self.use_fused_b2b_causal_conv1d = self.use_subquadratic_ops # Per attention head and per partition values. assert torch.distributed.is_initialized() diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index 436e11fae5..a3ebc98e86 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -33,6 +33,11 @@ from torch.autograd.function import Function from bionemo.evo2.models.megatron.hyena.hyena_config import HyenaConfig +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ( + ensure_subquadratic_b2b_causal_conv1d_supported, + ensure_subquadratic_causal_conv1d_supported, + ensure_subquadratic_fft_causal_conv1d_supported, +) try: @@ -50,10 +55,25 @@ def causal_conv1d_fn(*args, **kwargs): try: - from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d - from subquadratic_ops_torch.causal_conv1d import causal_conv1d - from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d + from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d as _subq_b2b_causal_conv1d + from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d + from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d from subquadratic_ops_torch.implicit_filter import implicit_filter + + def causal_conv1d(*args, **kwargs): + """Run guarded subquadratic causal_conv1d.""" + ensure_subquadratic_causal_conv1d_supported() + return _subq_causal_conv1d(*args, **kwargs) + + def b2b_causal_conv1d(*args, **kwargs): + """Run guarded subquadratic b2b_causal_conv1d.""" + ensure_subquadratic_b2b_causal_conv1d_supported() + return _subq_b2b_causal_conv1d(*args, **kwargs) + + def fft_causal_conv1d(*args, **kwargs): + """Run guarded subquadratic fft_causal_conv1d.""" + ensure_subquadratic_fft_causal_conv1d_supported() + return _subq_fft_causal_conv1d(*args, **kwargs) except ImportError as e: msg_causal_conv1d = f"Problem importing subquadratic_ops: {e}. causal_conv1d is not available." msg_b2b_causal_conv1d = f"Problem importing subquadratic_ops: {e}. b2b_causal_conv1d is not available." diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py new file mode 100644 index 0000000000..1ef01ee87f --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import lru_cache + +import torch +import torch.nn.functional as F # noqa: N812 + + +def _raise_subquadratic_self_test_error(op_name: str, detail: str) -> None: + raise RuntimeError( + f"subquadratic_ops_torch.{op_name} failed a CUDA self-test ({detail}). " + "This often happens with CUDA_ERROR_UNSUPPORTED_PTX_VERSION or unsupported GPU/toolchain " + "combinations. Refusing to run this subquadratic kernel because it can otherwise return " + "invalid outputs without raising." + ) + + +def _assert_close_or_raise(op_name: str, actual: torch.Tensor, expected: torch.Tensor) -> None: + torch.cuda.synchronize(actual.device) + if not torch.isfinite(actual).all(): + _raise_subquadratic_self_test_error(op_name, "non-finite output") + + if not torch.allclose(actual, expected, rtol=1e-4, atol=1e-4): + max_diff = (actual.float() - expected.float()).abs().max().item() + rel = ( + (actual.float() - expected.float()).pow(2).sum().sqrt() / (expected.float().pow(2).sum().sqrt() + 1e-30) + ).item() + _raise_subquadratic_self_test_error(op_name, f"max_diff={max_diff:.6g}, rel={rel:.6g}") + + +@lru_cache(maxsize=None) +def ensure_subquadratic_causal_conv1d_supported(device_index: int | None = None) -> None: + """Validate subquadratic_ops_torch.causal_conv1d before using it for model data.""" + if not torch.cuda.is_available(): + return + + device_index = torch.cuda.current_device() if device_index is None else device_index + device = torch.device("cuda", device_index) + + from subquadratic_ops_torch.causal_conv1d import causal_conv1d as subq_causal_conv1d + + batch_size = 1 + hidden_size = 4 + seq_len = 8 + kernel_size = 3 + pad_size = kernel_size - 1 + + u = torch.linspace(-1.0, 1.0, steps=batch_size * hidden_size * seq_len, device=device).reshape( + batch_size, hidden_size, seq_len + ) + weight = torch.linspace(-0.5, 0.5, steps=hidden_size * kernel_size, device=device).reshape( + hidden_size, kernel_size + ) + + expected = F.conv1d( + u, + weight.unsqueeze(1), + bias=None, + stride=1, + padding=pad_size, + groups=hidden_size, + )[..., :seq_len] + actual = subq_causal_conv1d(F.pad(u, (pad_size, 0)), weight)[..., pad_size:] + _assert_close_or_raise("causal_conv1d", actual, expected) + + +@lru_cache(maxsize=None) +def ensure_subquadratic_fft_causal_conv1d_supported(device_index: int | None = None) -> None: + """Validate subquadratic_ops_torch.fft_causal_conv1d before using it for model data.""" + if not torch.cuda.is_available(): + return + + device_index = torch.cuda.current_device() if device_index is None else device_index + device = torch.device("cuda", device_index) + + from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as subq_fft_causal_conv1d + + batch_size = 1 + hidden_size = 4 + seq_len = 8 + kernel_size = 5 + + u = torch.linspace(-1.0, 1.0, steps=batch_size * hidden_size * seq_len, device=device).reshape( + batch_size, hidden_size, seq_len + ) + weight = torch.linspace(-0.5, 0.5, steps=hidden_size * kernel_size, device=device).reshape( + hidden_size, kernel_size + ) + + expected = F.conv1d( + u, + weight.flip(-1).unsqueeze(1), + bias=None, + stride=1, + padding=kernel_size - 1, + groups=hidden_size, + )[..., :seq_len] + actual = subq_fft_causal_conv1d(u, weight) + _assert_close_or_raise("fft_causal_conv1d", actual, expected) + + +@lru_cache(maxsize=None) +def ensure_subquadratic_b2b_causal_conv1d_supported(device_index: int | None = None) -> None: + """Validate subquadratic_ops_torch.b2b_causal_conv1d before using it for model data.""" + if not torch.cuda.is_available(): + return + + device_index = torch.cuda.current_device() if device_index is None else device_index + device = torch.device("cuda", device_index) + + from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d as subq_b2b_causal_conv1d + + batch_size = 1 + hidden_size = 2 + seq_len = 10 + proj_kernel_size = 3 + mixer_kernel_size = 7 + + x = torch.linspace(-1.0, 1.0, steps=batch_size * 3 * hidden_size * seq_len, device=device).reshape( + batch_size, 3 * hidden_size, seq_len + ) + proj_weight = torch.linspace(-0.5, 0.5, steps=3 * hidden_size * proj_kernel_size, device=device).reshape( + 3 * hidden_size, proj_kernel_size + ) + mixer_weight = torch.linspace(-0.25, 0.25, steps=hidden_size * mixer_kernel_size, device=device).reshape( + hidden_size, mixer_kernel_size + ) + bias = torch.linspace(-0.1, 0.1, steps=hidden_size, device=device) + + actual = subq_b2b_causal_conv1d(x, proj_weight, mixer_weight, bias) + + projected = F.conv1d( + F.pad(x, (proj_kernel_size - 1, 0)), + proj_weight.flip(-1).unsqueeze(1), + groups=3 * hidden_size, + ) + x1, x2, v = projected[:, ::3], projected[:, 1::3], projected[:, 2::3] + z = x2 * v + mixed = F.conv1d( + F.pad(z, (mixer_kernel_size - 1, 0)), + mixer_weight.flip(-1).unsqueeze(1), + groups=hidden_size, + ) + expected = x1 * (mixed + bias[None, :, None] * z) + _assert_close_or_raise("b2b_causal_conv1d", actual, expected) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py index 0ed9d4defb..6219d5854c 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py @@ -77,7 +77,14 @@ ) from megatron.bridge.training.config import DistributedInitConfig, RNGConfig from megatron.bridge.training.mixed_precision import get_mixed_precision_config -from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer + + +try: + from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer +except ImportError: + from megatron.core.tokenizers.text.libraries.huggingface_tokenizer import ( + HuggingFaceTokenizer as _HuggingFaceTokenizer, + ) from megatron.bridge.training.utils.checkpoint_utils import ( file_exists, get_checkpoint_run_config_filename, diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py index 02b169296f..e0b1c5e5b6 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py @@ -13,10 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 - +import pytest import torch +import torch.nn.functional as F # noqa: N812 from bionemo.evo2.models.megatron.hyena import engine @@ -77,3 +76,51 @@ def test_parallel_iir_is_prefix_invariant_when_filter_is_longer_than_input(): ) torch.testing.assert_close(short_out, long_out[:, :short_len], rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("use_subquadratic_ops", [False, True], ids=["torch", "subq"]) +def test_parallel_fir_short_cuda_path_matches_torch_depthwise_conv1d(use_subquadratic_ops): + """Short FIR prefill should match F.conv1d or fail before returning bad subq output.""" + if not torch.cuda.is_available(): + pytest.skip("short FIR CUDA path requires CUDA") + + torch.manual_seed(1234) + batch_size = 2 + seq_len = 17 + hidden_size = 8 + kernel_size = 7 + device = torch.device("cuda") + + u = torch.randn(batch_size, seq_len, hidden_size, device=device) + weight = torch.randn(hidden_size, 1, kernel_size, device=device) + bias = torch.randn(hidden_size, device=device) + + try: + actual, state = engine.parallel_fir( + u=u, + weight=weight, + bias=bias, + L=seq_len, + gated_bias=True, + fir_length=kernel_size, + compute_state=True, + use_subquadratic_ops=use_subquadratic_ops, + ) + except RuntimeError as e: + if use_subquadratic_ops and "failed a CUDA self-test" in str(e): + pytest.xfail(str(e)) + raise + + u_bdl = u.transpose(1, 2).contiguous() + expected = F.conv1d( + u_bdl.float(), + weight.float(), + bias=None, + stride=1, + padding=kernel_size - 1, + groups=hidden_size, + )[..., :seq_len] + expected = expected.to(u.dtype) + bias[None, :, None] * u_bdl + + torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(state, u_bdl[..., -(kernel_size - 1) :]) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py index 1a1ce86627..f350700cb9 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py @@ -296,11 +296,11 @@ def test_b2b_causal_conv1d_effective_padding_size(): @pytest.mark.xfail( - reason="subquadratic-ops fused B2B kernel does not match causal_conv1d 1.6+ short-conv semantics", + reason="subquadratic-ops fused B2B kernel may fail CUDA/PTX self-test on unsupported GPUs", strict=True, ) def test_b2b_causal_conv1d_module_matches_sequential_reference(): - """Document the isolated B2B mismatch before re-enabling the fused path.""" + """Document the isolated B2B CUDA kernel behavior before relying on the fused path.""" if not torch.cuda.is_available(): pytest.skip("B2B causal conv isolation test requires CUDA") diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py index 4cc1e2cf13..618a0f6f20 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py @@ -51,6 +51,11 @@ # Note: mbridge_checkpoint_path fixture is provided by conftest.py at session scope +def _xfail_if_unsupported_subquadratic_ops(result: subprocess.CompletedProcess, use_subquadratic_ops: bool) -> None: + if use_subquadratic_ops and "failed a CUDA self-test" in result.stderr: + pytest.xfail("subquadratic_ops_torch CUDA kernels are unsupported in this environment") + + def _read_jsonl_results(output_file: Path) -> list[dict]: """Read JSONL output file and return parsed records.""" records = [] @@ -342,6 +347,7 @@ def run_infer_subprocess( env=env, ) + _xfail_if_unsupported_subquadratic_ops(result, use_subquadratic_ops) assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" assert output_file.exists(), "Output file was not created" @@ -449,6 +455,7 @@ def _run_infer_prompt_file( timeout=512, env=copy.deepcopy(PRETEST_ENV), ) + _xfail_if_unsupported_subquadratic_ops(result, use_subquadratic_ops) assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" records = _read_jsonl_results(output_file) return {record["id"]: record for record in records} @@ -644,10 +651,10 @@ def test_subquadratic_ops_matches_baseline(mbridge_checkpoint_path, tmp_path): """Greedy generation with --use-subquadratic-ops must match the standard path. This is the end-to-end correctness check for the subq-ops inference path. - The currently enabled subq path uses fft_causal_conv1d for FFT-sized filters; - short direct kernels and fused B2B prefill stay on the standard path until - the fused kernels support causal_conv1d 1.6+ semantics. With greedy decoding - (top_k=1) and the same seed, both paths must produce identical output. + The subq path uses guarded subquadratic kernels. If the local CUDA/GPU + combination cannot run those kernels correctly, the guard raises before + invalid outputs can propagate. With greedy decoding (top_k=1) and the same + seed, supported subq kernels must produce identical output. """ output_baseline = tmp_path / "output_baseline.jsonl" output_subq = tmp_path / "output_subq.jsonl" diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py index b319720fa9..07ed1102d2 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py @@ -43,6 +43,11 @@ PRETEST_ENV = copy.deepcopy(os.environ) +def _xfail_if_unsupported_subquadratic_ops(result: subprocess.CompletedProcess, use_subquadratic_ops: bool) -> None: + if use_subquadratic_ops and "failed a CUDA self-test" in result.stderr: + pytest.xfail("subquadratic_ops_torch CUDA kernels are unsupported in this environment") + + @pytest.fixture(scope="module") def mbridge_checkpoint_1b_8k_bf16_path(mbridge_checkpoint_1b_8k_bf16) -> Path: """Module-scoped alias for the session-scoped 1b-8k-bf16 checkpoint. @@ -664,6 +669,7 @@ def _run_predict(fasta_path: Path, output_dir: Path) -> tuple[dict[str, torch.Te capture_output=True, text=True, ) + _xfail_if_unsupported_subquadratic_ops(result, use_subquadratic_ops) if result.returncode != 0: print("STDOUT:\n" + result.stdout) print("STDERR:\n" + result.stderr) From 9532cba44cf15a819d2d4ea6bea4250083d03118 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 12:55:01 -0700 Subject: [PATCH 04/15] Roll back hyena_mixer diffs Signed-off-by: John St. John --- .../src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py index 508f670205..425060ac96 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py @@ -119,7 +119,9 @@ def __init__( self.fast_conv_mixer = self.hyena_config.fast_conv_mixer self.use_subquadratic_ops = self.transformer_config.use_subquadratic_ops - self.use_fused_b2b_causal_conv1d = self.use_subquadratic_ops + # TODO: Re-enable B2BCausalConv1dModule for short/medium Hyena layers once + # subquadratic-ops updates it to support causal_conv1d 1.6+ semantics. + self.use_fused_b2b_causal_conv1d = False # Per attention head and per partition values. assert torch.distributed.is_initialized() From 2be519053f6cf99ca6f4ffbe2af2affbd9b3db71 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 13:14:41 -0700 Subject: [PATCH 05/15] Roll back more variable renamings Signed-off-by: John St. John --- .../evo2/models/megatron/hyena/hyena_mixer.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py index 425060ac96..81271c8e47 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py @@ -119,9 +119,6 @@ def __init__( self.fast_conv_mixer = self.hyena_config.fast_conv_mixer self.use_subquadratic_ops = self.transformer_config.use_subquadratic_ops - # TODO: Re-enable B2BCausalConv1dModule for short/medium Hyena layers once - # subquadratic-ops updates it to support causal_conv1d 1.6+ semantics. - self.use_fused_b2b_causal_conv1d = False # Per attention head and per partition values. assert torch.distributed.is_initialized() @@ -200,9 +197,9 @@ def __init__( use_conv_bias=self.transformer_config.use_short_conv_bias, ) - if self.use_fused_b2b_causal_conv1d: - # Create a wrapper module that doesn't register parameters - # Use the existing weights from the original model + if self.use_subquadratic_ops: + # The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack + # cannot run subquadratic_ops_torch correctly. self.b2b_kernel = B2BCausalConv1dModule( self.hyena_proj_conv, self.mixer, @@ -231,9 +228,9 @@ def __init__( max_sequence_length, ) - if self.use_fused_b2b_causal_conv1d and self.operator_type == "hyena_medium_conv": - # Create a wrapper module that doesn't register parameters - # Use the existing weights from the original model + if self.use_subquadratic_ops and self.operator_type == "hyena_medium_conv": + # The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack + # cannot run subquadratic_ops_torch correctly. self.b2b_kernel = B2BCausalConv1dModule( self.hyena_proj_conv, self.mixer, @@ -311,12 +308,12 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True else: features = rearrange(features, "l b d -> b d l").contiguous() - is_b2b_eligible = self.use_fused_b2b_causal_conv1d and self.operator_type in [ + is_b2b_eligible = self.use_subquadratic_ops and self.operator_type in [ "hyena_short_conv", "hyena_medium_conv", ] - # b2b runs during training (no inference_context) or during prefill (no FIR cache yet). - # During decode (cache populated, L=1) we fall back to the regular per-token step path. + # B2B runs during training (no inference_context) or during prefill (no FIR cache yet). + # During decode, fall back to the regular per-token step path. is_prefill = inference_context is not None and id(self.hyena_proj_conv) not in getattr( inference_context, "fir_filter_state_dict", {} ) From 4ed5d0d57e463e686a853f5a81e1d92fe9b41495 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 13:29:40 -0700 Subject: [PATCH 06/15] Remove overly granular checks on compatability Signed-off-by: John St. John --- .../evo2/models/megatron/hyena/engine.py | 7 ++- .../evo2/models/megatron/hyena/hyena_utils.py | 62 ++++++++++++++----- .../models/megatron/hyena/test_hyena_utils.py | 23 +++++++ 3 files changed, 75 insertions(+), 17 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py index 470a909ec9..652dac3759 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py @@ -83,6 +83,7 @@ def parallel_fir( fir_length, compute_state, use_subquadratic_ops=False, + check_subquadratic_ops=True, ): """Compute parallel finite impulse response filtering with optional state computation.""" L = u.shape[1] # noqa: N806 @@ -94,7 +95,8 @@ def parallel_fir( if fir_length >= 128: if use_subquadratic_ops: # subq-ops fft_causal_conv1d expects [B, D, L] input and [D, L] filter; dtypes must match - ensure_subquadratic_fft_causal_conv1d_supported() + if check_subquadratic_ops and u.is_cuda: + ensure_subquadratic_fft_causal_conv1d_supported() k = weight[:, :, :L].squeeze(1) if weight.dim() == 3 else weight[:, :L] u_fp32 = u.to(torch.float32) z = _subq_fft_causal_conv1d(u_fp32, k.to(torch.float32)) @@ -113,7 +115,8 @@ def parallel_fir( if _subq_causal_conv1d is None: raise ImportError(_subq_error_msg) # subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight. - ensure_subquadratic_causal_conv1d_supported() + if check_subquadratic_ops and u.is_cuda: + ensure_subquadratic_causal_conv1d_supported() pad_size = fir_length - 1 x_padded = F.pad(u.to(torch.float32), (pad_size, 0)) w = weight.squeeze(1) if weight.dim() == 3 else weight diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index a3ebc98e86..aae76e8f77 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -60,20 +60,9 @@ def causal_conv1d_fn(*args, **kwargs): from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d from subquadratic_ops_torch.implicit_filter import implicit_filter - def causal_conv1d(*args, **kwargs): - """Run guarded subquadratic causal_conv1d.""" - ensure_subquadratic_causal_conv1d_supported() - return _subq_causal_conv1d(*args, **kwargs) - - def b2b_causal_conv1d(*args, **kwargs): - """Run guarded subquadratic b2b_causal_conv1d.""" - ensure_subquadratic_b2b_causal_conv1d_supported() - return _subq_b2b_causal_conv1d(*args, **kwargs) - - def fft_causal_conv1d(*args, **kwargs): - """Run guarded subquadratic fft_causal_conv1d.""" - ensure_subquadratic_fft_causal_conv1d_supported() - return _subq_fft_causal_conv1d(*args, **kwargs) + causal_conv1d = _subq_causal_conv1d + b2b_causal_conv1d = _subq_b2b_causal_conv1d + fft_causal_conv1d = _subq_fft_causal_conv1d except ImportError as e: msg_causal_conv1d = f"Problem importing subquadratic_ops: {e}. causal_conv1d is not available." msg_b2b_causal_conv1d = f"Problem importing subquadratic_ops: {e}. b2b_causal_conv1d is not available." @@ -471,7 +460,17 @@ def hyena_no_weight_decay_cond_with_embeddings(name, param): return ("embedding" in name) or hyena_no_weight_decay_cond(name, param) -def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=False, use_subquadratic_ops=False): # noqa: N803 +def fftconv_func( + u, + k, + D, # noqa: N803 + dropout_mask, + gelu=True, + k_rev=None, + bidirectional=False, + use_subquadratic_ops=False, + check_subquadratic_ops=True, +): """Apply a 1D convolution to the input sequence u using the filter k and the shortcut D.""" seqlen = u.shape[-1] fft_size = 2 * seqlen @@ -504,6 +503,8 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal # causal else: if use_subquadratic_ops: + if check_subquadratic_ops and u.is_cuda: + ensure_subquadratic_fft_causal_conv1d_supported() y = fft_causal_conv1d(u, k.squeeze(0)) else: fft_size = max(fft_size, 2 * k.shape[-1]) @@ -902,6 +903,7 @@ def __init__( self.zigzag = zigzag self.use_subquadratic_ops = transformer_config.use_subquadratic_ops + self._subquadratic_ops_checked = False self.model_parallel_size = self.pg_collection.tp.size() if self.pg_collection.tp is not None else 1 self.model_parallel_rank = self.pg_collection.tp.rank() if self.pg_collection.tp is not None else 0 @@ -984,6 +986,16 @@ def reset_parameters(self): bounds = math.sqrt(1 / self.kernel_size) torch.nn.init.uniform_(self.conv_bias, a=-bounds, b=bounds) + def _ensure_subquadratic_ops_supported(self): + """Run expensive subquadratic-op CUDA self-tests once per operator instance.""" + if self._subquadratic_ops_checked or not self.use_subquadratic_ops: + return + if self.operator_type == "hyena_medium_conv" and self.kernel_size < 128: + ensure_subquadratic_causal_conv1d_supported() + else: + ensure_subquadratic_fft_causal_conv1d_supported() + self._subquadratic_ops_checked = True + def forward_long(self, *, x1, x2, v, h, bias, inference_context): """Forward pass long.""" import bionemo.evo2.models.megatron.hyena.engine as engine @@ -1074,6 +1086,7 @@ def get_filter_state(filter_name): fir_length=self.kernel_size, # self.short_filter_length, compute_state=inference_context is not None, use_subquadratic_ops=self.use_subquadratic_ops, + check_subquadratic_ops=False, ) y = rearrange(y, "b d l -> b l d") y = y * x1 @@ -1099,6 +1112,8 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None): Input shapes: bs, (num_groups, group_size), seq_length Output shapes: bs, (num_groups, group_size), seq_length """ + if x1.is_cuda: + self._ensure_subquadratic_ops_supported() B, GDG, L = x1.shape # noqa: N806 x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L] @@ -1189,6 +1204,7 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None): gelu=False, bidirectional=self.bidirectional, use_subquadratic_ops=self.use_subquadratic_ops, + check_subquadratic_ops=False, ) z = z.to(v.dtype) @@ -1388,6 +1404,7 @@ def __init__( self.num_groups = num_groups self.transformer_config = transformer_config self.use_subquadratic_ops = transformer_config.use_subquadratic_ops + self._subquadratic_ops_checked = False self.short_conv_L = hyena_config.short_conv_L self.local_init = local_init if pg_collection is None: @@ -1543,6 +1560,7 @@ def __init__( """ super().__init__() self.b2b_causal_conv1d_fn = b2b_causal_conv1d + self._check_subquadratic_ops = b2b_causal_conv1d is globals()["b2b_causal_conv1d"] if pg_collection is None: pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.pg_collection = pg_collection @@ -1567,6 +1585,14 @@ def __init__( raise ValueError(f"Operator type {operator_type} not supported") self.effective_pad_size = (self._mixer_kernel_size - 1) + (self._proj_conv_kernel_size - 1) + self._subquadratic_ops_checked = False + + def _ensure_subquadratic_ops_supported(self): + """Run the B2B CUDA self-test once per wrapper instance.""" + if self._subquadratic_ops_checked or not self._check_subquadratic_ops: + return + ensure_subquadratic_b2b_causal_conv1d_supported() + self._subquadratic_ops_checked = True def forward(self, x, _use_cp=True): """Forward pass for the B2BCausalConv1dModule. @@ -1580,6 +1606,8 @@ def forward(self, x, _use_cp=True): # Validate input dimensions if x.dim() != 3: raise ValueError("Input tensor must be 3D [batch_size, hidden_dim, seq_len]") + if x.is_cuda: + self._ensure_subquadratic_ops_supported() # Extract weights at runtime to avoid parameter registration proj_weight = self._proj_conv_module.short_conv_weight @@ -1713,6 +1741,9 @@ def get_filter_state(filter_name): L = u.shape[1] # noqa: N806 fir_state = get_filter_state("fir") if fir_state is None: + if self.use_subquadratic_ops and u.is_cuda and not self._subquadratic_ops_checked: + ensure_subquadratic_causal_conv1d_supported() + self._subquadratic_ops_checked = True z_pre, fir_state = engine.parallel_fir( u=u, weight=torch.tensor(weight), # self.short_filter_weight, @@ -1722,6 +1753,7 @@ def get_filter_state(filter_name): fir_length=self.kernel_size, # self.short_filter_length, compute_state=inference_context is not None, use_subquadratic_ops=self.use_subquadratic_ops, + check_subquadratic_ops=False, ) else: if len(u.shape) > 2: diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py index f350700cb9..cb50dad10f 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py @@ -278,6 +278,29 @@ def test_b2b_causal_conv1d_module_device_handling(): # noqa: D103 assert result_cuda.device == x_cuda.device, "Device mismatch on CUDA" +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for subquadratic guard test") +@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.ensure_subquadratic_b2b_causal_conv1d_supported") +@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.b2b_causal_conv1d") +def test_b2b_causal_conv1d_module_checks_subquadratic_kernel_once(mock_b2b, mock_ensure): # noqa: D103 + mock_b2b.side_effect = mock_b2b_causal_conv1d + proj_conv = MockProjConv(kernel_size=3) + mixer = MockMixer(kernel_size=5) + b2b_module = B2BCausalConv1dModule( + proj_conv, + mixer, + operator_type="hyena_short_conv", + b2b_causal_conv1d=mock_b2b, + pg_collection=MockProcessGroupCollection(), + ) + + x = torch.randn(2, 96, 32, device="cuda") + b2b_module(x) + b2b_module(x) + + assert mock_ensure.call_count == 1 + assert mock_b2b.call_count == 2 + + def test_b2b_causal_conv1d_effective_padding_size(): """Test the zigzag pattern for data distribution in context parallel mode.""" proj_conv = MockProjConv(kernel_size=3) From 5736e04efecb28aca908d6bf55fab96a3e1ca8fe Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 13:36:49 -0700 Subject: [PATCH 07/15] Add back more of the removed subq ops calls Signed-off-by: John St. John --- .../evo2/models/megatron/hyena/hyena_utils.py | 12 ++++++--- .../models/megatron/hyena/test_hyena_utils.py | 26 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index aae76e8f77..d91342f7c3 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -1492,10 +1492,16 @@ def forward(self, x, inference_context=None, _use_cp=True): else: x = F.pad(x, (pad_size, 0)) - # subquadratic_ops causal_conv1d is only applied to the projection conv of Hyena LI layer - # Projection conv is fused with SE/MR layers (B2BCausalConv1dModule) + # subquadratic_ops causal_conv1d is only applied to the projection conv of Hyena LI layer. + # Projection conv is fused with SE/MR layers by B2BCausalConv1dModule when available. if self.use_fast_causal_conv: # hyena_proj_conv case - y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] + if self.use_subquadratic_ops: + if x.is_cuda and not self._subquadratic_ops_checked: + ensure_subquadratic_causal_conv1d_supported() + self._subquadratic_ops_checked = True + y = causal_conv1d(x, weight)[..., pad_size:] + else: + y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] else: # hyena_short_conv case y = F.conv1d( x, diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py index cb50dad10f..c00efd780f 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py @@ -25,6 +25,7 @@ from bionemo.evo2.models.megatron.hyena.hyena_utils import ( B2BCausalConv1dModule, ExchangeOverlappingRegionsCausal, + ParallelCausalDepthwiseConv1d, _get_inverse_zigzag_indices, _get_zigzag_indices, divide, @@ -121,6 +122,31 @@ def mock_b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias): return x +@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.causal_conv1d_fn") +@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.causal_conv1d") +def test_parallel_causal_depthwise_conv1d_uses_subquadratic_fast_conv( + mock_subq_causal_conv1d, mock_fast_causal_conv1d +): + """Fast projection conv should honor use_subquadratic_ops.""" + mock_subq_causal_conv1d.side_effect = lambda x, weight: torch.zeros_like(x) + x = torch.randn(2, 4, 8) + module = types.SimpleNamespace( + kernel_size=3, + short_conv_weight=torch.ones(4, 3), + group_dim=1, + pg_collection=types.SimpleNamespace(cp=None), + use_fast_causal_conv=True, + use_subquadratic_ops=True, + _subquadratic_ops_checked=False, + ) + + y = ParallelCausalDepthwiseConv1d.forward(module, x, _use_cp=False) + + assert y.shape == x.shape + mock_subq_causal_conv1d.assert_called_once() + mock_fast_causal_conv1d.assert_not_called() + + @pytest.mark.parametrize("operator_type", ["hyena_short_conv", "hyena_medium_conv"]) def test_b2b_causal_conv1d_module_initialization(operator_type): # noqa: D103 proj_conv = MockProjConv(kernel_size=3) From 4f57ec22cbf83a16564a1f8797f10c0fb902eb0f Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 13:45:14 -0700 Subject: [PATCH 08/15] Remove conftest diffs Signed-off-by: John St. John --- .../tests/bionemo/evo2/conftest.py | 29 +------------------ 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py index e9e427282b..45dc2b0ee7 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py @@ -14,42 +14,15 @@ # limitations under the License. -# conftest.py -# ruff: noqa: E402, I001 import copy import gc -import importlib import os import shlex import subprocess -import sys from pathlib import Path -from types import ModuleType import pytest - -# Load Transformer Engine's core library before torch; this avoids a CUDA library -# symbol-resolution failure seen when Megatron imports torch before TE is loaded. -importlib.import_module("transformer_engine.pytorch") -torch = importlib.import_module("torch") - -# Megatron Bridge imports Transformers conversion modules during collection. In -# this NGC image, torchvision is present but its custom ops are unavailable, so -# make Transformers skip optional vision imports for these non-vision tests. -transformers_import_utils = importlib.import_module("transformers.utils.import_utils") -transformers_import_utils._torchvision_available = False - - -# Megatron Bridge eagerly imports VLM recipe modules through its recipes package. -# Those modules import qwen_vl_utils, which imports torchvision. Provide the -# non-vision tests with a stub so Evo2 collection does not depend on VLM extras. -def _unavailable_process_vision_info(*args, **kwargs): - raise RuntimeError("qwen_vl_utils is unavailable in Evo2 tests") - - -qwen_vl_utils = ModuleType("qwen_vl_utils") -qwen_vl_utils.process_vision_info = _unavailable_process_vision_info -sys.modules["qwen_vl_utils"] = qwen_vl_utils +import torch from bionemo.core.data.load import load as bionemo_load from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH_512 From c372637172a8ab549ad2edcced5d3a63d6fd6d5d Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 13:59:01 -0700 Subject: [PATCH 09/15] Reduce the number of inner loop checks for compatibility Signed-off-by: John St. John --- .../evo2/models/megatron/hyena/engine.py | 10 ----- .../evo2/models/megatron/hyena/hyena_utils.py | 42 ------------------- .../megatron/hyena/subquadratic_safety.py | 8 ++++ .../src/bionemo/evo2/run/infer.py | 3 ++ .../src/bionemo/evo2/run/predict.py | 3 ++ .../src/bionemo/evo2/run/train.py | 7 +++- .../evo2/models/megatron/hyena/test_engine.py | 31 +++++++------- .../megatron/hyena/test_hyena_mixer_kernel.py | 5 +++ .../models/megatron/hyena/test_hyena_utils.py | 33 +++------------ 9 files changed, 45 insertions(+), 97 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py index 652dac3759..83a720446f 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py @@ -19,11 +19,6 @@ import torch.nn.functional as F # noqa: N812 from einops import rearrange -from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ( - ensure_subquadratic_causal_conv1d_supported, - ensure_subquadratic_fft_causal_conv1d_supported, -) - try: from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d @@ -83,7 +78,6 @@ def parallel_fir( fir_length, compute_state, use_subquadratic_ops=False, - check_subquadratic_ops=True, ): """Compute parallel finite impulse response filtering with optional state computation.""" L = u.shape[1] # noqa: N806 @@ -95,8 +89,6 @@ def parallel_fir( if fir_length >= 128: if use_subquadratic_ops: # subq-ops fft_causal_conv1d expects [B, D, L] input and [D, L] filter; dtypes must match - if check_subquadratic_ops and u.is_cuda: - ensure_subquadratic_fft_causal_conv1d_supported() k = weight[:, :, :L].squeeze(1) if weight.dim() == 3 else weight[:, :L] u_fp32 = u.to(torch.float32) z = _subq_fft_causal_conv1d(u_fp32, k.to(torch.float32)) @@ -115,8 +107,6 @@ def parallel_fir( if _subq_causal_conv1d is None: raise ImportError(_subq_error_msg) # subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight. - if check_subquadratic_ops and u.is_cuda: - ensure_subquadratic_causal_conv1d_supported() pad_size = fir_length - 1 x_padded = F.pad(u.to(torch.float32), (pad_size, 0)) w = weight.squeeze(1) if weight.dim() == 3 else weight diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index d91342f7c3..08e66b2d6a 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -33,11 +33,6 @@ from torch.autograd.function import Function from bionemo.evo2.models.megatron.hyena.hyena_config import HyenaConfig -from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ( - ensure_subquadratic_b2b_causal_conv1d_supported, - ensure_subquadratic_causal_conv1d_supported, - ensure_subquadratic_fft_causal_conv1d_supported, -) try: @@ -469,7 +464,6 @@ def fftconv_func( k_rev=None, bidirectional=False, use_subquadratic_ops=False, - check_subquadratic_ops=True, ): """Apply a 1D convolution to the input sequence u using the filter k and the shortcut D.""" seqlen = u.shape[-1] @@ -503,8 +497,6 @@ def fftconv_func( # causal else: if use_subquadratic_ops: - if check_subquadratic_ops and u.is_cuda: - ensure_subquadratic_fft_causal_conv1d_supported() y = fft_causal_conv1d(u, k.squeeze(0)) else: fft_size = max(fft_size, 2 * k.shape[-1]) @@ -903,7 +895,6 @@ def __init__( self.zigzag = zigzag self.use_subquadratic_ops = transformer_config.use_subquadratic_ops - self._subquadratic_ops_checked = False self.model_parallel_size = self.pg_collection.tp.size() if self.pg_collection.tp is not None else 1 self.model_parallel_rank = self.pg_collection.tp.rank() if self.pg_collection.tp is not None else 0 @@ -986,16 +977,6 @@ def reset_parameters(self): bounds = math.sqrt(1 / self.kernel_size) torch.nn.init.uniform_(self.conv_bias, a=-bounds, b=bounds) - def _ensure_subquadratic_ops_supported(self): - """Run expensive subquadratic-op CUDA self-tests once per operator instance.""" - if self._subquadratic_ops_checked or not self.use_subquadratic_ops: - return - if self.operator_type == "hyena_medium_conv" and self.kernel_size < 128: - ensure_subquadratic_causal_conv1d_supported() - else: - ensure_subquadratic_fft_causal_conv1d_supported() - self._subquadratic_ops_checked = True - def forward_long(self, *, x1, x2, v, h, bias, inference_context): """Forward pass long.""" import bionemo.evo2.models.megatron.hyena.engine as engine @@ -1086,7 +1067,6 @@ def get_filter_state(filter_name): fir_length=self.kernel_size, # self.short_filter_length, compute_state=inference_context is not None, use_subquadratic_ops=self.use_subquadratic_ops, - check_subquadratic_ops=False, ) y = rearrange(y, "b d l -> b l d") y = y * x1 @@ -1112,8 +1092,6 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None): Input shapes: bs, (num_groups, group_size), seq_length Output shapes: bs, (num_groups, group_size), seq_length """ - if x1.is_cuda: - self._ensure_subquadratic_ops_supported() B, GDG, L = x1.shape # noqa: N806 x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L] @@ -1204,7 +1182,6 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None): gelu=False, bidirectional=self.bidirectional, use_subquadratic_ops=self.use_subquadratic_ops, - check_subquadratic_ops=False, ) z = z.to(v.dtype) @@ -1404,7 +1381,6 @@ def __init__( self.num_groups = num_groups self.transformer_config = transformer_config self.use_subquadratic_ops = transformer_config.use_subquadratic_ops - self._subquadratic_ops_checked = False self.short_conv_L = hyena_config.short_conv_L self.local_init = local_init if pg_collection is None: @@ -1496,9 +1472,6 @@ def forward(self, x, inference_context=None, _use_cp=True): # Projection conv is fused with SE/MR layers by B2BCausalConv1dModule when available. if self.use_fast_causal_conv: # hyena_proj_conv case if self.use_subquadratic_ops: - if x.is_cuda and not self._subquadratic_ops_checked: - ensure_subquadratic_causal_conv1d_supported() - self._subquadratic_ops_checked = True y = causal_conv1d(x, weight)[..., pad_size:] else: y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] @@ -1566,7 +1539,6 @@ def __init__( """ super().__init__() self.b2b_causal_conv1d_fn = b2b_causal_conv1d - self._check_subquadratic_ops = b2b_causal_conv1d is globals()["b2b_causal_conv1d"] if pg_collection is None: pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.pg_collection = pg_collection @@ -1591,14 +1563,6 @@ def __init__( raise ValueError(f"Operator type {operator_type} not supported") self.effective_pad_size = (self._mixer_kernel_size - 1) + (self._proj_conv_kernel_size - 1) - self._subquadratic_ops_checked = False - - def _ensure_subquadratic_ops_supported(self): - """Run the B2B CUDA self-test once per wrapper instance.""" - if self._subquadratic_ops_checked or not self._check_subquadratic_ops: - return - ensure_subquadratic_b2b_causal_conv1d_supported() - self._subquadratic_ops_checked = True def forward(self, x, _use_cp=True): """Forward pass for the B2BCausalConv1dModule. @@ -1612,8 +1576,6 @@ def forward(self, x, _use_cp=True): # Validate input dimensions if x.dim() != 3: raise ValueError("Input tensor must be 3D [batch_size, hidden_dim, seq_len]") - if x.is_cuda: - self._ensure_subquadratic_ops_supported() # Extract weights at runtime to avoid parameter registration proj_weight = self._proj_conv_module.short_conv_weight @@ -1747,9 +1709,6 @@ def get_filter_state(filter_name): L = u.shape[1] # noqa: N806 fir_state = get_filter_state("fir") if fir_state is None: - if self.use_subquadratic_ops and u.is_cuda and not self._subquadratic_ops_checked: - ensure_subquadratic_causal_conv1d_supported() - self._subquadratic_ops_checked = True z_pre, fir_state = engine.parallel_fir( u=u, weight=torch.tensor(weight), # self.short_filter_weight, @@ -1759,7 +1718,6 @@ def get_filter_state(filter_name): fir_length=self.kernel_size, # self.short_filter_length, compute_state=inference_context is not None, use_subquadratic_ops=self.use_subquadratic_ops, - check_subquadratic_ops=False, ) else: if len(u.shape) > 2: diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py index 1ef01ee87f..7ef53c901a 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py @@ -41,6 +41,14 @@ def _assert_close_or_raise(op_name: str, actual: torch.Tensor, expected: torch.T _raise_subquadratic_self_test_error(op_name, f"max_diff={max_diff:.6g}, rel={rel:.6g}") +@lru_cache(maxsize=None) +def ensure_subquadratic_ops_supported(device_index: int | None = None) -> None: + """Validate all subquadratic_ops_torch CUDA kernels used by Evo2.""" + ensure_subquadratic_causal_conv1d_supported(device_index) + ensure_subquadratic_fft_causal_conv1d_supported(device_index) + ensure_subquadratic_b2b_causal_conv1d_supported(device_index) + + @lru_cache(maxsize=None) def ensure_subquadratic_causal_conv1d_supported(device_index: int | None = None) -> None: """Validate subquadratic_ops_torch.causal_conv1d before using it for model data.""" diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py index 6219d5854c..af0420b1a2 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py @@ -107,6 +107,7 @@ from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH from bionemo.evo2.models.evo2_provider import HyenaInferenceContext +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported from bionemo.evo2.run.predict import initialize_inference_distributed, resolve_checkpoint_path from bionemo.evo2.run.text_generation_controller import Evo2TextGenerationController @@ -469,6 +470,8 @@ def setup_inference_engine( dist_config=dist_config, ) logger.info("Initialized distributed environment") + if use_subquadratic_ops: + ensure_subquadratic_ops_supported() # ------------------------------------------------------------------------- # Step 5: Create model and load weights diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py index 66ba70b1af..9b83c45844 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py @@ -106,6 +106,7 @@ from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported from bionemo.recipeutils.inference.collation import batch_collator @@ -1093,6 +1094,8 @@ def predict( dist_config=dist_config, ) logger.info("Initialized distributed environment") + if use_subquadratic_ops: + ensure_subquadratic_ops_supported() # ------------------------------------------------------------------------- # Step 5: Create model and load weights diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py index 09b5b9df23..0b5d628dc9 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py @@ -35,10 +35,11 @@ from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES from megatron.bridge.training.post_training.checkpointing import has_modelopt_state from megatron.bridge.training.pretrain import pretrain -from megatron.bridge.utils.common_utils import get_rank_safe +from megatron.bridge.utils.common_utils import get_local_rank_preinit, get_rank_safe from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH from bionemo.evo2.models.evo2_provider import MODEL_OPTIONS, hyena_forward_step, infer_model_type +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported from bionemo.evo2.recipes.evo2 import evo2_1b_pretrain_config as pretrain_config @@ -885,7 +886,9 @@ def train(args: argparse.Namespace) -> None: if args.num_layers: cfg.model.num_layers = args.num_layers if args.use_subquadratic_ops: - # TODO assert that it is installed + if torch.cuda.is_available(): + torch.cuda.set_device(get_local_rank_preinit()) + ensure_subquadratic_ops_supported() cfg.model.use_subquadratic_ops = True if args.no_activation_checkpointing: diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py index e0b1c5e5b6..b5701d2ddc 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py @@ -18,6 +18,7 @@ import torch.nn.functional as F # noqa: N812 from bionemo.evo2.models.megatron.hyena import engine +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported def test_fftconv_func_is_prefix_invariant_when_filter_is_longer_than_input(): @@ -83,6 +84,11 @@ def test_parallel_fir_short_cuda_path_matches_torch_depthwise_conv1d(use_subquad """Short FIR prefill should match F.conv1d or fail before returning bad subq output.""" if not torch.cuda.is_available(): pytest.skip("short FIR CUDA path requires CUDA") + if use_subquadratic_ops: + try: + ensure_subquadratic_ops_supported() + except RuntimeError as e: + pytest.xfail(str(e)) torch.manual_seed(1234) batch_size = 2 @@ -95,21 +101,16 @@ def test_parallel_fir_short_cuda_path_matches_torch_depthwise_conv1d(use_subquad weight = torch.randn(hidden_size, 1, kernel_size, device=device) bias = torch.randn(hidden_size, device=device) - try: - actual, state = engine.parallel_fir( - u=u, - weight=weight, - bias=bias, - L=seq_len, - gated_bias=True, - fir_length=kernel_size, - compute_state=True, - use_subquadratic_ops=use_subquadratic_ops, - ) - except RuntimeError as e: - if use_subquadratic_ops and "failed a CUDA self-test" in str(e): - pytest.xfail(str(e)) - raise + actual, state = engine.parallel_fir( + u=u, + weight=weight, + bias=bias, + L=seq_len, + gated_bias=True, + fir_length=kernel_size, + compute_state=True, + use_subquadratic_ops=use_subquadratic_ops, + ) u_bdl = u.transpose(1, 2).contiguous() expected = F.conv1d( diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py index cb7a6ea564..5a21f89fdb 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py @@ -26,6 +26,7 @@ from bionemo.evo2.models.megatron.hyena.hyena_layer_specs import hyena_stack_spec_no_te from bionemo.evo2.models.megatron.hyena.hyena_mixer import HyenaMixer from bionemo.evo2.models.megatron.hyena.hyena_utils import ImplicitModalFilter +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported from ....utils import distributed_model_parallel_state @@ -254,6 +255,10 @@ def test_subquadratic_ops_kernel( # noqa: D103 # Skip bf16 with short convolution due to numerical instability if test_config.params_dtype == torch.bfloat16 and operator_type == "hyena_short_conv": pytest.skip("bf16 with short convolution is skipped due to numerical instability") + try: + ensure_subquadratic_ops_supported() + except RuntimeError as e: + pytest.xfail(str(e)) with distributed_model_parallel_state(): # Create both models inside the same distributed context diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py index c00efd780f..a34cae04eb 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py @@ -37,6 +37,7 @@ wang_init_method, zigzag_get_overlapping_patches, ) +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported class MockProcessGroup: @@ -137,7 +138,6 @@ def test_parallel_causal_depthwise_conv1d_uses_subquadratic_fast_conv( pg_collection=types.SimpleNamespace(cp=None), use_fast_causal_conv=True, use_subquadratic_ops=True, - _subquadratic_ops_checked=False, ) y = ParallelCausalDepthwiseConv1d.forward(module, x, _use_cp=False) @@ -304,29 +304,6 @@ def test_b2b_causal_conv1d_module_device_handling(): # noqa: D103 assert result_cuda.device == x_cuda.device, "Device mismatch on CUDA" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for subquadratic guard test") -@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.ensure_subquadratic_b2b_causal_conv1d_supported") -@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.b2b_causal_conv1d") -def test_b2b_causal_conv1d_module_checks_subquadratic_kernel_once(mock_b2b, mock_ensure): # noqa: D103 - mock_b2b.side_effect = mock_b2b_causal_conv1d - proj_conv = MockProjConv(kernel_size=3) - mixer = MockMixer(kernel_size=5) - b2b_module = B2BCausalConv1dModule( - proj_conv, - mixer, - operator_type="hyena_short_conv", - b2b_causal_conv1d=mock_b2b, - pg_collection=MockProcessGroupCollection(), - ) - - x = torch.randn(2, 96, 32, device="cuda") - b2b_module(x) - b2b_module(x) - - assert mock_ensure.call_count == 1 - assert mock_b2b.call_count == 2 - - def test_b2b_causal_conv1d_effective_padding_size(): """Test the zigzag pattern for data distribution in context parallel mode.""" proj_conv = MockProjConv(kernel_size=3) @@ -344,14 +321,14 @@ def test_b2b_causal_conv1d_effective_padding_size(): assert b2b_module.effective_pad_size == expected_pad_size -@pytest.mark.xfail( - reason="subquadratic-ops fused B2B kernel may fail CUDA/PTX self-test on unsupported GPUs", - strict=True, -) def test_b2b_causal_conv1d_module_matches_sequential_reference(): """Document the isolated B2B CUDA kernel behavior before relying on the fused path.""" if not torch.cuda.is_available(): pytest.skip("B2B causal conv isolation test requires CUDA") + try: + ensure_subquadratic_ops_supported() + except RuntimeError as e: + pytest.xfail(str(e)) torch.manual_seed(1234) batch_size = 2 From 5411e0937fff83f4216445b1dee2738995ff7b5b Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 14:17:28 -0700 Subject: [PATCH 10/15] Address PR feedback Signed-off-by: John St. John --- .../evo2/models/megatron/hyena/hyena_block.py | 16 ++++++++++---- .../evo2/models/megatron/hyena/hyena_utils.py | 8 +++---- .../models/megatron/hyena/test_hyena_utils.py | 21 +++++++++++++++++++ 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py index c413dd94b7..ac03b7a999 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py @@ -121,8 +121,9 @@ def __init__( pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(layer_type_list) if get_cpu_offload_context is not None: - # Megatron Core changed this helper from six to seven positional arguments - # across releases. Pass only the arguments accepted by the installed version. + # MCore 0.x has shipped both six- and seven-argument variants of this helper. + # Pass only the arguments accepted by the installed version; if a future helper + # uses *args, pass the full compatibility list rather than counting *args as one slot. offload_args = [ self.config.cpu_offloading, self.config.cpu_offloading_num_layers, @@ -132,9 +133,16 @@ def __init__( self.config.cpu_offloading_double_buffering, getattr(self.config, "cpu_offloading_retain_pinned_cpu_buffers", False), ] - num_offload_params = len(inspect.signature(get_cpu_offload_context).parameters) + offload_params = tuple(inspect.signature(get_cpu_offload_context).parameters.values()) + if any(param.kind is inspect.Parameter.VAR_POSITIONAL for param in offload_params): + num_offload_args = len(offload_args) + else: + num_offload_args = sum( + param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + for param in offload_params + ) (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( - *offload_args[:num_offload_params], + *offload_args[:num_offload_args], ) self.config._cpu_offloading_context = self.offload_context if self.config.cpu_offloading else None else: diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index 08e66b2d6a..6f5ae3abae 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -467,7 +467,7 @@ def fftconv_func( ): """Apply a 1D convolution to the input sequence u using the filter k and the shortcut D.""" seqlen = u.shape[-1] - fft_size = 2 * seqlen + fft_size = max(2 * seqlen, 2 * k.shape[-1]) # check if k is less than seqlen -- subquadratic_ops input does not need padding if not use_subquadratic_ops and k.shape[-1] < seqlen: @@ -499,7 +499,6 @@ def fftconv_func( if use_subquadratic_ops: y = fft_causal_conv1d(u, k.squeeze(0)) else: - fft_size = max(fft_size, 2 * k.shape[-1]) k_f = torch.fft.rfft(k, n=fft_size) / fft_size if k_rev is not None: k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size @@ -646,6 +645,8 @@ def compute_filter(self, L, t, glogp, R): # noqa: N803 return h, None + # Keep this eager. The short-prefill prefix-invariance tests in tests/bionemo/evo2/run + # cover the prior torch.compile regression with dynamic filter lengths and custom ops. def filter(self, L, *args, **kwargs): # noqa: N803 """Get t and the convolution filter for t and the requested sequence length.""" if self._cp_size > 1: @@ -768,8 +769,7 @@ def forward(self, L, *args, **kwargs): # noqa: N803 """ return self.filter(L, *args, **kwargs) - # Keep this eager. Compiling this helper can leave global dispatcher state - # that interferes with unrelated custom autograd/custom-op call sites. + # Keep this eager for the same short-prefill prefix-invariance reproducer as ImplicitModalFilter.filter. def filter(self, L, *args, **kwargs): # noqa: N803 """Compute the filter as a function of h and decay for the requested sequence length.""" h = self.h[:, :L] diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py index a34cae04eb..eabbc62838 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py @@ -537,6 +537,27 @@ def test_fftconv_func(): assert output_short.shape == u.shape +def test_fftconv_func_bidirectional_is_prefix_invariant_when_filter_is_longer_than_input(): + """Bidirectional FFT convolution should not alias short prefixes when the filter is long.""" + torch.manual_seed(1234) + batch_size = 2 + short_len = 5 + long_len = 64 + hidden_size = 4 + filter_len = 64 + + u_short = torch.randn(batch_size, hidden_size, short_len) + u_long = torch.zeros(batch_size, hidden_size, long_len) + u_long[..., :short_len] = u_short + k = torch.randn(1, 2 * hidden_size, filter_len) + D = torch.randn(hidden_size) # noqa: N806 + + short_out = fftconv_func(u_short, k, D, None, gelu=False, bidirectional=True) + long_out = fftconv_func(u_long, k, D, None, gelu=False, bidirectional=True)[..., :short_len] + + torch.testing.assert_close(short_out, long_out, rtol=1e-5, atol=1e-5) + + def test_fftconv_func_high_dimensional_input(): """Test fftconv_func with high-dimensional input to cover the len(u.shape) > 3 case.""" batch_size = 2 From 1fa72608748fc6741ecdcfda73b72d7ae26e3f4b Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 14:21:47 -0700 Subject: [PATCH 11/15] Update default fp32 residual Signed-off-by: John St. John --- .../evo2_megatron/src/bionemo/evo2/models/evo2_provider.py | 2 +- .../recipes/evo2_megatron/src/bionemo/evo2/run/predict.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py index 6dcab77575..6decfbb50b 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py @@ -306,7 +306,7 @@ class HyenaModelProvider(TransformerConfig, ModelProviderMixin[MCoreHyenaModel]) apply_rope_fusion: bool = True make_vocab_size_divisible_by: int = 128 gated_linear_unit: bool = True - fp32_residual_connection: bool = True + fp32_residual_connection: bool = False normalization: str = "RMSNorm" add_bias_linear: bool = False hidden_dropout: float = 0.0 diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py index 9b83c45844..ab30732586 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py @@ -1046,9 +1046,6 @@ def predict( # Set the model to use fewer layers and skip post-processing (output heads). model_provider.num_layers = target_num_layers model_provider.post_process = False - if hasattr(model_provider, "fp32_residual_connection"): - model_provider.fp32_residual_connection = False - # Also truncate the hybrid_override_pattern if it exists, since it must match num_layers if hasattr(model_provider, "hybrid_override_pattern") and model_provider.hybrid_override_pattern is not None: original_pattern = model_provider.hybrid_override_pattern From e8fb3c3b78638f810746504a7a1ccc0cbd546702 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 15:40:13 -0700 Subject: [PATCH 12/15] Add missing pytest dep Signed-off-by: John St. John --- bionemo-recipes/recipes/evo2_megatron/pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml index 93246f161f..c5a0d9ffb4 100644 --- a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml +++ b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml @@ -36,7 +36,9 @@ dependencies = [ ] [project.optional-dependencies] -test = [] +test = [ + "pytest>=8.0", +] [project.scripts] torchrun = "torch.distributed.run:main" From 53fdc45c7ccb09000532928d623e2a8c50ec4aaf Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 16:03:58 -0700 Subject: [PATCH 13/15] Fix changed import in infer.py Signed-off-by: John St. John --- .../src/bionemo/evo2/run/infer.py | 87 +++++++++++++++++-- 1 file changed, 82 insertions(+), 5 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py index af0420b1a2..5fa22e8840 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py @@ -59,6 +59,7 @@ import argparse import gc +import inspect import json import logging import os @@ -91,16 +92,40 @@ read_run_config, ) from megatron.bridge.utils.common_utils import get_world_size_safe -from megatron.bridge.utils.instantiate_utils import instantiate +from megatron.bridge.utils.instantiate_utils import instantiate, register_allowed_target_prefix from megatron.core import dist_checkpointing, parallel_state from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.inference.engines.static_engine import StaticInferenceEngine from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( AbstractModelInferenceWrapper, ) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) + + +try: + from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, + ) +except ImportError: + + @dataclass + class InferenceWrapperConfig: + """Compatibility shim for MCore versions that removed InferenceWrapperConfig.""" + + hidden_size: int + inference_max_requests: int + inference_max_seq_length: int + inference_batch_times_seqlen_threshold: int + params_dtype: torch.dtype + padded_vocab_size: int + nccl_all_reduce_for_prefill: bool = False + moe_pad_experts_for_cuda_graph_inference: bool = False + + def add_attributes(self, attributes: dict[str, Any]) -> None: + """Match the old MCore config helper used by Evo2TextGenerationController.""" + for name, value in attributes.items(): + setattr(self, name, value) + + from megatron.core.inference.sampling_params import SamplingParams from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_model_config @@ -115,6 +140,52 @@ logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +_WRAPPER_INIT_ACCEPTS_CONFIG = ( + "inference_wrapper_config" in inspect.signature(AbstractModelInferenceWrapper.__init__).parameters +) + + +class _TextGenerationTokenizerAdapter: + """Expose the tokenizer methods expected by MCore's static text-generation path.""" + + def __init__(self, tokenizer: _HuggingFaceTokenizer): + self._tokenizer = tokenizer + + def __getattr__(self, name: str) -> Any: + return getattr(self._tokenizer, name) + + @property + def vocab_size(self) -> int: + return self._tokenizer.vocab_size + + @property + def bos(self) -> Optional[int]: + return getattr(self._tokenizer, "bos", None) + + @property + def eod(self) -> Optional[int]: + return getattr(self._tokenizer, "eod", None) + + def tokenize(self, text: str) -> list[int]: + if hasattr(self._tokenizer, "tokenize"): + return self._tokenizer.tokenize(text) + return self._tokenizer.text_to_ids(text) + + def detokenize(self, tokens: list[int], skip_special_tokens: bool = True) -> str: + if hasattr(self._tokenizer, "detokenize"): + return self._tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens) + return self._tokenizer.ids_to_text(tokens) + + def offsets(self, tokens: list[int], text: str) -> list[int]: + if hasattr(self._tokenizer, "offsets"): + return self._tokenizer.offsets(tokens, text) + offsets = [] + position = 0 + for token in tokens: + offsets.append(position) + position += len(self.detokenize([token], skip_special_tokens=False)) + return offsets + # ============================================================================= # Hardware-Aware Defaults @@ -243,7 +314,11 @@ def __init__( inference_wrapper_config: Configuration with hidden size, vocab size, etc. inference_context: Context for managing state and sequence offsets. """ - super().__init__(model, inference_wrapper_config, inference_context) + self.inference_wrapper_config = inference_wrapper_config + if _WRAPPER_INIT_ACCEPTS_CONFIG: + super().__init__(model, inference_wrapper_config, inference_context) + else: + super().__init__(model, inference_context) def prep_inference_input(self, prompts_tokens: torch.Tensor) -> Dict[str, Any]: """Prepare the inference input data. @@ -410,6 +485,7 @@ def setup_inference_engine( raise FileNotFoundError(f"run_config.yaml not found at {run_config_filename}") run_config = read_run_config(run_config_filename) + register_allowed_target_prefix("bionemo.") model_provider = instantiate(run_config["model"]) logger.info(f"Instantiated model provider: {type(model_provider).__name__}") @@ -446,6 +522,7 @@ def setup_inference_engine( tokenizer = _HuggingFaceTokenizer(tokenizer_dir) else: tokenizer = _HuggingFaceTokenizer(DEFAULT_HF_TOKENIZER_MODEL_PATH) + tokenizer = _TextGenerationTokenizerAdapter(tokenizer) model_provider.vocab_size = tokenizer.vocab_size model_provider.should_pad_vocab = True From 3b8ebc202ab8c588a1ece65d1443bf063cd1f660 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 22 May 2026 17:11:51 -0700 Subject: [PATCH 14/15] Register allowed prefix Signed-off-by: John St. John --- .../evo2_megatron/src/bionemo/evo2/models/evo2_provider.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py index 6decfbb50b..8f310676ff 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py @@ -35,6 +35,7 @@ from megatron.bridge.training.state import GlobalState from megatron.bridge.training.utils.packed_seq_utils import get_packed_seq_params from megatron.bridge.training.utils.pg_utils import get_pg_collection +from megatron.bridge.utils.instantiate_utils import register_allowed_target_prefix from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size from megatron.core import parallel_state from megatron.core.inference.contexts import StaticInferenceContext @@ -53,6 +54,9 @@ from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond +register_allowed_target_prefix("bionemo.evo2.") + + def get_vocab_size(*args, **kwargs): raise NotImplementedError("FIXME get_vocab_size is not implemented Find it in megatron bridge") From 74ac48f8ba45a7a5a67a7ecd2ce266c8c928b7c5 Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Tue, 26 May 2026 13:40:54 -0700 Subject: [PATCH 15/15] Attempt to address failing CI Signed-off-by: John St. John --- .../src/bionemo/evo2/models/evo2_provider.py | 26 +++++++++++ .../src/bionemo/evo2/run/train.py | 23 +++++++--- .../bionemo/evo2/test_model_providers.py | 43 +++++++++++++++++++ 3 files changed, 85 insertions(+), 7 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py index 8f310676ff..ce4721634f 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py @@ -18,8 +18,10 @@ import math +import sys from dataclasses import dataclass from functools import partial +from pathlib import Path from typing import Callable, Iterable, Literal, Optional, Type import torch @@ -54,6 +56,30 @@ from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond +def _patch_megatron_dataset_helper_compile() -> None: + """Skip Megatron's runtime helper build when a wheel already ships the extension.""" + from megatron.core.datasets import utils as dataset_utils + + original_compile_helpers = dataset_utils.compile_helpers + if getattr(original_compile_helpers, "_evo2_prebuilt_helper_guard", False): + guarded_compile_helpers = original_compile_helpers + else: + + def guarded_compile_helpers() -> None: + datasets_dir = Path(dataset_utils.__file__).resolve().parent + if not (datasets_dir / "Makefile").exists() and list(datasets_dir.glob("helpers_cpp*.so")): + return None + return original_compile_helpers() + + guarded_compile_helpers._evo2_prebuilt_helper_guard = True + dataset_utils.compile_helpers = guarded_compile_helpers + + bridge_initialize = sys.modules.get("megatron.bridge.training.initialize") + if bridge_initialize is not None: + bridge_initialize.compile_helpers = guarded_compile_helpers + + +_patch_megatron_dataset_helper_compile() register_allowed_target_prefix("bionemo.evo2.") diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py index 0b5d628dc9..77fd434bd8 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py @@ -408,11 +408,20 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: help="Use the faster, but maybe less accurate fused form of cross entropy, " "which also has bf16 grads internally.", ) # DONE - parser.add_argument( - "--no-fp32-residual-connection", + fp32_residual_group = parser.add_mutually_exclusive_group(required=False) + fp32_residual_group.add_argument( + "--fp32-residual-connection", + dest="fp32_residual_connection", action="store_true", - default=False, - help="If set, turn off fp32 residual connections which may be faster but may impact accuracy.", + default=None, + help="Enable fp32 residual connections. Defaults to the selected model provider setting.", + ) + fp32_residual_group.add_argument( + "--no-fp32-residual-connection", + dest="fp32_residual_connection", + action="store_false", + default=None, + help="Disable fp32 residual connections. Defaults to the selected model provider setting.", ) # DONE parser.add_argument( "--debug-ddp-parity-freq", @@ -859,11 +868,11 @@ def train(args: argparse.Namespace) -> None: cfg.model.seq_len_interpolation_factor = args.seq_len_interpolation_factor cfg.model.calculate_per_token_loss = not args.no_calculate_per_token_loss model_type = infer_model_type(args.model_size) - if model_type != "hyena" and not args.no_fp32_residual_connection: + if args.fp32_residual_connection is not None: + cfg.model.fp32_residual_connection = args.fp32_residual_connection + if model_type != "hyena" and cfg.model.fp32_residual_connection: logger.info("Disabling fp32_residual_connection for non-Hyena model (not compatible with TE layers)") cfg.model.fp32_residual_connection = False - else: - cfg.model.fp32_residual_connection = not args.no_fp32_residual_connection cfg.model.cross_entropy_loss_fusion = args.cross_entropy_loss_fusion # cfg.model.cuda_graph_impl = "local" # or "transformer_engine" # cfg.model.cuda_graph_scope = "full_iteration" diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py index 3f7d605f49..afe565dacc 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py @@ -15,6 +15,8 @@ """Tests for model provider instantiation, naming, and checkpoint converters.""" +from pathlib import Path + import pytest import torch @@ -22,6 +24,7 @@ HYENA_MODEL_OPTIONS, MODEL_OPTIONS, Hyena1bModelProvider, + _patch_megatron_dataset_helper_compile, infer_model_type, ) from bionemo.evo2.utils.checkpoint.mbridge_to_vortex import _split_fc1, mbridge_to_vortex_state_dict @@ -63,6 +66,46 @@ def test_infer_model_type_unknown(): infer_model_type("nonexistent_model") +@pytest.mark.parametrize( + ("has_makefile", "has_prebuilt_extension", "expected_original_calls"), + [ + (False, True, 0), + (True, True, 1), + (False, False, 1), + ], +) +def test_megatron_dataset_helper_compile_guard( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + has_makefile: bool, + has_prebuilt_extension: bool, + expected_original_calls: int, +): + """Skip Megatron's runtime make step only when a prebuilt helper extension exists.""" + from megatron.bridge.training import initialize as bridge_initialize + from megatron.core.datasets import utils as dataset_utils + + calls = [] + + def original_compile_helpers(): + calls.append("called") + + if has_makefile: + (tmp_path / "Makefile").write_text("all:\n") + if has_prebuilt_extension: + (tmp_path / "helpers_cpp.cpython-312-x86_64-linux-gnu.so").touch() + + monkeypatch.setattr(dataset_utils, "__file__", str(tmp_path / "utils.py")) + monkeypatch.setattr(dataset_utils, "compile_helpers", original_compile_helpers) + monkeypatch.setattr(bridge_initialize, "compile_helpers", original_compile_helpers) + + _patch_megatron_dataset_helper_compile() + + dataset_utils.compile_helpers() + assert bridge_initialize.compile_helpers is dataset_utils.compile_helpers + assert len(calls) == expected_original_calls + + def _make_mock_savanna_sd(pattern: str) -> dict[str, torch.Tensor]: """Create a minimal mock savanna state dict for the given pattern.