Skip to content

Commit 85a8dcc

Browse files
authored
Short sequence prefix-invariant evo2 implementation (NVIDIA-BioNeMo#1580)
### Description Changes: * codex added to top level devcontainer * bump causal-conv1d, megatron-bridge, and associated dependencies * add test coverage for prefix invariance when running evo2 on very short sequences through inference and training <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated development container configuration for improved build environment setup. * Added build dependencies to support enhanced model training infrastructure. * **Bug Fixes** * Improved robustness of GPU kernel operations with enhanced validation checks. * Enhanced model compatibility across different system configurations. * **Tests** * Expanded test coverage for model inference, training, and GPU kernel correctness. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: John St. John <jstjohn@nvidia.com> Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent 0a594f5 commit 85a8dcc

26 files changed

Lines changed: 992 additions & 64 deletions

.devcontainer/Dockerfile

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
1010
--mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \
1111
PIP_CONSTRAINT= pip install -r /workspace/requirements.txt
1212

13+
# Sandboxed agent CLIs use these helpers on Linux.
14+
RUN apt-get update && apt-get install -y --no-install-recommends \
15+
bubblewrap \
16+
uidmap \
17+
&& rm -rf /var/lib/apt/lists/*
18+
1319
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/
20+
1421
USER ubuntu
15-
RUN curl https://cursor.com/install -fsS | bash # Install cursor-agent CLI tool
16-
RUN curl -fsSL https://claude.ai/install.sh | bash # Install Claude CLI tool
22+
RUN curl https://cursor.com/install -fsS | bash || true # Install cursor-agent CLI tool
23+
RUN curl -fsSL https://claude.ai/install.sh | bash || true # Install Claude CLI tool
1724
RUN uv tool install pre-commit --with pre-commit-uv --force-reinstall
1825
ENV PATH="/home/ubuntu/.local/bin:${PATH}"

.devcontainer/devcontainer.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"source=${localEnv:HOME}/.cache,target=/home/ubuntu/.cache,type=bind,consistency=cached",
99
"source=${localEnv:HOME}/.claude,target=/home/ubuntu/.claude,type=bind,consistency=cached",
1010
"source=${localEnv:HOME}/.claude.json,target=/home/ubuntu/.claude.json,type=bind,consistency=cached",
11+
"source=${localEnv:HOME}/.codex,target=/home/ubuntu/.codex,type=bind,consistency=cached",
1112
"source=${localEnv:HOME}/.config,target=/home/ubuntu/.config,type=bind,consistency=cached",
1213
"source=${localEnv:HOME}/.cursor,target=/home/ubuntu/.cursor,type=bind,consistency=cached",
1314
"source=${localEnv:HOME}/.gnupg,target=/home/ubuntu/.gnupg,type=bind,consistency=cached",

.devcontainer/initializeCommand.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mkdir -p ~/.gnupg
88
mkdir -p ~/.config
99
mkdir -p ~/.cursor
1010
mkdir -p ~/.claude
11+
mkdir -p ~/.codex
1112
[ ! -f ~/.netrc ] && touch ~/.netrc
1213

1314
[ ! -f ~/.bash_history_devcontainer ] && touch ~/.bash_history_devcontainer

.devcontainer/postCreateCommand.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
#!/usr/bin/env bash
22
set -euo pipefail
3+
4+
# Install Codex as the devcontainer user so the binary lands in the mounted user environment.
5+
if ! command -v codex >/dev/null 2>&1; then
6+
curl -fsSL https://chatgpt.com/codex/install.sh | sh || true # do not fail if there are URL resolutions with codex
7+
fi
38
# Run via uv to avoid relying on updated PATH in this shell
49
if git rev-parse --is-inside-work-tree >/dev/null 2>&1; then
510
# Some editors (VS Code, Cursor) set core.hooksPath in .git/config, which

.devcontainer/start.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
99
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
1010
DEVCONTAINER_JSON="${SCRIPT_DIR}/devcontainer.json"
1111
CONTAINER_NAME="${BIONEMO_CONTAINER_NAME:-bionemo-devcontainer}"
12-
IMAGE_NAME="${BIONEMO_IMAGE_NAME:-bionemo-devcontainer:latest}"
12+
IMAGE_NAME="${BIONEMO_IMAGE_NAME:-${CONTAINER_NAME}:latest}"
1313

1414
# ---------------------------------------------------------------------------
1515
# Helpers

bionemo-recipes/recipes/evo2_megatron/.ci_build.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
rm -f /usr/local/lib/python*/dist-packages/transformer_engine-*.dist-info/direct_url.json
66
export UV_LOCK_TIMEOUT=900 # increase to 15 minutes (900 seconds), adjust as needed
77
export UV_LINK_MODE=copy
8-
uv venv --system-site-packages
8+
uv venv --clear --system-site-packages
99

1010
# 2. Activate the environment
1111
source .venv/bin/activate
@@ -38,8 +38,8 @@ for pkg_dir in "$RECIPE_ROOT/../../../sub-packages/bionemo-recipeutils" "$RECIPE
3838
fi
3939
done
4040

41-
# 6. Install the recipe with all remaining dependencies
42-
uv pip install -c pip-constraints.txt -e . --no-build-isolation
41+
# 6. Install the recipe with all remaining dependencies, including test extras
42+
uv pip install -c pip-constraints.txt -e '.[test]' --no-build-isolation
4343

4444
# 7. Restore original pyproject.toml (the edit was only needed for uv resolution)
4545
mv pyproject.toml.ci_bak pyproject.toml
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
poetry-core
2+
poetry_dynamic_versioning # build dep of nvidia-resiliency-ext (transitively pulled by megatron-bridge); needed in the venv because we install with --no-build-isolation
3+
grpcio-tools # build dep of nvidia-resiliency-ext: its setup.py shells out to `python -m grpc_tools.protoc` to compile *.proto files; --no-build-isolation means we have to provide it in the venv up-front
24
wheel_stub
35
ninja # should speed up causal-conv1d build

bionemo-recipes/recipes/evo2_megatron/pyproject.toml

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ dependencies = [
2424
"causal_conv1d",
2525
"nv-grouped-gemm",
2626
"megatron-core",
27-
"nvidia-resiliency-ext",
27+
# nvidia-resiliency-ext is pulled transitively by megatron-bridge.
2828
"emerging_optimizers",
2929
"subquadratic-ops-torch-cu13",
30+
"email-validator",
3031

3132
# These are dependencies for examples only, but are useful for actually doing analyses with this model
3233
"biopython",
@@ -35,7 +36,9 @@ dependencies = [
3536
]
3637

3738
[project.optional-dependencies]
38-
test = []
39+
test = [
40+
"pytest>=8.0",
41+
]
3942

4043
[project.scripts]
4144
torchrun = "torch.distributed.run:main"
@@ -88,22 +91,29 @@ override-dependencies = [
8891
"triton; sys_platform == 'never'",
8992
"transformer-engine; sys_platform == 'never'",
9093
"transformer-engine[pytorch]; sys_platform == 'never'",
94+
# Avoid alpha Pydantic releases; langchain imports pulled by nvidia-resiliency-ext are not compatible.
95+
"pydantic>=2.12,<2.14",
96+
# Avoid optional log-pattern-mining dependency conflicts from nvidia-resiliency-ext.
97+
"logsage; sys_platform == 'never'",
98+
"drain3; sys_platform == 'never'",
9199
]
92100

93101
[tool.uv.sources]
94102
# Shared recipe utilities (framework-agnostic)
95103
# External dependencies with specific git tags/commits
96-
causal_conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d.git", tag = "v1.5.4" }
104+
# 1.6.1 fixes a custom-op no-storage failure in no-grad/frozen forward paths.
105+
causal_conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d.git", tag = "v1.6.1" }
97106
nv-grouped-gemm = { git = "https://github.com/fanshiqing/grouped_gemm", tag = "v1.1.4.post6" }
98107

99108
# Internal dependencies
100109
bionemo-recipeutils = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch = "main", subdirectory = "sub-packages/bionemo-recipeutils" }
101110
bionemo-core = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch = "main", subdirectory = "sub-packages/bionemo-core" }
102-
nvidia-resiliency-ext = { git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git", rev = "54f85fe422d296cf04ea524130014bd3a2c3add1" }
111+
# nvidia-resiliency-ext is intentionally left to Megatron-Bridge so the transitive pin stays consistent.
103112

104-
# Megatron Bundle. This points to a version that still supports the deprecated no_weight_decay_cond field until the API for an alternative has been finalized.
105-
megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f" }
106-
megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f", subdirectory = "3rdparty/Megatron-LM" }
113+
# Megatron Bundle. MCore is sourced from the same Megatron-Bridge release tag.
114+
megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", tag = "v0.4.1" }
115+
megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", tag = "v0.4.1", subdirectory = "3rdparty/Megatron-LM" }
107116

108117
[tool.uv.extra-build-dependencies]
109118
warp-lang = ["wheel_stub"]
119+
nvidia-resiliency-ext = ["poetry_dynamic_versioning"]

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919

2020
import math
21+
import sys
2122
from dataclasses import dataclass
2223
from functools import partial
24+
from pathlib import Path
2325
from typing import Callable, Iterable, Literal, Optional, Type
2426

2527
import torch
@@ -35,6 +37,7 @@
3537
from megatron.bridge.training.state import GlobalState
3638
from megatron.bridge.training.utils.packed_seq_utils import get_packed_seq_params
3739
from megatron.bridge.training.utils.pg_utils import get_pg_collection
40+
from megatron.bridge.utils.instantiate_utils import register_allowed_target_prefix
3841
from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size
3942
from megatron.core import parallel_state
4043
from megatron.core.inference.contexts import StaticInferenceContext
@@ -53,6 +56,33 @@
5356
from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond
5457

5558

59+
def _patch_megatron_dataset_helper_compile() -> None:
60+
"""Skip Megatron's runtime helper build when a wheel already ships the extension."""
61+
from megatron.core.datasets import utils as dataset_utils
62+
63+
original_compile_helpers = dataset_utils.compile_helpers
64+
if getattr(original_compile_helpers, "_evo2_prebuilt_helper_guard", False):
65+
guarded_compile_helpers = original_compile_helpers
66+
else:
67+
68+
def guarded_compile_helpers() -> None:
69+
datasets_dir = Path(dataset_utils.__file__).resolve().parent
70+
if not (datasets_dir / "Makefile").exists() and list(datasets_dir.glob("helpers_cpp*.so")):
71+
return None
72+
return original_compile_helpers()
73+
74+
guarded_compile_helpers._evo2_prebuilt_helper_guard = True
75+
dataset_utils.compile_helpers = guarded_compile_helpers
76+
77+
bridge_initialize = sys.modules.get("megatron.bridge.training.initialize")
78+
if bridge_initialize is not None:
79+
bridge_initialize.compile_helpers = guarded_compile_helpers
80+
81+
82+
_patch_megatron_dataset_helper_compile()
83+
register_allowed_target_prefix("bionemo.evo2.")
84+
85+
5686
def get_vocab_size(*args, **kwargs):
5787
raise NotImplementedError("FIXME get_vocab_size is not implemented Find it in megatron bridge")
5888

@@ -306,7 +336,7 @@ class HyenaModelProvider(TransformerConfig, ModelProviderMixin[MCoreHyenaModel])
306336
apply_rope_fusion: bool = True
307337
make_vocab_size_divisible_by: int = 128
308338
gated_linear_unit: bool = True
309-
fp32_residual_connection: bool = True
339+
fp32_residual_connection: bool = False
310340
normalization: str = "RMSNorm"
311341
add_bias_linear: bool = False
312342
hidden_dropout: float = 0.0

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@
1919
import torch.nn.functional as F # noqa: N812
2020
from einops import rearrange
2121

22+
from bionemo.evo2.models.megatron.hyena.fft_utils import linear_causal_fft_size
23+
2224

2325
try:
2426
from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d
2527
from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d
28+
from subquadratic_ops_torch.rearrange import rearrange as _subq_rearrange
2629
except ImportError as _subq_import_error:
2730
_subq_causal_conv1d = None
2831
_subq_fft_causal_conv1d = None
32+
_subq_rearrange = None
2933
_subq_error_msg = f"subquadratic_ops_torch not available: {_subq_import_error}"
3034

3135

@@ -50,7 +54,7 @@ def fftconv_func(*, u, k, D): # noqa: N803
5054
The convolution is computed in the frequency domain and then transformed back to the time domain.
5155
"""
5256
seqlen = u.shape[-1]
53-
fft_size = 2 * seqlen
57+
fft_size = linear_causal_fft_size(seqlen, k.shape[-1])
5458

5559
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
5660
k_f = adjust_filter_shape_for_broadcast(u, k_f)
@@ -76,11 +80,15 @@ def parallel_fir(
7680
):
7781
"""Compute parallel finite impulse response filtering with optional state computation."""
7882
L = u.shape[1] # noqa: N806
79-
u = rearrange(u, "b l d -> b d l")
8083

8184
if use_subquadratic_ops and _subq_fft_causal_conv1d is None:
8285
raise ImportError(_subq_error_msg)
8386

87+
if use_subquadratic_ops:
88+
u = _subq_rearrange(u.transpose(0, 1), bhl_to_lbh=False)
89+
else:
90+
u = rearrange(u, "b l d -> b d l")
91+
8492
if fir_length >= 128:
8593
if use_subquadratic_ops:
8694
# subq-ops fft_causal_conv1d expects [B, D, L] input and [D, L] filter; dtypes must match
@@ -99,7 +107,9 @@ def parallel_fir(
99107
).to(dtype=u.dtype)
100108
else:
101109
if use_subquadratic_ops:
102-
# subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight; dtypes must match
110+
if _subq_causal_conv1d is None:
111+
raise ImportError(_subq_error_msg)
112+
# subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight.
103113
pad_size = fir_length - 1
104114
x_padded = F.pad(u.to(torch.float32), (pad_size, 0))
105115
w = weight.squeeze(1) if weight.dim() == 3 else weight
@@ -111,7 +121,7 @@ def parallel_fir(
111121
bias=None,
112122
stride=1,
113123
padding=fir_length - 1,
114-
groups=u.shape[1], # always set to D, regardless of filter grouping
124+
groups=u.shape[1],
115125
)[..., :L]
116126

117127
z = z.to(u.dtype)
@@ -130,7 +140,7 @@ def parallel_fir(
130140

131141
def parallel_iir(*, z_pre, h, D, L, poles, t, hidden_size, compute_state): # noqa: N803
132142
"""Compute the output state of the short convolutional filter."""
133-
fft_size = 2 * L
143+
fft_size = linear_causal_fft_size(L, h.shape[-1])
134144
x1, x2, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
135145

136146
x1v = x1 * v
@@ -221,9 +231,9 @@ def prefill_via_modal_fft(*, x1v, L, poles, t, X_s): # noqa: N803
221231
# When the model has a long convolution derived from a recurrence in modal form and prefill_style is "fft",
222232
# we split the filter into poles and residues and reuse FFT computation on the input.
223233
bs = x1v.shape[0]
224-
fft_size = 2 * L
234+
fft_size = X_s.shape[-1]
225235
state_s = (poles.to(torch.float32) * t).exp()
226-
state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # noqa N806: B, D, state_dim, 2 * L
236+
state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # noqa N806: B, D, state_dim, fft_size
227237
state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
228238
# Do not try to fix `UserWarning: Casting complex values to real discards
229239
# the imaginary part` by inserting state.real conversion anywhere before

0 commit comments

Comments
 (0)