diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 8d8d51b7b3..72f3ba4c62 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -10,7 +10,26 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \ PIP_CONSTRAINT= pip install -r /workspace/requirements.txt +# Sandboxed agent CLIs use these helpers on Linux. +RUN apt-get update && apt-get install -y --no-install-recommends \ + bubblewrap \ + uidmap \ + && rm -rf /var/lib/apt/lists/* + COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/ + +# Install Node.js 22 LTS and the OpenAI Codex CLI. +RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \ + && apt-get install -y --no-install-recommends nodejs \ + && rm -rf /var/lib/apt/lists/* \ + && npm install -g --no-fund --no-audit @openai/codex \ + && npm cache clean --force + +# Default Codex to Landlock where nested namespaces are restricted. +RUN mkdir -p /home/ubuntu/.codex \ + && printf '[features]\nuse_legacy_landlock = true\n' > /home/ubuntu/.codex/config.toml \ + && chown -R ubuntu:ubuntu /home/ubuntu/.codex + USER ubuntu RUN curl https://cursor.com/install -fsS | bash # Install cursor-agent CLI tool RUN curl -fsSL https://claude.ai/install.sh | bash # Install Claude CLI tool diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index a896c36a22..deac11c06e 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -8,6 +8,7 @@ "source=${localEnv:HOME}/.cache,target=/home/ubuntu/.cache,type=bind,consistency=cached", "source=${localEnv:HOME}/.claude,target=/home/ubuntu/.claude,type=bind,consistency=cached", "source=${localEnv:HOME}/.claude.json,target=/home/ubuntu/.claude.json,type=bind,consistency=cached", + "source=${localEnv:HOME}/.codex,target=/home/ubuntu/.codex,type=bind,consistency=cached", "source=${localEnv:HOME}/.config,target=/home/ubuntu/.config,type=bind,consistency=cached", "source=${localEnv:HOME}/.cursor,target=/home/ubuntu/.cursor,type=bind,consistency=cached", "source=${localEnv:HOME}/.gnupg,target=/home/ubuntu/.gnupg,type=bind,consistency=cached", diff --git a/.devcontainer/initializeCommand.sh b/.devcontainer/initializeCommand.sh index 182db741f8..aae34bd54f 100755 --- a/.devcontainer/initializeCommand.sh +++ b/.devcontainer/initializeCommand.sh @@ -8,6 +8,7 @@ mkdir -p ~/.gnupg mkdir -p ~/.config mkdir -p ~/.cursor mkdir -p ~/.claude +mkdir -p ~/.codex [ ! -f ~/.netrc ] && touch ~/.netrc [ ! -f ~/.bash_history_devcontainer ] && touch ~/.bash_history_devcontainer diff --git a/.devcontainer/start.sh b/.devcontainer/start.sh index 6f75fbe6b0..834aabf871 100755 --- a/.devcontainer/start.sh +++ b/.devcontainer/start.sh @@ -9,7 +9,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" DEVCONTAINER_JSON="${SCRIPT_DIR}/devcontainer.json" CONTAINER_NAME="${BIONEMO_CONTAINER_NAME:-bionemo-devcontainer}" -IMAGE_NAME="${BIONEMO_IMAGE_NAME:-bionemo-devcontainer:latest}" +IMAGE_NAME="${BIONEMO_IMAGE_NAME:-${CONTAINER_NAME}:latest}" # --------------------------------------------------------------------------- # Helpers diff --git a/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh b/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh index 58ee94d1f3..4eb8c487aa 100755 --- a/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh +++ b/bionemo-recipes/recipes/evo2_megatron/.ci_build.sh @@ -5,7 +5,7 @@ rm -f /usr/local/lib/python*/dist-packages/transformer_engine-*.dist-info/direct_url.json export UV_LOCK_TIMEOUT=900 # increase to 15 minutes (900 seconds), adjust as needed export UV_LINK_MODE=copy -uv venv --system-site-packages +uv venv --clear --system-site-packages # 2. Activate the environment source .venv/bin/activate @@ -38,8 +38,8 @@ for pkg_dir in "$RECIPE_ROOT/../../../sub-packages/bionemo-recipeutils" "$RECIPE fi done -# 6. Install the recipe with all remaining dependencies -uv pip install -c pip-constraints.txt -e . --no-build-isolation +# 6. Install the recipe with all remaining dependencies, including test extras +uv pip install -c pip-constraints.txt -e '.[test]' --no-build-isolation # 7. Restore original pyproject.toml (the edit was only needed for uv resolution) mv pyproject.toml.ci_bak pyproject.toml diff --git a/bionemo-recipes/recipes/evo2_megatron/build_requirements.txt b/bionemo-recipes/recipes/evo2_megatron/build_requirements.txt index 5c7e5906a1..38cbd1fe91 100644 --- a/bionemo-recipes/recipes/evo2_megatron/build_requirements.txt +++ b/bionemo-recipes/recipes/evo2_megatron/build_requirements.txt @@ -1,3 +1,5 @@ poetry-core +poetry_dynamic_versioning # build dep of nvidia-resiliency-ext (transitively pulled by megatron-bridge); needed in the venv because we install with --no-build-isolation +grpcio-tools # build dep of nvidia-resiliency-ext: its setup.py shells out to `python -m grpc_tools.protoc` to compile *.proto files; --no-build-isolation means we have to provide it in the venv up-front wheel_stub ninja # should speed up causal-conv1d build diff --git a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml index e83ac2c76c..c5a0d9ffb4 100644 --- a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml +++ b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml @@ -24,9 +24,10 @@ dependencies = [ "causal_conv1d", "nv-grouped-gemm", "megatron-core", - "nvidia-resiliency-ext", + # nvidia-resiliency-ext is pulled transitively by megatron-bridge. "emerging_optimizers", "subquadratic-ops-torch-cu13", + "email-validator", # These are dependencies for examples only, but are useful for actually doing analyses with this model "biopython", @@ -35,7 +36,9 @@ dependencies = [ ] [project.optional-dependencies] -test = [] +test = [ + "pytest>=8.0", +] [project.scripts] torchrun = "torch.distributed.run:main" @@ -88,22 +91,29 @@ override-dependencies = [ "triton; sys_platform == 'never'", "transformer-engine; sys_platform == 'never'", "transformer-engine[pytorch]; sys_platform == 'never'", + # Avoid alpha Pydantic releases; langchain imports pulled by nvidia-resiliency-ext are not compatible. + "pydantic>=2.12,<2.14", + # Avoid optional log-pattern-mining dependency conflicts from nvidia-resiliency-ext. + "logsage; sys_platform == 'never'", + "drain3; sys_platform == 'never'", ] [tool.uv.sources] # Shared recipe utilities (framework-agnostic) # External dependencies with specific git tags/commits -causal_conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d.git", tag = "v1.5.4" } +# 1.6.1 fixes a custom-op no-storage failure in no-grad/frozen forward paths. +causal_conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d.git", tag = "v1.6.1" } nv-grouped-gemm = { git = "https://github.com/fanshiqing/grouped_gemm", tag = "v1.1.4.post6" } # Internal dependencies bionemo-recipeutils = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch = "main", subdirectory = "sub-packages/bionemo-recipeutils" } bionemo-core = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch = "main", subdirectory = "sub-packages/bionemo-core" } -nvidia-resiliency-ext = { git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git", rev = "54f85fe422d296cf04ea524130014bd3a2c3add1" } +# nvidia-resiliency-ext is intentionally left to Megatron-Bridge so the transitive pin stays consistent. -# Megatron Bundle. This points to a version that still supports the deprecated no_weight_decay_cond field until the API for an alternative has been finalized. -megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f" } -megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f", subdirectory = "3rdparty/Megatron-LM" } +# Megatron Bundle. MCore is sourced from the same Megatron-Bridge release tag. +megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", tag = "v0.4.1" } +megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", tag = "v0.4.1", subdirectory = "3rdparty/Megatron-LM" } [tool.uv.extra-build-dependencies] warp-lang = ["wheel_stub"] +nvidia-resiliency-ext = ["poetry_dynamic_versioning"] diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py index 6dcab77575..ce4721634f 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py @@ -18,8 +18,10 @@ import math +import sys from dataclasses import dataclass from functools import partial +from pathlib import Path from typing import Callable, Iterable, Literal, Optional, Type import torch @@ -35,6 +37,7 @@ from megatron.bridge.training.state import GlobalState from megatron.bridge.training.utils.packed_seq_utils import get_packed_seq_params from megatron.bridge.training.utils.pg_utils import get_pg_collection +from megatron.bridge.utils.instantiate_utils import register_allowed_target_prefix from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size from megatron.core import parallel_state from megatron.core.inference.contexts import StaticInferenceContext @@ -53,6 +56,33 @@ from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond +def _patch_megatron_dataset_helper_compile() -> None: + """Skip Megatron's runtime helper build when a wheel already ships the extension.""" + from megatron.core.datasets import utils as dataset_utils + + original_compile_helpers = dataset_utils.compile_helpers + if getattr(original_compile_helpers, "_evo2_prebuilt_helper_guard", False): + guarded_compile_helpers = original_compile_helpers + else: + + def guarded_compile_helpers() -> None: + datasets_dir = Path(dataset_utils.__file__).resolve().parent + if not (datasets_dir / "Makefile").exists() and list(datasets_dir.glob("helpers_cpp*.so")): + return None + return original_compile_helpers() + + guarded_compile_helpers._evo2_prebuilt_helper_guard = True + dataset_utils.compile_helpers = guarded_compile_helpers + + bridge_initialize = sys.modules.get("megatron.bridge.training.initialize") + if bridge_initialize is not None: + bridge_initialize.compile_helpers = guarded_compile_helpers + + +_patch_megatron_dataset_helper_compile() +register_allowed_target_prefix("bionemo.evo2.") + + def get_vocab_size(*args, **kwargs): raise NotImplementedError("FIXME get_vocab_size is not implemented Find it in megatron bridge") @@ -306,7 +336,7 @@ class HyenaModelProvider(TransformerConfig, ModelProviderMixin[MCoreHyenaModel]) apply_rope_fusion: bool = True make_vocab_size_divisible_by: int = 128 gated_linear_unit: bool = True - fp32_residual_connection: bool = True + fp32_residual_connection: bool = False normalization: str = "RMSNorm" add_bias_linear: bool = False hidden_dropout: float = 0.0 diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py index c9ff82ba00..83a720446f 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py @@ -29,6 +29,11 @@ _subq_error_msg = f"subquadratic_ops_torch not available: {_subq_import_error}" +def _linear_causal_fft_size(input_len: int, filter_len: int) -> int: + """Return an FFT size that cannot alias a causal convolution prefix.""" + return max(2 * input_len, 2 * filter_len) + + def adjust_filter_shape_for_broadcast(u, h): """Adjust filter shape for broadcasting compatibility with input tensor.""" h = h.squeeze() # Standardize to [D, L] from [1, D, L] and [D, 1, L] @@ -50,7 +55,7 @@ def fftconv_func(*, u, k, D): # noqa: N803 The convolution is computed in the frequency domain and then transformed back to the time domain. """ seqlen = u.shape[-1] - fft_size = 2 * seqlen + fft_size = _linear_causal_fft_size(seqlen, k.shape[-1]) k_f = torch.fft.rfft(k, n=fft_size) / fft_size k_f = adjust_filter_shape_for_broadcast(u, k_f) @@ -99,7 +104,9 @@ def parallel_fir( ).to(dtype=u.dtype) else: if use_subquadratic_ops: - # subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight; dtypes must match + if _subq_causal_conv1d is None: + raise ImportError(_subq_error_msg) + # subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight. pad_size = fir_length - 1 x_padded = F.pad(u.to(torch.float32), (pad_size, 0)) w = weight.squeeze(1) if weight.dim() == 3 else weight @@ -111,7 +118,7 @@ def parallel_fir( bias=None, stride=1, padding=fir_length - 1, - groups=u.shape[1], # always set to D, regardless of filter grouping + groups=u.shape[1], )[..., :L] z = z.to(u.dtype) @@ -130,7 +137,7 @@ def parallel_fir( def parallel_iir(*, z_pre, h, D, L, poles, t, hidden_size, compute_state): # noqa: N803 """Compute the output state of the short convolutional filter.""" - fft_size = 2 * L + fft_size = _linear_causal_fft_size(L, h.shape[-1]) x1, x2, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1) x1v = x1 * v @@ -221,9 +228,9 @@ def prefill_via_modal_fft(*, x1v, L, poles, t, X_s): # noqa: N803 # When the model has a long convolution derived from a recurrence in modal form and prefill_style is "fft", # we split the filter into poles and residues and reuse FFT computation on the input. bs = x1v.shape[0] - fft_size = 2 * L + fft_size = X_s.shape[-1] state_s = (poles.to(torch.float32) * t).exp() - state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # noqa N806: B, D, state_dim, 2 * L + state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # noqa N806: B, D, state_dim, fft_size state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size) # Do not try to fix `UserWarning: Casting complex values to real discards # the imaginary part` by inserting state.real conversion anywhere before diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py index 93acf70e57..ac03b7a999 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from contextlib import nullcontext from dataclasses import dataclass from typing import Optional, Union @@ -120,13 +121,28 @@ def __init__( pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(layer_type_list) if get_cpu_offload_context is not None: - (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( + # MCore 0.x has shipped both six- and seven-argument variants of this helper. + # Pass only the arguments accepted by the installed version; if a future helper + # uses *args, pass the full compatibility list rather than counting *args as one slot. + offload_args = [ self.config.cpu_offloading, self.config.cpu_offloading_num_layers, self.config.num_layers, self.config.cpu_offloading_activations, self.config.cpu_offloading_weights, self.config.cpu_offloading_double_buffering, + getattr(self.config, "cpu_offloading_retain_pinned_cpu_buffers", False), + ] + offload_params = tuple(inspect.signature(get_cpu_offload_context).parameters.values()) + if any(param.kind is inspect.Parameter.VAR_POSITIONAL for param in offload_params): + num_offload_args = len(offload_args) + else: + num_offload_args = sum( + param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + for param in offload_params + ) + (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( + *offload_args[:num_offload_args], ) self.config._cpu_offloading_context = self.offload_context if self.config.cpu_offloading else None else: diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.py index 18867f3689..577bf9af2a 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_layer.py @@ -108,8 +108,10 @@ def bias_dropout_add_exec_handler(self): if self.training: return torch.enable_grad else: - # Validation, Test, Inference, Etc. - return torch.inference_mode + # torch.inference_mode marks outputs as inference tensors. Those flags + # persist after leaving the context and can break downstream autograd or + # torch.library custom ops that consume frozen model outputs. + return torch.no_grad def forward( self, diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py index d8cb265e31..81271c8e47 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py @@ -42,6 +42,7 @@ logger = logging.getLogger(__name__) + try: from transformer_engine.common.recipe import DelayedScaling, Format except ImportError: @@ -117,7 +118,6 @@ def __init__( self.fast_conv_proj = self.hyena_config.fast_conv_proj self.fast_conv_mixer = self.hyena_config.fast_conv_mixer - # Use b2b causal conv1d self.use_subquadratic_ops = self.transformer_config.use_subquadratic_ops # Per attention head and per partition values. @@ -198,8 +198,8 @@ def __init__( ) if self.use_subquadratic_ops: - # Create a wrapper module that doesn't register parameters - # Use the existing weights from the original model + # The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack + # cannot run subquadratic_ops_torch correctly. self.b2b_kernel = B2BCausalConv1dModule( self.hyena_proj_conv, self.mixer, @@ -229,8 +229,8 @@ def __init__( ) if self.use_subquadratic_ops and self.operator_type == "hyena_medium_conv": - # Create a wrapper module that doesn't register parameters - # Use the existing weights from the original model + # The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack + # cannot run subquadratic_ops_torch correctly. self.b2b_kernel = B2BCausalConv1dModule( self.hyena_proj_conv, self.mixer, @@ -312,8 +312,8 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True "hyena_short_conv", "hyena_medium_conv", ] - # b2b runs during training (no inference_context) or during prefill (no FIR cache yet). - # During decode (cache populated, L=1) we fall back to the regular per-token step path. + # B2B runs during training (no inference_context) or during prefill (no FIR cache yet). + # During decode, fall back to the regular per-token step path. is_prefill = inference_context is not None and id(self.hyena_proj_conv) not in getattr( inference_context, "fir_filter_state_dict", {} ) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index 44be308aaf..6f5ae3abae 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -50,10 +50,14 @@ def causal_conv1d_fn(*args, **kwargs): try: - from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d - from subquadratic_ops_torch.causal_conv1d import causal_conv1d - from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d + from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d as _subq_b2b_causal_conv1d + from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d + from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d from subquadratic_ops_torch.implicit_filter import implicit_filter + + causal_conv1d = _subq_causal_conv1d + b2b_causal_conv1d = _subq_b2b_causal_conv1d + fft_causal_conv1d = _subq_fft_causal_conv1d except ImportError as e: msg_causal_conv1d = f"Problem importing subquadratic_ops: {e}. causal_conv1d is not available." msg_b2b_causal_conv1d = f"Problem importing subquadratic_ops: {e}. b2b_causal_conv1d is not available." @@ -451,10 +455,19 @@ def hyena_no_weight_decay_cond_with_embeddings(name, param): return ("embedding" in name) or hyena_no_weight_decay_cond(name, param) -def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=False, use_subquadratic_ops=False): # noqa: N803 +def fftconv_func( + u, + k, + D, # noqa: N803 + dropout_mask, + gelu=True, + k_rev=None, + bidirectional=False, + use_subquadratic_ops=False, +): """Apply a 1D convolution to the input sequence u using the filter k and the shortcut D.""" seqlen = u.shape[-1] - fft_size = 2 * seqlen + fft_size = max(2 * seqlen, 2 * k.shape[-1]) # check if k is less than seqlen -- subquadratic_ops input does not need padding if not use_subquadratic_ops and k.shape[-1] < seqlen: @@ -632,6 +645,8 @@ def compute_filter(self, L, t, glogp, R): # noqa: N803 return h, None + # Keep this eager. The short-prefill prefix-invariance tests in tests/bionemo/evo2/run + # cover the prior torch.compile regression with dynamic filter lengths and custom ops. def filter(self, L, *args, **kwargs): # noqa: N803 """Get t and the convolution filter for t and the requested sequence length.""" if self._cp_size > 1: @@ -754,7 +769,7 @@ def forward(self, L, *args, **kwargs): # noqa: N803 """ return self.filter(L, *args, **kwargs) - @torch.compile(mode="max-autotune") + # Keep this eager for the same short-prefill prefix-invariance reproducer as ImplicitModalFilter.filter. def filter(self, L, *args, **kwargs): # noqa: N803 """Compute the filter as a function of h and decay for the requested sequence length.""" h = self.h[:, :L] @@ -1453,10 +1468,10 @@ def forward(self, x, inference_context=None, _use_cp=True): else: x = F.pad(x, (pad_size, 0)) - # subquadratic_ops causal_conv1d is only applied to the projection conv of Hyena LI layer - # Projection conv is fused with SE/MR layers (B2BCausalConv1dModule) + # subquadratic_ops causal_conv1d is only applied to the projection conv of Hyena LI layer. + # Projection conv is fused with SE/MR layers by B2BCausalConv1dModule when available. if self.use_fast_causal_conv: # hyena_proj_conv case - if self.use_subquadratic_ops: # hyena_proj_conv of LI layer when subquadratic_ops is enabled + if self.use_subquadratic_ops: y = causal_conv1d(x, weight)[..., pad_size:] else: y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:] diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py new file mode 100644 index 0000000000..7ef53c901a --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import lru_cache + +import torch +import torch.nn.functional as F # noqa: N812 + + +def _raise_subquadratic_self_test_error(op_name: str, detail: str) -> None: + raise RuntimeError( + f"subquadratic_ops_torch.{op_name} failed a CUDA self-test ({detail}). " + "This often happens with CUDA_ERROR_UNSUPPORTED_PTX_VERSION or unsupported GPU/toolchain " + "combinations. Refusing to run this subquadratic kernel because it can otherwise return " + "invalid outputs without raising." + ) + + +def _assert_close_or_raise(op_name: str, actual: torch.Tensor, expected: torch.Tensor) -> None: + torch.cuda.synchronize(actual.device) + if not torch.isfinite(actual).all(): + _raise_subquadratic_self_test_error(op_name, "non-finite output") + + if not torch.allclose(actual, expected, rtol=1e-4, atol=1e-4): + max_diff = (actual.float() - expected.float()).abs().max().item() + rel = ( + (actual.float() - expected.float()).pow(2).sum().sqrt() / (expected.float().pow(2).sum().sqrt() + 1e-30) + ).item() + _raise_subquadratic_self_test_error(op_name, f"max_diff={max_diff:.6g}, rel={rel:.6g}") + + +@lru_cache(maxsize=None) +def ensure_subquadratic_ops_supported(device_index: int | None = None) -> None: + """Validate all subquadratic_ops_torch CUDA kernels used by Evo2.""" + ensure_subquadratic_causal_conv1d_supported(device_index) + ensure_subquadratic_fft_causal_conv1d_supported(device_index) + ensure_subquadratic_b2b_causal_conv1d_supported(device_index) + + +@lru_cache(maxsize=None) +def ensure_subquadratic_causal_conv1d_supported(device_index: int | None = None) -> None: + """Validate subquadratic_ops_torch.causal_conv1d before using it for model data.""" + if not torch.cuda.is_available(): + return + + device_index = torch.cuda.current_device() if device_index is None else device_index + device = torch.device("cuda", device_index) + + from subquadratic_ops_torch.causal_conv1d import causal_conv1d as subq_causal_conv1d + + batch_size = 1 + hidden_size = 4 + seq_len = 8 + kernel_size = 3 + pad_size = kernel_size - 1 + + u = torch.linspace(-1.0, 1.0, steps=batch_size * hidden_size * seq_len, device=device).reshape( + batch_size, hidden_size, seq_len + ) + weight = torch.linspace(-0.5, 0.5, steps=hidden_size * kernel_size, device=device).reshape( + hidden_size, kernel_size + ) + + expected = F.conv1d( + u, + weight.unsqueeze(1), + bias=None, + stride=1, + padding=pad_size, + groups=hidden_size, + )[..., :seq_len] + actual = subq_causal_conv1d(F.pad(u, (pad_size, 0)), weight)[..., pad_size:] + _assert_close_or_raise("causal_conv1d", actual, expected) + + +@lru_cache(maxsize=None) +def ensure_subquadratic_fft_causal_conv1d_supported(device_index: int | None = None) -> None: + """Validate subquadratic_ops_torch.fft_causal_conv1d before using it for model data.""" + if not torch.cuda.is_available(): + return + + device_index = torch.cuda.current_device() if device_index is None else device_index + device = torch.device("cuda", device_index) + + from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as subq_fft_causal_conv1d + + batch_size = 1 + hidden_size = 4 + seq_len = 8 + kernel_size = 5 + + u = torch.linspace(-1.0, 1.0, steps=batch_size * hidden_size * seq_len, device=device).reshape( + batch_size, hidden_size, seq_len + ) + weight = torch.linspace(-0.5, 0.5, steps=hidden_size * kernel_size, device=device).reshape( + hidden_size, kernel_size + ) + + expected = F.conv1d( + u, + weight.flip(-1).unsqueeze(1), + bias=None, + stride=1, + padding=kernel_size - 1, + groups=hidden_size, + )[..., :seq_len] + actual = subq_fft_causal_conv1d(u, weight) + _assert_close_or_raise("fft_causal_conv1d", actual, expected) + + +@lru_cache(maxsize=None) +def ensure_subquadratic_b2b_causal_conv1d_supported(device_index: int | None = None) -> None: + """Validate subquadratic_ops_torch.b2b_causal_conv1d before using it for model data.""" + if not torch.cuda.is_available(): + return + + device_index = torch.cuda.current_device() if device_index is None else device_index + device = torch.device("cuda", device_index) + + from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d as subq_b2b_causal_conv1d + + batch_size = 1 + hidden_size = 2 + seq_len = 10 + proj_kernel_size = 3 + mixer_kernel_size = 7 + + x = torch.linspace(-1.0, 1.0, steps=batch_size * 3 * hidden_size * seq_len, device=device).reshape( + batch_size, 3 * hidden_size, seq_len + ) + proj_weight = torch.linspace(-0.5, 0.5, steps=3 * hidden_size * proj_kernel_size, device=device).reshape( + 3 * hidden_size, proj_kernel_size + ) + mixer_weight = torch.linspace(-0.25, 0.25, steps=hidden_size * mixer_kernel_size, device=device).reshape( + hidden_size, mixer_kernel_size + ) + bias = torch.linspace(-0.1, 0.1, steps=hidden_size, device=device) + + actual = subq_b2b_causal_conv1d(x, proj_weight, mixer_weight, bias) + + projected = F.conv1d( + F.pad(x, (proj_kernel_size - 1, 0)), + proj_weight.flip(-1).unsqueeze(1), + groups=3 * hidden_size, + ) + x1, x2, v = projected[:, ::3], projected[:, 1::3], projected[:, 2::3] + z = x2 * v + mixed = F.conv1d( + F.pad(z, (mixer_kernel_size - 1, 0)), + mixer_weight.flip(-1).unsqueeze(1), + groups=hidden_size, + ) + expected = x1 * (mixed + bias[None, :, None] * z) + _assert_close_or_raise("b2b_causal_conv1d", actual, expected) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py index 0ed9d4defb..5fa22e8840 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py @@ -59,6 +59,7 @@ import argparse import gc +import inspect import json import logging import os @@ -77,29 +78,61 @@ ) from megatron.bridge.training.config import DistributedInitConfig, RNGConfig from megatron.bridge.training.mixed_precision import get_mixed_precision_config -from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer + + +try: + from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer +except ImportError: + from megatron.core.tokenizers.text.libraries.huggingface_tokenizer import ( + HuggingFaceTokenizer as _HuggingFaceTokenizer, + ) from megatron.bridge.training.utils.checkpoint_utils import ( file_exists, get_checkpoint_run_config_filename, read_run_config, ) from megatron.bridge.utils.common_utils import get_world_size_safe -from megatron.bridge.utils.instantiate_utils import instantiate +from megatron.bridge.utils.instantiate_utils import instantiate, register_allowed_target_prefix from megatron.core import dist_checkpointing, parallel_state from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.inference.engines.static_engine import StaticInferenceEngine from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( AbstractModelInferenceWrapper, ) -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( - InferenceWrapperConfig, -) + + +try: + from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, + ) +except ImportError: + + @dataclass + class InferenceWrapperConfig: + """Compatibility shim for MCore versions that removed InferenceWrapperConfig.""" + + hidden_size: int + inference_max_requests: int + inference_max_seq_length: int + inference_batch_times_seqlen_threshold: int + params_dtype: torch.dtype + padded_vocab_size: int + nccl_all_reduce_for_prefill: bool = False + moe_pad_experts_for_cuda_graph_inference: bool = False + + def add_attributes(self, attributes: dict[str, Any]) -> None: + """Match the old MCore config helper used by Evo2TextGenerationController.""" + for name, value in attributes.items(): + setattr(self, name, value) + + from megatron.core.inference.sampling_params import SamplingParams from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_model_config from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH from bionemo.evo2.models.evo2_provider import HyenaInferenceContext +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported from bionemo.evo2.run.predict import initialize_inference_distributed, resolve_checkpoint_path from bionemo.evo2.run.text_generation_controller import Evo2TextGenerationController @@ -107,6 +140,52 @@ logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +_WRAPPER_INIT_ACCEPTS_CONFIG = ( + "inference_wrapper_config" in inspect.signature(AbstractModelInferenceWrapper.__init__).parameters +) + + +class _TextGenerationTokenizerAdapter: + """Expose the tokenizer methods expected by MCore's static text-generation path.""" + + def __init__(self, tokenizer: _HuggingFaceTokenizer): + self._tokenizer = tokenizer + + def __getattr__(self, name: str) -> Any: + return getattr(self._tokenizer, name) + + @property + def vocab_size(self) -> int: + return self._tokenizer.vocab_size + + @property + def bos(self) -> Optional[int]: + return getattr(self._tokenizer, "bos", None) + + @property + def eod(self) -> Optional[int]: + return getattr(self._tokenizer, "eod", None) + + def tokenize(self, text: str) -> list[int]: + if hasattr(self._tokenizer, "tokenize"): + return self._tokenizer.tokenize(text) + return self._tokenizer.text_to_ids(text) + + def detokenize(self, tokens: list[int], skip_special_tokens: bool = True) -> str: + if hasattr(self._tokenizer, "detokenize"): + return self._tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens) + return self._tokenizer.ids_to_text(tokens) + + def offsets(self, tokens: list[int], text: str) -> list[int]: + if hasattr(self._tokenizer, "offsets"): + return self._tokenizer.offsets(tokens, text) + offsets = [] + position = 0 + for token in tokens: + offsets.append(position) + position += len(self.detokenize([token], skip_special_tokens=False)) + return offsets + # ============================================================================= # Hardware-Aware Defaults @@ -235,7 +314,11 @@ def __init__( inference_wrapper_config: Configuration with hidden size, vocab size, etc. inference_context: Context for managing state and sequence offsets. """ - super().__init__(model, inference_wrapper_config, inference_context) + self.inference_wrapper_config = inference_wrapper_config + if _WRAPPER_INIT_ACCEPTS_CONFIG: + super().__init__(model, inference_wrapper_config, inference_context) + else: + super().__init__(model, inference_context) def prep_inference_input(self, prompts_tokens: torch.Tensor) -> Dict[str, Any]: """Prepare the inference input data. @@ -402,6 +485,7 @@ def setup_inference_engine( raise FileNotFoundError(f"run_config.yaml not found at {run_config_filename}") run_config = read_run_config(run_config_filename) + register_allowed_target_prefix("bionemo.") model_provider = instantiate(run_config["model"]) logger.info(f"Instantiated model provider: {type(model_provider).__name__}") @@ -438,6 +522,7 @@ def setup_inference_engine( tokenizer = _HuggingFaceTokenizer(tokenizer_dir) else: tokenizer = _HuggingFaceTokenizer(DEFAULT_HF_TOKENIZER_MODEL_PATH) + tokenizer = _TextGenerationTokenizerAdapter(tokenizer) model_provider.vocab_size = tokenizer.vocab_size model_provider.should_pad_vocab = True @@ -462,6 +547,8 @@ def setup_inference_engine( dist_config=dist_config, ) logger.info("Initialized distributed environment") + if use_subquadratic_ops: + ensure_subquadratic_ops_supported() # ------------------------------------------------------------------------- # Step 5: Create model and load weights diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py index c888d46c94..ab30732586 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py @@ -76,7 +76,6 @@ ) from megatron.bridge.training.config import DistributedInitConfig, RNGConfig from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES, get_mixed_precision_config -from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer from megatron.bridge.training.utils.checkpoint_utils import ( file_exists, get_checkpoint_run_config_filename, @@ -97,8 +96,17 @@ from megatron.core.utils import get_batch_on_this_cp_rank from torch import Tensor + +try: + from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer +except ImportError: + from megatron.core.tokenizers.text.libraries.huggingface_tokenizer import ( + HuggingFaceTokenizer as _HuggingFaceTokenizer, + ) + from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported from bionemo.recipeutils.inference.collation import batch_collator @@ -656,7 +664,6 @@ def _predict_step( if not parallel_state.is_pipeline_last_stage(): return None - # Forward pass output_tensor = model( input_ids=batch["tokens"], position_ids=batch["position_ids"], @@ -1036,10 +1043,9 @@ def predict( f"Valid range: -{original_num_layers} to {original_num_layers - 1}." ) - # Set the model to use fewer layers and skip post-processing (output heads) + # Set the model to use fewer layers and skip post-processing (output heads). model_provider.num_layers = target_num_layers model_provider.post_process = False - # Also truncate the hybrid_override_pattern if it exists, since it must match num_layers if hasattr(model_provider, "hybrid_override_pattern") and model_provider.hybrid_override_pattern is not None: original_pattern = model_provider.hybrid_override_pattern @@ -1085,6 +1091,8 @@ def predict( dist_config=dist_config, ) logger.info("Initialized distributed environment") + if use_subquadratic_ops: + ensure_subquadratic_ops_supported() # ------------------------------------------------------------------------- # Step 5: Create model and load weights @@ -1276,6 +1284,12 @@ def predict( def main() -> None: """CLI entry point for Evo2 prediction.""" args = parse_args() + try: + from megatron.bridge.utils.instantiate_utils import register_allowed_target_prefix + + register_allowed_target_prefix("bionemo.") + except ImportError: + pass predict( fasta_path=args.fasta, ckpt_dir=args.ckpt_dir, diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py index 09b5b9df23..77fd434bd8 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py @@ -35,10 +35,11 @@ from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES from megatron.bridge.training.post_training.checkpointing import has_modelopt_state from megatron.bridge.training.pretrain import pretrain -from megatron.bridge.utils.common_utils import get_rank_safe +from megatron.bridge.utils.common_utils import get_local_rank_preinit, get_rank_safe from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH from bionemo.evo2.models.evo2_provider import MODEL_OPTIONS, hyena_forward_step, infer_model_type +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported from bionemo.evo2.recipes.evo2 import evo2_1b_pretrain_config as pretrain_config @@ -407,11 +408,20 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: help="Use the faster, but maybe less accurate fused form of cross entropy, " "which also has bf16 grads internally.", ) # DONE - parser.add_argument( - "--no-fp32-residual-connection", + fp32_residual_group = parser.add_mutually_exclusive_group(required=False) + fp32_residual_group.add_argument( + "--fp32-residual-connection", + dest="fp32_residual_connection", action="store_true", - default=False, - help="If set, turn off fp32 residual connections which may be faster but may impact accuracy.", + default=None, + help="Enable fp32 residual connections. Defaults to the selected model provider setting.", + ) + fp32_residual_group.add_argument( + "--no-fp32-residual-connection", + dest="fp32_residual_connection", + action="store_false", + default=None, + help="Disable fp32 residual connections. Defaults to the selected model provider setting.", ) # DONE parser.add_argument( "--debug-ddp-parity-freq", @@ -858,11 +868,11 @@ def train(args: argparse.Namespace) -> None: cfg.model.seq_len_interpolation_factor = args.seq_len_interpolation_factor cfg.model.calculate_per_token_loss = not args.no_calculate_per_token_loss model_type = infer_model_type(args.model_size) - if model_type != "hyena" and not args.no_fp32_residual_connection: + if args.fp32_residual_connection is not None: + cfg.model.fp32_residual_connection = args.fp32_residual_connection + if model_type != "hyena" and cfg.model.fp32_residual_connection: logger.info("Disabling fp32_residual_connection for non-Hyena model (not compatible with TE layers)") cfg.model.fp32_residual_connection = False - else: - cfg.model.fp32_residual_connection = not args.no_fp32_residual_connection cfg.model.cross_entropy_loss_fusion = args.cross_entropy_loss_fusion # cfg.model.cuda_graph_impl = "local" # or "transformer_engine" # cfg.model.cuda_graph_scope = "full_iteration" @@ -885,7 +895,9 @@ def train(args: argparse.Namespace) -> None: if args.num_layers: cfg.model.num_layers = args.num_layers if args.use_subquadratic_ops: - # TODO assert that it is installed + if torch.cuda.is_available(): + torch.cuda.set_device(get_local_rank_preinit()) + ensure_subquadratic_ops_supported() cfg.model.use_subquadratic_ops = True if args.no_activation_checkpointing: diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py index 35b66ed455..45dc2b0ee7 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py @@ -14,7 +14,6 @@ # limitations under the License. -# conftest.py import copy import gc import os diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py new file mode 100644 index 0000000000..b5701d2ddc --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.nn.functional as F # noqa: N812 + +from bionemo.evo2.models.megatron.hyena import engine +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported + + +def test_fftconv_func_is_prefix_invariant_when_filter_is_longer_than_input(): + """Short-input FFT convolution should match the prefix of a longer-input convolution.""" + torch.manual_seed(1234) + batch_size = 2 + hidden_size = 4 + short_len = 5 + long_len = 128 + filter_len = 128 + + u_long = torch.randn(batch_size, hidden_size, long_len) + u_short = u_long[..., :short_len].contiguous() + k = torch.randn(hidden_size, 1, filter_len) + d = torch.randn(hidden_size) + + short_out = engine.fftconv_func(u=u_short, k=k, D=d) + long_out = engine.fftconv_func(u=u_long, k=k, D=d)[..., :short_len] + + torch.testing.assert_close(short_out, long_out, rtol=1e-5, atol=1e-5) + + +def test_parallel_iir_is_prefix_invariant_when_filter_is_longer_than_input(): + """The IIR prefill convolution should not circularly alias short prefixes.""" + torch.manual_seed(1234) + batch_size = 2 + hidden_size = 4 + short_len = 5 + long_len = 128 + filter_len = 128 + + z_long = torch.randn(batch_size, 3 * hidden_size, long_len) + z_short = z_long[..., :short_len].contiguous() + h = torch.randn(hidden_size, filter_len) + d = torch.randn(hidden_size) + + short_out, _ = engine.parallel_iir( + z_pre=z_short, + h=h, + D=d, + L=short_len, + poles=None, + t=None, + hidden_size=hidden_size, + compute_state=False, + ) + long_out, _ = engine.parallel_iir( + z_pre=z_long, + h=h, + D=d, + L=long_len, + poles=None, + t=None, + hidden_size=hidden_size, + compute_state=False, + ) + + torch.testing.assert_close(short_out, long_out[:, :short_len], rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("use_subquadratic_ops", [False, True], ids=["torch", "subq"]) +def test_parallel_fir_short_cuda_path_matches_torch_depthwise_conv1d(use_subquadratic_ops): + """Short FIR prefill should match F.conv1d or fail before returning bad subq output.""" + if not torch.cuda.is_available(): + pytest.skip("short FIR CUDA path requires CUDA") + if use_subquadratic_ops: + try: + ensure_subquadratic_ops_supported() + except RuntimeError as e: + pytest.xfail(str(e)) + + torch.manual_seed(1234) + batch_size = 2 + seq_len = 17 + hidden_size = 8 + kernel_size = 7 + device = torch.device("cuda") + + u = torch.randn(batch_size, seq_len, hidden_size, device=device) + weight = torch.randn(hidden_size, 1, kernel_size, device=device) + bias = torch.randn(hidden_size, device=device) + + actual, state = engine.parallel_fir( + u=u, + weight=weight, + bias=bias, + L=seq_len, + gated_bias=True, + fir_length=kernel_size, + compute_state=True, + use_subquadratic_ops=use_subquadratic_ops, + ) + + u_bdl = u.transpose(1, 2).contiguous() + expected = F.conv1d( + u_bdl.float(), + weight.float(), + bias=None, + stride=1, + padding=kernel_size - 1, + groups=hidden_size, + )[..., :seq_len] + expected = expected.to(u.dtype) + bias[None, :, None] * u_bdl + + torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(state, u_bdl[..., -(kernel_size - 1) :]) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py index cb7a6ea564..5a21f89fdb 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py @@ -26,6 +26,7 @@ from bionemo.evo2.models.megatron.hyena.hyena_layer_specs import hyena_stack_spec_no_te from bionemo.evo2.models.megatron.hyena.hyena_mixer import HyenaMixer from bionemo.evo2.models.megatron.hyena.hyena_utils import ImplicitModalFilter +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported from ....utils import distributed_model_parallel_state @@ -254,6 +255,10 @@ def test_subquadratic_ops_kernel( # noqa: D103 # Skip bf16 with short convolution due to numerical instability if test_config.params_dtype == torch.bfloat16 and operator_type == "hyena_short_conv": pytest.skip("bf16 with short convolution is skipped due to numerical instability") + try: + ensure_subquadratic_ops_supported() + except RuntimeError as e: + pytest.xfail(str(e)) with distributed_model_parallel_state(): # Create both models inside the same distributed context diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py index 2ba21709fa..eabbc62838 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py @@ -20,10 +20,12 @@ import pytest import torch import torch.distributed as dist +import torch.nn.functional as F # noqa: N812 from bionemo.evo2.models.megatron.hyena.hyena_utils import ( B2BCausalConv1dModule, ExchangeOverlappingRegionsCausal, + ParallelCausalDepthwiseConv1d, _get_inverse_zigzag_indices, _get_zigzag_indices, divide, @@ -35,6 +37,7 @@ wang_init_method, zigzag_get_overlapping_patches, ) +from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported class MockProcessGroup: @@ -120,6 +123,30 @@ def mock_b2b_causal_conv1d(x, weight_proj, weight_mixer, skip_bias): return x +@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.causal_conv1d_fn") +@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.causal_conv1d") +def test_parallel_causal_depthwise_conv1d_uses_subquadratic_fast_conv( + mock_subq_causal_conv1d, mock_fast_causal_conv1d +): + """Fast projection conv should honor use_subquadratic_ops.""" + mock_subq_causal_conv1d.side_effect = lambda x, weight: torch.zeros_like(x) + x = torch.randn(2, 4, 8) + module = types.SimpleNamespace( + kernel_size=3, + short_conv_weight=torch.ones(4, 3), + group_dim=1, + pg_collection=types.SimpleNamespace(cp=None), + use_fast_causal_conv=True, + use_subquadratic_ops=True, + ) + + y = ParallelCausalDepthwiseConv1d.forward(module, x, _use_cp=False) + + assert y.shape == x.shape + mock_subq_causal_conv1d.assert_called_once() + mock_fast_causal_conv1d.assert_not_called() + + @pytest.mark.parametrize("operator_type", ["hyena_short_conv", "hyena_medium_conv"]) def test_b2b_causal_conv1d_module_initialization(operator_type): # noqa: D103 proj_conv = MockProjConv(kernel_size=3) @@ -294,6 +321,66 @@ def test_b2b_causal_conv1d_effective_padding_size(): assert b2b_module.effective_pad_size == expected_pad_size +def test_b2b_causal_conv1d_module_matches_sequential_reference(): + """Document the isolated B2B CUDA kernel behavior before relying on the fused path.""" + if not torch.cuda.is_available(): + pytest.skip("B2B causal conv isolation test requires CUDA") + try: + ensure_subquadratic_ops_supported() + except RuntimeError as e: + pytest.xfail(str(e)) + + torch.manual_seed(1234) + batch_size = 2 + hidden_size = 4 + seq_len = 16 + proj_kernel_size = 3 + mixer_kernel_size = 7 + device = torch.device("cuda") + + x = torch.randn(batch_size, 3 * hidden_size, seq_len, device=device) + proj_weight = torch.randn(3 * hidden_size, proj_kernel_size, device=device) + mixer_weight = torch.randn(hidden_size, mixer_kernel_size, device=device) + bias = torch.randn(hidden_size, device=device) + + proj_conv = torch.nn.Module() + proj_conv.kernel_size = proj_kernel_size + proj_conv.short_conv_weight = proj_weight + proj_conv.group_dim = 1 + + mixer = torch.nn.Module() + mixer.use_conv_bias = True + mixer.group_dim = 1 + mixer.conv_bias = bias + mixer.short_conv = torch.nn.Module() + mixer.short_conv.kernel_size = mixer_kernel_size + mixer.short_conv.short_conv_weight = mixer_weight.unsqueeze(1) + + b2b_module = B2BCausalConv1dModule( + proj_conv, + mixer, + operator_type="hyena_short_conv", + pg_collection=MockProcessGroupCollection(), + ) + + fused = b2b_module(x).float() + projected = F.conv1d( + F.pad(x.float(), (proj_kernel_size - 1, 0)), + proj_weight.float().flip(-1).unsqueeze(1), + groups=3 * hidden_size, + ) + x1, x2, v = projected[:, ::3], projected[:, 1::3], projected[:, 2::3] + z = x2 * v + mixed = F.conv1d( + F.pad(z, (mixer_kernel_size - 1, 0)), + mixer_weight.float().flip(-1).unsqueeze(1), + groups=hidden_size, + ) + reference = x1 * (mixed + bias.float()[None, :, None] * z) + + torch.testing.assert_close(fused, reference, rtol=1e-4, atol=1e-4) + + def test_zigzag_get_overlapping_patches(): # noqa: D103 # Test the actual output of zigzag_get_overlapping_patches data = torch.arange(8).reshape(2, 4) # shape [2, 4] @@ -450,6 +537,27 @@ def test_fftconv_func(): assert output_short.shape == u.shape +def test_fftconv_func_bidirectional_is_prefix_invariant_when_filter_is_longer_than_input(): + """Bidirectional FFT convolution should not alias short prefixes when the filter is long.""" + torch.manual_seed(1234) + batch_size = 2 + short_len = 5 + long_len = 64 + hidden_size = 4 + filter_len = 64 + + u_short = torch.randn(batch_size, hidden_size, short_len) + u_long = torch.zeros(batch_size, hidden_size, long_len) + u_long[..., :short_len] = u_short + k = torch.randn(1, 2 * hidden_size, filter_len) + D = torch.randn(hidden_size) # noqa: N806 + + short_out = fftconv_func(u_short, k, D, None, gelu=False, bidirectional=True) + long_out = fftconv_func(u_long, k, D, None, gelu=False, bidirectional=True)[..., :short_len] + + torch.testing.assert_close(short_out, long_out, rtol=1e-5, atol=1e-5) + + def test_fftconv_func_high_dimensional_input(): """Test fftconv_func with high-dimensional input to cover the len(u.shape) > 3 case.""" batch_size = 2 diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py index e2428f4b4f..618a0f6f20 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py @@ -51,6 +51,11 @@ # Note: mbridge_checkpoint_path fixture is provided by conftest.py at session scope +def _xfail_if_unsupported_subquadratic_ops(result: subprocess.CompletedProcess, use_subquadratic_ops: bool) -> None: + if use_subquadratic_ops and "failed a CUDA self-test" in result.stderr: + pytest.xfail("subquadratic_ops_torch CUDA kernels are unsupported in this environment") + + def _read_jsonl_results(output_file: Path) -> list[dict]: """Read JSONL output file and return parsed records.""" records = [] @@ -342,6 +347,7 @@ def run_infer_subprocess( env=env, ) + _xfail_if_unsupported_subquadratic_ops(result, use_subquadratic_ops) assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" assert output_file.exists(), "Output file was not created" @@ -390,6 +396,126 @@ def _write_prompts_jsonl(prompt_file: Path, prompts: list[tuple[str, str]]) -> N f.writelines(json.dumps({"id": prompt_id, "prompt": prompt_text}) + "\n" for prompt_id, prompt_text in prompts) +@pytest.fixture( + params=[False, True], + ids=["causal-conv1d", "subquadratic-ops"], +) +def infer_use_subquadratic_ops(request): + """Whether infer should use subquadratic Hyena kernels.""" + return request.param + + +def _run_infer_prompt_file( + *, + mbridge_checkpoint_path: Path, + prompt_file: Path, + output_file: Path, + max_batch_size: int, + use_subquadratic_ops: bool, +) -> dict[str, dict]: + open_port = find_free_network_port() + cmd = [ + "torchrun", + "--nproc_per_node", + "1", + "--nnodes", + "1", + "--master_port", + str(open_port), + "-m", + "bionemo.evo2.run.infer", + "--ckpt-dir", + str(mbridge_checkpoint_path), + "--prompt-file", + str(prompt_file), + "--max-new-tokens", + "1", + "--output-file", + str(output_file), + "--temperature", + "1.0", + "--top-k", + "1", + "--seed", + "1234", + "--max-batch-size", + str(max_batch_size), + "--max-seq-length", + "512", + "--return-log-probs", + ] + if use_subquadratic_ops: + cmd.append("--use-subquadratic-ops") + + result = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + timeout=512, + env=copy.deepcopy(PRETEST_ENV), + ) + _xfail_if_unsupported_subquadratic_ops(result, use_subquadratic_ops) + assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + records = _read_jsonl_results(output_file) + return {record["id"]: record for record in records} + + +def _completion_logprobs(record: dict) -> torch.Tensor: + logprobs = record.get("logprobs", {}).get("completion_logprobs") + assert logprobs is not None, f"Missing completion logprobs in record: {record}" + tensor = torch.as_tensor(logprobs, dtype=torch.float32).flatten() + assert tensor.numel() == 1 + return tensor + + +@pytest.mark.timeout(512) +@pytest.mark.slow +def test_infer_evo2_short_prefill_is_prefix_invariant_across_batch_padding( + mbridge_checkpoint_path, + tmp_path, + infer_use_subquadratic_ops: bool, +): + """A short prefill should generate the same next token alone or in a padded batch.""" + if torch.cuda.device_count() < 1: + pytest.skip("Inference prefill prefix-invariance test requires a GPU") + + short_prompt = "ACGTACGTAA" + padding_prompt = ("GGCCGGGCGCGGTGGCTCACGCCTGTAATCCCAGCACTTTGGGAGGCCGAGGCGGGCGGATCACGAGGTC" * 4)[:256] + + alone_prompt_file = tmp_path / "short_alone_prompts.jsonl" + padded_prompt_file = tmp_path / "short_padded_prompts.jsonl" + _write_prompts_jsonl(alone_prompt_file, [("short", short_prompt)]) + _write_prompts_jsonl(padded_prompt_file, [("padding", padding_prompt), ("short", short_prompt)]) + + alone_records = _run_infer_prompt_file( + mbridge_checkpoint_path=mbridge_checkpoint_path, + prompt_file=alone_prompt_file, + output_file=tmp_path / "alone_output.jsonl", + max_batch_size=1, + use_subquadratic_ops=infer_use_subquadratic_ops, + ) + padded_records = _run_infer_prompt_file( + mbridge_checkpoint_path=mbridge_checkpoint_path, + prompt_file=padded_prompt_file, + output_file=tmp_path / "padded_output.jsonl", + max_batch_size=2, + use_subquadratic_ops=infer_use_subquadratic_ops, + ) + + assert set(alone_records) == {"short"} + assert set(padded_records) == {"padding", "short"} + assert padded_records["short"]["prompt"] == short_prompt + assert alone_records["short"]["completion"] == padded_records["short"]["completion"] + + torch.testing.assert_close( + _completion_logprobs(alone_records["short"]), + _completion_logprobs(padded_records["short"]), + rtol=2e-2, + atol=5e-2, + ) + + def run_infer_subprocess_parallel( mbridge_checkpoint_path, prompt_file: Path, @@ -524,11 +650,11 @@ def test_identical_prompts_should_be_identical(mbridge_checkpoint_path, tmp_path def test_subquadratic_ops_matches_baseline(mbridge_checkpoint_path, tmp_path): """Greedy generation with --use-subquadratic-ops must match the standard path. - This is the end-to-end correctness check for the subq-ops inference path: - Phase 1 routes engine.parallel_fir through subq-ops kernels during prefill, - Phase 2 fuses proj+mixer convs via b2b_causal_conv1d during prefill and - populates FIR caches for the subsequent decode steps. With greedy decoding - (top_k=1) and the same seed, both paths must produce identical output. + This is the end-to-end correctness check for the subq-ops inference path. + The subq path uses guarded subquadratic kernels. If the local CUDA/GPU + combination cannot run those kernels correctly, the guard raises before + invalid outputs can propagate. With greedy decoding (top_k=1) and the same + seed, supported subq kernels must produce identical output. """ output_baseline = tmp_path / "output_baseline.jsonl" output_subq = tmp_path / "output_subq.jsonl" diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py index e28006cb74..07ed1102d2 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py @@ -43,6 +43,11 @@ PRETEST_ENV = copy.deepcopy(os.environ) +def _xfail_if_unsupported_subquadratic_ops(result: subprocess.CompletedProcess, use_subquadratic_ops: bool) -> None: + if use_subquadratic_ops and "failed a CUDA self-test" in result.stderr: + pytest.xfail("subquadratic_ops_torch CUDA kernels are unsupported in this environment") + + @pytest.fixture(scope="module") def mbridge_checkpoint_1b_8k_bf16_path(mbridge_checkpoint_1b_8k_bf16) -> Path: """Module-scoped alias for the session-scoped 1b-8k-bf16 checkpoint. @@ -622,6 +627,107 @@ def test_predict_evo2_embedding_extraction( assert len(seq_idx_map) == num_sequences +@pytest.fixture( + params=[False, True], + ids=["causal-conv1d", "subquadratic-ops"], +) +def use_subquadratic_ops(request): + """Whether predict should use subquadratic Hyena kernels.""" + return request.param + + +@pytest.mark.timeout(512) +@pytest.mark.slow +def test_predict_evo2_short_embedding_is_prefix_invariant_across_batch_padding( + tmp_path, + mbridge_checkpoint_1b_8k_bf16_path: Path, + use_subquadratic_ops: bool, +): + """A short sequence should embed the same alone or padded in a longer batch.""" + if torch.cuda.device_count() < 1: + pytest.skip("Embedding prediction test requires a GPU") + + short_sequence = "ACGTACGTAA" + padding_sequence = (ALU_SEQUENCE * (256 // len(ALU_SEQUENCE) + 1))[:256] + + def _write_fasta(fasta_path: Path, records: dict[str, str]) -> None: + fasta_path.write_text("".join(f">{name}\n{sequence}\n" for name, sequence in records.items())) + + def _run_predict(fasta_path: Path, output_dir: Path) -> tuple[dict[str, torch.Tensor], dict[str, int]]: + open_port = find_free_network_port() + subquadratic_arg = " --use-subquadratic-ops" if use_subquadratic_ops else "" + command = ( + f"torchrun --nproc_per_node 1 --nnodes 1 --master_port {open_port} " + f"-m bionemo.evo2.run.predict --fasta {fasta_path} --ckpt-dir {mbridge_checkpoint_1b_8k_bf16_path} " + f"--output-dir {output_dir} --micro-batch-size 2 --write-interval epoch --embedding-layer -1" + f"{subquadratic_arg}" + ) + result = subprocess.run( + shlex.split(command), + check=False, + cwd=tmp_path, + capture_output=True, + text=True, + ) + _xfail_if_unsupported_subquadratic_ops(result, use_subquadratic_ops) + if result.returncode != 0: + print("STDOUT:\n" + result.stdout) + print("STDERR:\n" + result.stderr) + assert result.returncode == 0, f"predict_evo2 command failed with code {result.returncode}" + + pred_files = sorted(glob.glob(str(output_dir / "predictions__rank_*__dp_rank_*.pt"))) + assert len(pred_files) == 1, f"Expected 1 prediction file, got {len(pred_files)}" + with open(output_dir / "seq_idx_map.json") as f: + seq_idx_map = json.load(f) + return torch.load(pred_files[0], weights_only=True), seq_idx_map + + def _unpadded_dna_embeddings( + preds: dict[str, torch.Tensor], + seq_idx_map: dict[str, int], + seqid: str, + dna_length: int, + ) -> torch.Tensor: + matches = (preds["seq_idx"] == seq_idx_map[seqid]).nonzero(as_tuple=True)[0] + assert matches.numel() == 1 + row = matches.item() + assert preds["pad_mask"][row].sum().item() == dna_length + return preds["hidden_embeddings"][row, :dna_length].to(torch.float32) + + def _relative_frobenius_error(left: torch.Tensor, right: torch.Tensor) -> float: + numerator = (left - right).float().pow(2).sum().sqrt() + denominator = right.float().pow(2).sum().sqrt() + return float(numerator / (denominator + 1e-30)) + + def _assert_prefix_embeddings_close(left: torch.Tensor, right: torch.Tensor) -> None: + rel_error = _relative_frobenius_error(left, right) + bound = 4.0 * (1.03**33) * float(torch.finfo(torch.bfloat16).eps) + if rel_error <= bound: + return + + rel_shuffled_hidden = _relative_frobenius_error(left, torch.roll(right, shifts=-1, dims=-1)) + rel_shuffled_sequence = _relative_frobenius_error(left, torch.roll(right, shifts=-1, dims=0)) + max_abs_diff = (left - right).abs().max().item() + raise AssertionError( + "Prefix embeddings exceeded bf16 relative-norm tolerance: " + f"rel={rel_error}, bound={bound}, rel_shuffled_hidden={rel_shuffled_hidden}, " + f"rel_shuffled_sequence={rel_shuffled_sequence}, max_abs_diff={max_abs_diff}" + ) + + alone_fasta = tmp_path / "short_alone.fasta" + padded_fasta = tmp_path / "short_padded.fasta" + _write_fasta(alone_fasta, {"short": short_sequence}) + _write_fasta(padded_fasta, {"short": short_sequence, "padding": padding_sequence}) + alone_preds, alone_seq_idx_map = _run_predict(alone_fasta, tmp_path / "alone_output") + padded_preds, padded_seq_idx_map = _run_predict(padded_fasta, tmp_path / "padded_output") + assert alone_preds["hidden_embeddings"].shape[1] == len(short_sequence) + assert padded_preds["hidden_embeddings"].shape[1] == len(padding_sequence) + + alone_embeddings = _unpadded_dna_embeddings(alone_preds, alone_seq_idx_map, "short", len(short_sequence)) + padded_embeddings = _unpadded_dna_embeddings(padded_preds, padded_seq_idx_map, "short", len(short_sequence)) + + _assert_prefix_embeddings_close(alone_embeddings, padded_embeddings) + + @pytest.mark.slow def test_predict_evo2_embedding_layer_validation( tmp_path, diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py index 3f7d605f49..afe565dacc 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py @@ -15,6 +15,8 @@ """Tests for model provider instantiation, naming, and checkpoint converters.""" +from pathlib import Path + import pytest import torch @@ -22,6 +24,7 @@ HYENA_MODEL_OPTIONS, MODEL_OPTIONS, Hyena1bModelProvider, + _patch_megatron_dataset_helper_compile, infer_model_type, ) from bionemo.evo2.utils.checkpoint.mbridge_to_vortex import _split_fc1, mbridge_to_vortex_state_dict @@ -63,6 +66,46 @@ def test_infer_model_type_unknown(): infer_model_type("nonexistent_model") +@pytest.mark.parametrize( + ("has_makefile", "has_prebuilt_extension", "expected_original_calls"), + [ + (False, True, 0), + (True, True, 1), + (False, False, 1), + ], +) +def test_megatron_dataset_helper_compile_guard( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + has_makefile: bool, + has_prebuilt_extension: bool, + expected_original_calls: int, +): + """Skip Megatron's runtime make step only when a prebuilt helper extension exists.""" + from megatron.bridge.training import initialize as bridge_initialize + from megatron.core.datasets import utils as dataset_utils + + calls = [] + + def original_compile_helpers(): + calls.append("called") + + if has_makefile: + (tmp_path / "Makefile").write_text("all:\n") + if has_prebuilt_extension: + (tmp_path / "helpers_cpp.cpython-312-x86_64-linux-gnu.so").touch() + + monkeypatch.setattr(dataset_utils, "__file__", str(tmp_path / "utils.py")) + monkeypatch.setattr(dataset_utils, "compile_helpers", original_compile_helpers) + monkeypatch.setattr(bridge_initialize, "compile_helpers", original_compile_helpers) + + _patch_megatron_dataset_helper_compile() + + dataset_utils.compile_helpers() + assert bridge_initialize.compile_helpers is dataset_utils.compile_helpers + assert len(calls) == expected_original_calls + + def _make_mock_savanna_sd(pattern: str) -> dict[str, torch.Tensor]: """Create a minimal mock savanna state dict for the given pattern.