Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,26 @@ RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \
PIP_CONSTRAINT= pip install -r /workspace/requirements.txt

# Sandboxed agent CLIs use these helpers on Linux.
RUN apt-get update && apt-get install -y --no-install-recommends \
bubblewrap \
uidmap \
&& rm -rf /var/lib/apt/lists/*

COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/

# Install Node.js 22 LTS and the OpenAI Codex CLI.
RUN curl -fsSL https://deb.nodesource.com/setup_22.x | bash - \
&& apt-get install -y --no-install-recommends nodejs \
&& rm -rf /var/lib/apt/lists/* \
&& npm install -g --no-fund --no-audit @openai/codex \
&& npm cache clean --force

# Default Codex to Landlock where nested namespaces are restricted.
RUN mkdir -p /home/ubuntu/.codex \
&& printf '[features]\nuse_legacy_landlock = true\n' > /home/ubuntu/.codex/config.toml \
&& chown -R ubuntu:ubuntu /home/ubuntu/.codex

USER ubuntu
RUN curl https://cursor.com/install -fsS | bash # Install cursor-agent CLI tool
RUN curl -fsSL https://claude.ai/install.sh | bash # Install Claude CLI tool
Expand Down
1 change: 1 addition & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"source=${localEnv:HOME}/.cache,target=/home/ubuntu/.cache,type=bind,consistency=cached",
"source=${localEnv:HOME}/.claude,target=/home/ubuntu/.claude,type=bind,consistency=cached",
"source=${localEnv:HOME}/.claude.json,target=/home/ubuntu/.claude.json,type=bind,consistency=cached",
"source=${localEnv:HOME}/.codex,target=/home/ubuntu/.codex,type=bind,consistency=cached",
"source=${localEnv:HOME}/.config,target=/home/ubuntu/.config,type=bind,consistency=cached",
"source=${localEnv:HOME}/.cursor,target=/home/ubuntu/.cursor,type=bind,consistency=cached",
"source=${localEnv:HOME}/.gnupg,target=/home/ubuntu/.gnupg,type=bind,consistency=cached",
Expand Down
1 change: 1 addition & 0 deletions .devcontainer/initializeCommand.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mkdir -p ~/.gnupg
mkdir -p ~/.config
mkdir -p ~/.cursor
mkdir -p ~/.claude
mkdir -p ~/.codex
[ ! -f ~/.netrc ] && touch ~/.netrc

[ ! -f ~/.bash_history_devcontainer ] && touch ~/.bash_history_devcontainer
Expand Down
2 changes: 1 addition & 1 deletion .devcontainer/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
DEVCONTAINER_JSON="${SCRIPT_DIR}/devcontainer.json"
CONTAINER_NAME="${BIONEMO_CONTAINER_NAME:-bionemo-devcontainer}"
IMAGE_NAME="${BIONEMO_IMAGE_NAME:-bionemo-devcontainer:latest}"
IMAGE_NAME="${BIONEMO_IMAGE_NAME:-${CONTAINER_NAME}:latest}"

# ---------------------------------------------------------------------------
# Helpers
Expand Down
6 changes: 3 additions & 3 deletions bionemo-recipes/recipes/evo2_megatron/.ci_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
rm -f /usr/local/lib/python*/dist-packages/transformer_engine-*.dist-info/direct_url.json
export UV_LOCK_TIMEOUT=900 # increase to 15 minutes (900 seconds), adjust as needed
export UV_LINK_MODE=copy
uv venv --system-site-packages
uv venv --clear --system-site-packages

# 2. Activate the environment
source .venv/bin/activate
Expand Down Expand Up @@ -38,8 +38,8 @@ for pkg_dir in "$RECIPE_ROOT/../../../sub-packages/bionemo-recipeutils" "$RECIPE
fi
done

# 6. Install the recipe with all remaining dependencies
uv pip install -c pip-constraints.txt -e . --no-build-isolation
# 6. Install the recipe with all remaining dependencies, including test extras
uv pip install -c pip-constraints.txt -e '.[test]' --no-build-isolation

