Skip to content

Commit 49be647

Browse files
committed
Go back to the original subq version, assume it works on other gpus and fail loudly if the CUDA_ERROR_UNSUPPORTED_PTX_VERSION error comes up
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent f2cd6a1 commit 49be647

10 files changed

Lines changed: 288 additions & 24 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
# nvidia-resiliency-ext is pulled transitively by megatron-bridge.
2828
"emerging_optimizers",
2929
"subquadratic-ops-torch-cu13",
30+
"email-validator",
3031

3132
# These are dependencies for examples only, but are useful for actually doing analyses with this model
3233
"biopython",
@@ -88,6 +89,8 @@ override-dependencies = [
8889
"triton; sys_platform == 'never'",
8990
"transformer-engine; sys_platform == 'never'",
9091
"transformer-engine[pytorch]; sys_platform == 'never'",
92+
# Avoid alpha Pydantic releases; langchain imports pulled by nvidia-resiliency-ext are not compatible.
93+
"pydantic>=2.12,<2.14",
9194
# Avoid optional log-pattern-mining dependency conflicts from nvidia-resiliency-ext.
9295
"logsage; sys_platform == 'never'",
9396
"drain3; sys_platform == 'never'",

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,17 @@
1919
import torch.nn.functional as F # noqa: N812
2020
from einops import rearrange
2121

22+
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import (
23+
ensure_subquadratic_causal_conv1d_supported,
24+
ensure_subquadratic_fft_causal_conv1d_supported,
25+
)
26+
2227

2328
try:
29+
from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d
2430
from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d
2531
except ImportError as _subq_import_error:
32+
_subq_causal_conv1d = None
2633
_subq_fft_causal_conv1d = None
2734
_subq_error_msg = f"subquadratic_ops_torch not available: {_subq_import_error}"
2835

@@ -87,6 +94,7 @@ def parallel_fir(
8794
if fir_length >= 128:
8895
if use_subquadratic_ops:
8996
# subq-ops fft_causal_conv1d expects [B, D, L] input and [D, L] filter; dtypes must match
97+
ensure_subquadratic_fft_causal_conv1d_supported()
9098
k = weight[:, :, :L].squeeze(1) if weight.dim() == 3 else weight[:, :L]
9199
u_fp32 = u.to(torch.float32)
92100
z = _subq_fft_causal_conv1d(u_fp32, k.to(torch.float32))
@@ -101,14 +109,24 @@ def parallel_fir(
101109
D=bias,
102110
).to(dtype=u.dtype)
103111
else:
104-
z = F.conv1d(
105-
u.to(torch.float32),
106-
weight.to(torch.float32),
107-
bias=None,
108-
stride=1,
109-
padding=fir_length - 1,
110-
groups=u.shape[1], # always set to D, regardless of filter grouping
111-
)[..., :L]
112+
if use_subquadratic_ops:
113+
if _subq_causal_conv1d is None:
114+
raise ImportError(_subq_error_msg)
115+
# subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight.
116+
ensure_subquadratic_causal_conv1d_supported()
117+
pad_size = fir_length - 1
118+
x_padded = F.pad(u.to(torch.float32), (pad_size, 0))
119+
w = weight.squeeze(1) if weight.dim() == 3 else weight
120+
z = _subq_causal_conv1d(x_padded, w.to(torch.float32))[..., pad_size:]
121+
else:
122+
z = F.conv1d(
123+
u.to(torch.float32),
124+
weight.to(torch.float32),
125+
bias=None,
126+
stride=1,
127+
padding=fir_length - 1,
128+
groups=u.shape[1],
129+
)[..., :L]
112130

113131
z = z.to(u.dtype)
114132

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_mixer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ def __init__(
119119
self.fast_conv_mixer = self.hyena_config.fast_conv_mixer
120120

121121
self.use_subquadratic_ops = self.transformer_config.use_subquadratic_ops
122-
# TODO: Re-enable B2BCausalConv1dModule for short/medium Hyena layers once
123-
# subquadratic-ops updates it to support causal_conv1d 1.6+ semantics.
124-
self.use_fused_b2b_causal_conv1d = False
122+
self.use_fused_b2b_causal_conv1d = self.use_subquadratic_ops
125123

126124
# Per attention head and per partition values.
127125
assert torch.distributed.is_initialized()

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
from torch.autograd.function import Function
3434

3535
from bionemo.evo2.models.megatron.hyena.hyena_config import HyenaConfig
36+
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import (
37+
ensure_subquadratic_b2b_causal_conv1d_supported,
38+
ensure_subquadratic_causal_conv1d_supported,
39+
ensure_subquadratic_fft_causal_conv1d_supported,
40+
)
3641

3742

3843
try:
@@ -50,10 +55,25 @@ def causal_conv1d_fn(*args, **kwargs):
5055

5156

5257
try:
53-
from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d
54-
from subquadratic_ops_torch.causal_conv1d import causal_conv1d
55-
from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d
58+
from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d as _subq_b2b_causal_conv1d
59+
from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d
60+
from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d
5661
from subquadratic_ops_torch.implicit_filter import implicit_filter
62+
63+
def causal_conv1d(*args, **kwargs):
64+
"""Run guarded subquadratic causal_conv1d."""
65+
ensure_subquadratic_causal_conv1d_supported()
66+
return _subq_causal_conv1d(*args, **kwargs)
67+
68+
def b2b_causal_conv1d(*args, **kwargs):
69+
"""Run guarded subquadratic b2b_causal_conv1d."""
70+
ensure_subquadratic_b2b_causal_conv1d_supported()
71+
return _subq_b2b_causal_conv1d(*args, **kwargs)
72+
73+
def fft_causal_conv1d(*args, **kwargs):
74+
"""Run guarded subquadratic fft_causal_conv1d."""
75+
ensure_subquadratic_fft_causal_conv1d_supported()
76+
return _subq_fft_causal_conv1d(*args, **kwargs)
5777
except ImportError as e:
5878
msg_causal_conv1d = f"Problem importing subquadratic_ops: {e}. causal_conv1d is not available."
5979
msg_b2b_causal_conv1d = f"Problem importing subquadratic_ops: {e}. b2b_causal_conv1d is not available."
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from functools import lru_cache
17+
18+
import torch
19+
import torch.nn.functional as F # noqa: N812
20+
21+
22+
def _raise_subquadratic_self_test_error(op_name: str, detail: str) -> None:
23+
raise RuntimeError(
24+
f"subquadratic_ops_torch.{op_name} failed a CUDA self-test ({detail}). "
25+
"This often happens with CUDA_ERROR_UNSUPPORTED_PTX_VERSION or unsupported GPU/toolchain "
26+
"combinations. Refusing to run this subquadratic kernel because it can otherwise return "
27+
"invalid outputs without raising."
28+
)
29+
30+
31+
def _assert_close_or_raise(op_name: str, actual: torch.Tensor, expected: torch.Tensor) -> None:
32+
torch.cuda.synchronize(actual.device)
33+
if not torch.isfinite(actual).all():
34+
_raise_subquadratic_self_test_error(op_name, "non-finite output")
35+
36+
if not torch.allclose(actual, expected, rtol=1e-4, atol=1e-4):
37+
max_diff = (actual.float() - expected.float()).abs().max().item()
38+
rel = (
39+
(actual.float() - expected.float()).pow(2).sum().sqrt() / (expected.float().pow(2).sum().sqrt() + 1e-30)
40+
).item()
41+
_raise_subquadratic_self_test_error(op_name, f"max_diff={max_diff:.6g}, rel={rel:.6g}")
42+
43+
44+
@lru_cache(maxsize=None)
45+
def ensure_subquadratic_causal_conv1d_supported(device_index: int | None = None) -> None:
46+
"""Validate subquadratic_ops_torch.causal_conv1d before using it for model data."""
47+
if not torch.cuda.is_available():
48+
return
49+
50+
device_index = torch.cuda.current_device() if device_index is None else device_index
51+
device = torch.device("cuda", device_index)
52+
53+
from subquadratic_ops_torch.causal_conv1d import causal_conv1d as subq_causal_conv1d
54+
55+
batch_size = 1
56+
hidden_size = 4
57+
seq_len = 8
58+
kernel_size = 3
59+
pad_size = kernel_size - 1
60+
61+
u = torch.linspace(-1.0, 1.0, steps=batch_size * hidden_size * seq_len, device=device).reshape(
62+
batch_size, hidden_size, seq_len
63+
)
64+
weight = torch.linspace(-0.5, 0.5, steps=hidden_size * kernel_size, device=device).reshape(
65+
hidden_size, kernel_size
66+
)
67+
68+
expected = F.conv1d(
69+
u,
70+
weight.unsqueeze(1),
71+
bias=None,
72+
stride=1,
73+
padding=pad_size,
74+
groups=hidden_size,
75+
)[..., :seq_len]
76+
actual = subq_causal_conv1d(F.pad(u, (pad_size, 0)), weight)[..., pad_size:]
77+
_assert_close_or_raise("causal_conv1d", actual, expected)
78+
79+
80+
@lru_cache(maxsize=None)
81+
def ensure_subquadratic_fft_causal_conv1d_supported(device_index: int | None = None) -> None:
82+
"""Validate subquadratic_ops_torch.fft_causal_conv1d before using it for model data."""
83+
if not torch.cuda.is_available():
84+
return
85+
86+
device_index = torch.cuda.current_device() if device_index is None else device_index
87+
device = torch.device("cuda", device_index)
88+
89+
from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as subq_fft_causal_conv1d
90+
91+
batch_size = 1
92+
hidden_size = 4
93+
seq_len = 8
94+
kernel_size = 5
95+
96+
u = torch.linspace(-1.0, 1.0, steps=batch_size * hidden_size * seq_len, device=device).reshape(
97+
batch_size, hidden_size, seq_len
98+
)
99+
weight = torch.linspace(-0.5, 0.5, steps=hidden_size * kernel_size, device=device).reshape(
100+
hidden_size, kernel_size
101+
)
102+
103+
expected = F.conv1d(
104+
u,
105+
weight.flip(-1).unsqueeze(1),
106+
bias=None,
107+
stride=1,
108+
padding=kernel_size - 1,
109+
groups=hidden_size,
110+
)[..., :seq_len]
111+
actual = subq_fft_causal_conv1d(u, weight)
112+
_assert_close_or_raise("fft_causal_conv1d", actual, expected)
113+
114+
115+
@lru_cache(maxsize=None)
116+
def ensure_subquadratic_b2b_causal_conv1d_supported(device_index: int | None = None) -> None:
117+
"""Validate subquadratic_ops_torch.b2b_causal_conv1d before using it for model data."""
118+
if not torch.cuda.is_available():
119+
return
120+
121+
device_index = torch.cuda.current_device() if device_index is None else device_index
122+
device = torch.device("cuda", device_index)
123+
124+
from subquadratic_ops_torch.b2b_causal_conv1d import b2b_causal_conv1d as subq_b2b_causal_conv1d
125+
126+
batch_size = 1
127+
hidden_size = 2
128+
seq_len = 10
129+
proj_kernel_size = 3
130+
mixer_kernel_size = 7
131+
132+
x = torch.linspace(-1.0, 1.0, steps=batch_size * 3 * hidden_size * seq_len, device=device).reshape(
133+
batch_size, 3 * hidden_size, seq_len
134+
)
135+
proj_weight = torch.linspace(-0.5, 0.5, steps=3 * hidden_size * proj_kernel_size, device=device).reshape(
136+
3 * hidden_size, proj_kernel_size
137+
)
138+
mixer_weight = torch.linspace(-0.25, 0.25, steps=hidden_size * mixer_kernel_size, device=device).reshape(
139+
hidden_size, mixer_kernel_size
140+
)
141+
bias = torch.linspace(-0.1, 0.1, steps=hidden_size, device=device)
142+
143+
actual = subq_b2b_causal_conv1d(x, proj_weight, mixer_weight, bias)
144+
145+
projected = F.conv1d(
146+
F.pad(x, (proj_kernel_size - 1, 0)),
147+
proj_weight.flip(-1).unsqueeze(1),
148+
groups=3 * hidden_size,
149+
)
150+
x1, x2, v = projected[:, ::3], projected[:, 1::3], projected[:, 2::3]
151+
z = x2 * v
152+
mixed = F.conv1d(
153+
F.pad(z, (mixer_kernel_size - 1, 0)),
154+
mixer_weight.flip(-1).unsqueeze(1),
155+
groups=hidden_size,
156+
)
157+
expected = x1 * (mixed + bias[None, :, None] * z)
158+
_assert_close_or_raise("b2b_causal_conv1d", actual, expected)

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,14 @@
7777
)
7878
from megatron.bridge.training.config import DistributedInitConfig, RNGConfig
7979
from megatron.bridge.training.mixed_precision import get_mixed_precision_config
80-
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
80+
81+
82+
try:
83+
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
84+
except ImportError:
85+
from megatron.core.tokenizers.text.libraries.huggingface_tokenizer import (
86+
HuggingFaceTokenizer as _HuggingFaceTokenizer,
87+
)
8188
from megatron.bridge.training.utils.checkpoint_utils import (
8289
file_exists,
8390
get_checkpoint_run_config_filename,

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
17-
# SPDX-License-Identifier: LicenseRef-Apache2
18-
16+
import pytest
1917
import torch
18+
import torch.nn.functional as F # noqa: N812
2019

2120
from bionemo.evo2.models.megatron.hyena import engine
2221

@@ -77,3 +76,51 @@ def test_parallel_iir_is_prefix_invariant_when_filter_is_longer_than_input():
7776
)
7877

7978
torch.testing.assert_close(short_out, long_out[:, :short_len], rtol=1e-5, atol=1e-5)
79+
80+
81+
@pytest.mark.parametrize("use_subquadratic_ops", [False, True], ids=["torch", "subq"])
82+
def test_parallel_fir_short_cuda_path_matches_torch_depthwise_conv1d(use_subquadratic_ops):
83+
"""Short FIR prefill should match F.conv1d or fail before returning bad subq output."""
84+
if not torch.cuda.is_available():
85+
pytest.skip("short FIR CUDA path requires CUDA")
86+
87+
torch.manual_seed(1234)
88+
batch_size = 2
89+
seq_len = 17
90+
hidden_size = 8
91+
kernel_size = 7
92+
device = torch.device("cuda")
93+
94+
u = torch.randn(batch_size, seq_len, hidden_size, device=device)
95+
weight = torch.randn(hidden_size, 1, kernel_size, device=device)
96+
bias = torch.randn(hidden_size, device=device)
97+
98+
try:
99+
actual, state = engine.parallel_fir(
100+
u=u,
101+
weight=weight,
102+
bias=bias,
103+
L=seq_len,
104+
gated_bias=True,
105+
fir_length=kernel_size,
106+
compute_state=True,
107+
use_subquadratic_ops=use_subquadratic_ops,
108+
)
109+
except RuntimeError as e:
110+
if use_subquadratic_ops and "failed a CUDA self-test" in str(e):
111+
pytest.xfail(str(e))
112+
raise
113+
114+
u_bdl = u.transpose(1, 2).contiguous()
115+
expected = F.conv1d(
116+
u_bdl.float(),
117+
weight.float(),
118+
bias=None,
119+
stride=1,
120+
padding=kernel_size - 1,
121+
groups=hidden_size,
122+
)[..., :seq_len]
123+
expected = expected.to(u.dtype) + bias[None, :, None] * u_bdl
124+
125+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
126+
torch.testing.assert_close(state, u_bdl[..., -(kernel_size - 1) :])

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,11 @@ def test_b2b_causal_conv1d_effective_padding_size():
296296

297297

298298
@pytest.mark.xfail(
299-
reason="subquadratic-ops fused B2B kernel does not match causal_conv1d 1.6+ short-conv semantics",
299+
reason="subquadratic-ops fused B2B kernel may fail CUDA/PTX self-test on unsupported GPUs",
300300
strict=True,
301301
)
302302
def test_b2b_causal_conv1d_module_matches_sequential_reference():
303-
"""Document the isolated B2B mismatch before re-enabling the fused path."""
303+
"""Document the isolated B2B CUDA kernel behavior before relying on the fused path."""
304304
if not torch.cuda.is_available():
305305
pytest.skip("B2B causal conv isolation test requires CUDA")
306306

0 commit comments

Comments
 (0)