# 7. Restore original pyproject.toml (the edit was only needed for uv resolution)
mv pyproject.toml.ci_bak pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
poetry-core
poetry_dynamic_versioning # build dep of nvidia-resiliency-ext (transitively pulled by megatron-bridge); needed in the venv because we install with --no-build-isolation
grpcio-tools # build dep of nvidia-resiliency-ext: its setup.py shells out to `python -m grpc_tools.protoc` to compile *.proto files; --no-build-isolation means we have to provide it in the venv up-front
wheel_stub
ninja # should speed up causal-conv1d build
24 changes: 17 additions & 7 deletions bionemo-recipes/recipes/evo2_megatron/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ dependencies = [
"causal_conv1d",
"nv-grouped-gemm",
"megatron-core",
"nvidia-resiliency-ext",
# nvidia-resiliency-ext is pulled transitively by megatron-bridge.
"emerging_optimizers",
"subquadratic-ops-torch-cu13",
"email-validator",

# These are dependencies for examples only, but are useful for actually doing analyses with this model
"biopython",
Expand All @@ -35,7 +36,9 @@ dependencies = [
]

[project.optional-dependencies]
test = []
test = [
"pytest>=8.0",
]

[project.scripts]
torchrun = "torch.distributed.run:main"
Expand Down Expand Up @@ -88,22 +91,29 @@ override-dependencies = [
"triton; sys_platform == 'never'",
"transformer-engine; sys_platform == 'never'",
"transformer-engine[pytorch]; sys_platform == 'never'",
# Avoid alpha Pydantic releases; langchain imports pulled by nvidia-resiliency-ext are not compatible.
"pydantic>=2.12,<2.14",
# Avoid optional log-pattern-mining dependency conflicts from nvidia-resiliency-ext.
"logsage; sys_platform == 'never'",
"drain3; sys_platform == 'never'",
]

[tool.uv.sources]
# Shared recipe utilities (framework-agnostic)
# External dependencies with specific git tags/commits
causal_conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d.git", tag = "v1.5.4" }
# 1.6.1 fixes a custom-op no-storage failure in no-grad/frozen forward paths.
causal_conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d.git", tag = "v1.6.1" }
nv-grouped-gemm = { git = "https://github.com/fanshiqing/grouped_gemm", tag = "v1.1.4.post6" }

# Internal dependencies
bionemo-recipeutils = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch = "main", subdirectory = "sub-packages/bionemo-recipeutils" }
bionemo-core = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch = "main", subdirectory = "sub-packages/bionemo-core" }
nvidia-resiliency-ext = { git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git", rev = "54f85fe422d296cf04ea524130014bd3a2c3add1" }
# nvidia-resiliency-ext is intentionally left to Megatron-Bridge so the transitive pin stays consistent.

# Megatron Bundle. This points to a version that still supports the deprecated no_weight_decay_cond field until the API for an alternative has been finalized.
megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f" }
megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f", subdirectory = "3rdparty/Megatron-LM" }
# Megatron Bundle. MCore is sourced from the same Megatron-Bridge release tag.
megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", tag = "v0.4.1" }
megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", tag = "v0.4.1", subdirectory = "3rdparty/Megatron-LM" }

[tool.uv.extra-build-dependencies]
warp-lang = ["wheel_stub"]
nvidia-resiliency-ext = ["poetry_dynamic_versioning"]
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@


import math
import sys
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Iterable, Literal, Optional, Type

import torch
Expand All @@ -35,6 +37,7 @@
from megatron.bridge.training.state import GlobalState
from megatron.bridge.training.utils.packed_seq_utils import get_packed_seq_params
from megatron.bridge.training.utils.pg_utils import get_pg_collection
from megatron.bridge.utils.instantiate_utils import register_allowed_target_prefix
from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size
from megatron.core import parallel_state
from megatron.core.inference.contexts import StaticInferenceContext
Expand All @@ -53,6 +56,33 @@
from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond


def _patch_megatron_dataset_helper_compile() -> None:
"""Skip Megatron's runtime helper build when a wheel already ships the extension."""
from megatron.core.datasets import utils as dataset_utils

original_compile_helpers = dataset_utils.compile_helpers
if getattr(original_compile_helpers, "_evo2_prebuilt_helper_guard", False):
guarded_compile_helpers = original_compile_helpers
else:

def guarded_compile_helpers() -> None:
datasets_dir = Path(dataset_utils.__file__).resolve().parent
if not (datasets_dir / "Makefile").exists() and list(datasets_dir.glob("helpers_cpp*.so")):
return None
return original_compile_helpers()

guarded_compile_helpers._evo2_prebuilt_helper_guard = True
dataset_utils.compile_helpers = guarded_compile_helpers

bridge_initialize = sys.modules.get("megatron.bridge.training.initialize")
if bridge_initialize is not None:
bridge_initialize.compile_helpers = guarded_compile_helpers


_patch_megatron_dataset_helper_compile()
register_allowed_target_prefix("bionemo.evo2.")


def get_vocab_size(*args, **kwargs):
raise NotImplementedError("FIXME get_vocab_size is not implemented Find it in megatron bridge")

Expand Down Expand Up @@ -306,7 +336,7 @@ class HyenaModelProvider(TransformerConfig, ModelProviderMixin[MCoreHyenaModel])
apply_rope_fusion: bool = True
make_vocab_size_divisible_by: int = 128
gated_linear_unit: bool = True
fp32_residual_connection: bool = True
fp32_residual_connection: bool = False
normalization: str = "RMSNorm"
add_bias_linear: bool = False
hidden_dropout: float = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
_subq_error_msg = f"subquadratic_ops_torch not available: {_subq_import_error}"


def _linear_causal_fft_size(input_len: int, filter_len: int) -> int:
"""Return an FFT size that cannot alias a causal convolution prefix."""
return max(2 * input_len, 2 * filter_len)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can do this,

if filter_len <= 2* input_len:
   return min(input_len + filter_len, 2 * filter_len)
return 2 * max(input_len, filter_len)



def adjust_filter_shape_for_broadcast(u, h):
"""Adjust filter shape for broadcasting compatibility with input tensor."""
h = h.squeeze() # Standardize to [D, L] from [1, D, L] and [D, 1, L]
Expand All @@ -50,7 +55,7 @@ def fftconv_func(*, u, k, D): # noqa: N803
The convolution is computed in the frequency domain and then transformed back to the time domain.
"""
seqlen = u.shape[-1]
fft_size = 2 * seqlen
fft_size = _linear_causal_fft_size(seqlen, k.shape[-1])

k_f = torch.fft.rfft(k, n=fft_size) / fft_size
k_f = adjust_filter_shape_for_broadcast(u, k_f)
Expand Down Expand Up @@ -99,7 +104,9 @@ def parallel_fir(
).to(dtype=u.dtype)
else:
if use_subquadratic_ops:
# subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight; dtypes must match
if _subq_causal_conv1d is None:
raise ImportError(_subq_error_msg)
# subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight.
pad_size = fir_length - 1
x_padded = F.pad(u.to(torch.float32), (pad_size, 0))
w = weight.squeeze(1) if weight.dim() == 3 else weight
Expand All @@ -111,7 +118,7 @@ def parallel_fir(
bias=None,
stride=1,
padding=fir_length - 1,
groups=u.shape[1], # always set to D, regardless of filter grouping
groups=u.shape[1],
)[..., :L]

z = z.to(u.dtype)
Expand All @@ -130,7 +137,7 @@ def parallel_fir(

def parallel_iir(*, z_pre, h, D, L, poles, t, hidden_size, compute_state): # noqa: N803
"""Compute the output state of the short convolutional filter."""
fft_size = 2 * L
fft_size = _linear_causal_fft_size(L, h.shape[-1])
x1, x2, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)

x1v = x1 * v
Expand Down Expand Up @@ -221,9 +228,9 @@ def prefill_via_modal_fft(*, x1v, L, poles, t, X_s): # noqa: N803
# When the model has a long convolution derived from a recurrence in modal form and prefill_style is "fft",
# we split the filter into poles and residues and reuse FFT computation on the input.
bs = x1v.shape[0]
fft_size = 2 * L
fft_size = X_s.shape[-1]
state_s = (poles.to(torch.float32) * t).exp()
state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # noqa N806: B, D, state_dim, 2 * L
state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # noqa N806: B, D, state_dim, fft_size
state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
# Do not try to fix `UserWarning: Casting complex values to real discards
# the imaginary part` by inserting state.real conversion anywhere before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Optional, Union
Expand Down Expand Up @@ -120,13 +121,28 @@ def __init__(
pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(layer_type_list)

if get_cpu_offload_context is not None:
(self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context(
# MCore 0.x has shipped both six- and seven-argument variants of this helper.
# Pass only the arguments accepted by the installed version; if a future helper
# uses *args, pass the full compatibility list rather than counting *args as one slot.
offload_args = [
self.config.cpu_offloading,
self.config.cpu_offloading_num_layers,
self.config.num_layers,
self.config.cpu_offloading_activations,
self.config.cpu_offloading_weights,
self.config.cpu_offloading_double_buffering,
getattr(self.config, "cpu_offloading_retain_pinned_cpu_buffers", False),
]
offload_params = tuple(inspect.signature(get_cpu_offload_context).parameters.values())
if any(param.kind is inspect.Parameter.VAR_POSITIONAL for param in offload_params):
num_offload_args = len(offload_args)
else:
num_offload_args = sum(
param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
for param in offload_params
)
(self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context(
*offload_args[:num_offload_args],
)
self.config._cpu_offloading_context = self.offload_context if self.config.cpu_offloading else None
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ def bias_dropout_add_exec_handler(self):
if self.training:
return torch.enable_grad
else:
# Validation, Test, Inference, Etc.
return torch.inference_mode
# torch.inference_mode marks outputs as inference tensors. Those flags
# persist after leaving the context and can break downstream autograd or
# torch.library custom ops that consume frozen model outputs.
return torch.no_grad

def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

logger = logging.getLogger(__name__)


try:
from transformer_engine.common.recipe import DelayedScaling, Format
except ImportError:
Expand Down Expand Up @@ -117,7 +118,6 @@ def __init__(
self.fast_conv_proj = self.hyena_config.fast_conv_proj
self.fast_conv_mixer = self.hyena_config.fast_conv_mixer

# Use b2b causal conv1d
self.use_subquadratic_ops = self.transformer_config.use_subquadratic_ops

# Per attention head and per partition values.
Expand Down Expand Up @@ -198,8 +198,8 @@ def __init__(
)

if self.use_subquadratic_ops:
# Create a wrapper module that doesn't register parameters
# Use the existing weights from the original model
# The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack
# cannot run subquadratic_ops_torch correctly.
self.b2b_kernel = B2BCausalConv1dModule(
self.hyena_proj_conv,
self.mixer,
Expand Down Expand Up @@ -229,8 +229,8 @@ def __init__(
)

if self.use_subquadratic_ops and self.operator_type == "hyena_medium_conv":
# Create a wrapper module that doesn't register parameters
# Use the existing weights from the original model
# The B2B kernel is guarded in hyena_utils and fails early if the local CUDA stack
# cannot run subquadratic_ops_torch correctly.
self.b2b_kernel = B2BCausalConv1dModule(
self.hyena_proj_conv,
self.mixer,
Expand Down Expand Up @@ -312,8 +312,8 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True
"hyena_short_conv",
"hyena_medium_conv",
]
# b2b runs during training (no inference_context) or during prefill (no FIR cache yet).
# During decode (cache populated, L=1) we fall back to the regular per-token step path.
# B2B runs during training (no inference_context) or during prefill (no FIR cache yet).
# During decode, fall back to the regular per-token step path.
is_prefill = inference_context is not None and id(self.hyena_proj_conv) not in getattr(
inference_context, "fir_filter_state_dict", {}
)
Expand Down
Loading
Loading