Skip to content

Commit c372637

Browse files
committed
Reduce the number of inner loop checks for compatibility
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent 4f57ec2 commit c372637

9 files changed

Lines changed: 45 additions & 97 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
import torch.nn.functional as F # noqa: N812
2020
from einops import rearrange
2121

22-
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import (
23-
ensure_subquadratic_causal_conv1d_supported,
24-
ensure_subquadratic_fft_causal_conv1d_supported,
25-
)
26-
2722

2823
try:
2924
from subquadratic_ops_torch.causal_conv1d import causal_conv1d as _subq_causal_conv1d
@@ -83,7 +78,6 @@ def parallel_fir(
8378
fir_length,
8479
compute_state,
8580
use_subquadratic_ops=False,
86-
check_subquadratic_ops=True,
8781
):
8882
"""Compute parallel finite impulse response filtering with optional state computation."""
8983
L = u.shape[1] # noqa: N806
@@ -95,8 +89,6 @@ def parallel_fir(
9589
if fir_length >= 128:
9690
if use_subquadratic_ops:
9791
# subq-ops fft_causal_conv1d expects [B, D, L] input and [D, L] filter; dtypes must match
98-
if check_subquadratic_ops and u.is_cuda:
99-
ensure_subquadratic_fft_causal_conv1d_supported()
10092
k = weight[:, :, :L].squeeze(1) if weight.dim() == 3 else weight[:, :L]
10193
u_fp32 = u.to(torch.float32)
10294
z = _subq_fft_causal_conv1d(u_fp32, k.to(torch.float32))
@@ -115,8 +107,6 @@ def parallel_fir(
115107
if _subq_causal_conv1d is None:
116108
raise ImportError(_subq_error_msg)
117109
# subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight.
118-
if check_subquadratic_ops and u.is_cuda:
119-
ensure_subquadratic_causal_conv1d_supported()
120110
pad_size = fir_length - 1
121111
x_padded = F.pad(u.to(torch.float32), (pad_size, 0))
122112
w = weight.squeeze(1) if weight.dim() == 3 else weight

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@
3333
from torch.autograd.function import Function
3434

3535
from bionemo.evo2.models.megatron.hyena.hyena_config import HyenaConfig
36-
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import (
37-
ensure_subquadratic_b2b_causal_conv1d_supported,
38-
ensure_subquadratic_causal_conv1d_supported,
39-
ensure_subquadratic_fft_causal_conv1d_supported,
40-
)
4136

4237

4338
try:
@@ -469,7 +464,6 @@ def fftconv_func(
469464
k_rev=None,
470465
bidirectional=False,
471466
use_subquadratic_ops=False,
472-
check_subquadratic_ops=True,
473467
):
474468
"""Apply a 1D convolution to the input sequence u using the filter k and the shortcut D."""
475469
seqlen = u.shape[-1]
@@ -503,8 +497,6 @@ def fftconv_func(
503497
# causal
504498
else:
505499
if use_subquadratic_ops:
506-
if check_subquadratic_ops and u.is_cuda:
507-
ensure_subquadratic_fft_causal_conv1d_supported()
508500
y = fft_causal_conv1d(u, k.squeeze(0))
509501
else:
510502
fft_size = max(fft_size, 2 * k.shape[-1])
@@ -903,7 +895,6 @@ def __init__(
903895
self.zigzag = zigzag
904896

905897
self.use_subquadratic_ops = transformer_config.use_subquadratic_ops
906-
self._subquadratic_ops_checked = False
907898

908899
self.model_parallel_size = self.pg_collection.tp.size() if self.pg_collection.tp is not None else 1
909900
self.model_parallel_rank = self.pg_collection.tp.rank() if self.pg_collection.tp is not None else 0
@@ -986,16 +977,6 @@ def reset_parameters(self):
986977
bounds = math.sqrt(1 / self.kernel_size)
987978
torch.nn.init.uniform_(self.conv_bias, a=-bounds, b=bounds)
988979

989-
def _ensure_subquadratic_ops_supported(self):
990-
"""Run expensive subquadratic-op CUDA self-tests once per operator instance."""
991-
if self._subquadratic_ops_checked or not self.use_subquadratic_ops:
992-
return
993-
if self.operator_type == "hyena_medium_conv" and self.kernel_size < 128:
994-
ensure_subquadratic_causal_conv1d_supported()
995-
else:
996-
ensure_subquadratic_fft_causal_conv1d_supported()
997-
self._subquadratic_ops_checked = True
998-
999980
def forward_long(self, *, x1, x2, v, h, bias, inference_context):
1000981
"""Forward pass long."""
1001982
import bionemo.evo2.models.megatron.hyena.engine as engine
@@ -1086,7 +1067,6 @@ def get_filter_state(filter_name):
10861067
fir_length=self.kernel_size, # self.short_filter_length,
10871068
compute_state=inference_context is not None,
10881069
use_subquadratic_ops=self.use_subquadratic_ops,
1089-
check_subquadratic_ops=False,
10901070
)
10911071
y = rearrange(y, "b d l -> b l d")
10921072
y = y * x1
@@ -1112,8 +1092,6 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None):
11121092
Input shapes: bs, (num_groups, group_size), seq_length
11131093
Output shapes: bs, (num_groups, group_size), seq_length
11141094
"""
1115-
if x1.is_cuda:
1116-
self._ensure_subquadratic_ops_supported()
11171095
B, GDG, L = x1.shape # noqa: N806
11181096
x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L]
11191097

@@ -1204,7 +1182,6 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None):
12041182
gelu=False,
12051183
bidirectional=self.bidirectional,
12061184
use_subquadratic_ops=self.use_subquadratic_ops,
1207-
check_subquadratic_ops=False,
12081185
)
12091186
z = z.to(v.dtype)
12101187

@@ -1404,7 +1381,6 @@ def __init__(
14041381
self.num_groups = num_groups
14051382
self.transformer_config = transformer_config
14061383
self.use_subquadratic_ops = transformer_config.use_subquadratic_ops
1407-
self._subquadratic_ops_checked = False
14081384
self.short_conv_L = hyena_config.short_conv_L
14091385
self.local_init = local_init
14101386
if pg_collection is None:
@@ -1496,9 +1472,6 @@ def forward(self, x, inference_context=None, _use_cp=True):
14961472
# Projection conv is fused with SE/MR layers by B2BCausalConv1dModule when available.
14971473
if self.use_fast_causal_conv: # hyena_proj_conv case
14981474
if self.use_subquadratic_ops:
1499-
if x.is_cuda and not self._subquadratic_ops_checked:
1500-
ensure_subquadratic_causal_conv1d_supported()
1501-
self._subquadratic_ops_checked = True
15021475
y = causal_conv1d(x, weight)[..., pad_size:]
15031476
else:
15041477
y = causal_conv1d_fn(x, weight, bias=None, activation=None)[..., pad_size:]
@@ -1566,7 +1539,6 @@ def __init__(
15661539
"""
15671540
super().__init__()
15681541
self.b2b_causal_conv1d_fn = b2b_causal_conv1d
1569-
self._check_subquadratic_ops = b2b_causal_conv1d is globals()["b2b_causal_conv1d"]
15701542
if pg_collection is None:
15711543
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
15721544
self.pg_collection = pg_collection
@@ -1591,14 +1563,6 @@ def __init__(
15911563
raise ValueError(f"Operator type {operator_type} not supported")
15921564

15931565
self.effective_pad_size = (self._mixer_kernel_size - 1) + (self._proj_conv_kernel_size - 1)
1594-
self._subquadratic_ops_checked = False
1595-
1596-
def _ensure_subquadratic_ops_supported(self):
1597-
"""Run the B2B CUDA self-test once per wrapper instance."""
1598-
if self._subquadratic_ops_checked or not self._check_subquadratic_ops:
1599-
return
1600-
ensure_subquadratic_b2b_causal_conv1d_supported()
1601-
self._subquadratic_ops_checked = True
16021566

16031567
def forward(self, x, _use_cp=True):
16041568
"""Forward pass for the B2BCausalConv1dModule.
@@ -1612,8 +1576,6 @@ def forward(self, x, _use_cp=True):
16121576
# Validate input dimensions
16131577
if x.dim() != 3:
16141578
raise ValueError("Input tensor must be 3D [batch_size, hidden_dim, seq_len]")
1615-
if x.is_cuda:
1616-
self._ensure_subquadratic_ops_supported()
16171579

16181580
# Extract weights at runtime to avoid parameter registration
16191581
proj_weight = self._proj_conv_module.short_conv_weight
@@ -1747,9 +1709,6 @@ def get_filter_state(filter_name):
17471709
L = u.shape[1] # noqa: N806
17481710
fir_state = get_filter_state("fir")
17491711
if fir_state is None:
1750-
if self.use_subquadratic_ops and u.is_cuda and not self._subquadratic_ops_checked:
1751-
ensure_subquadratic_causal_conv1d_supported()
1752-
self._subquadratic_ops_checked = True
17531712
z_pre, fir_state = engine.parallel_fir(
17541713
u=u,
17551714
weight=torch.tensor(weight), # self.short_filter_weight,
@@ -1759,7 +1718,6 @@ def get_filter_state(filter_name):
17591718
fir_length=self.kernel_size, # self.short_filter_length,
17601719
compute_state=inference_context is not None,
17611720
use_subquadratic_ops=self.use_subquadratic_ops,
1762-
check_subquadratic_ops=False,
17631721
)
17641722
else:
17651723
if len(u.shape) > 2:

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/subquadratic_safety.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def _assert_close_or_raise(op_name: str, actual: torch.Tensor, expected: torch.T
4141
_raise_subquadratic_self_test_error(op_name, f"max_diff={max_diff:.6g}, rel={rel:.6g}")
4242

4343

44+
@lru_cache(maxsize=None)
45+
def ensure_subquadratic_ops_supported(device_index: int | None = None) -> None:
46+
"""Validate all subquadratic_ops_torch CUDA kernels used by Evo2."""
47+
ensure_subquadratic_causal_conv1d_supported(device_index)
48+
ensure_subquadratic_fft_causal_conv1d_supported(device_index)
49+
ensure_subquadratic_b2b_causal_conv1d_supported(device_index)
50+
51+
4452
@lru_cache(maxsize=None)
4553
def ensure_subquadratic_causal_conv1d_supported(device_index: int | None = None) -> None:
4654
"""Validate subquadratic_ops_torch.causal_conv1d before using it for model data."""

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107

108108
from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH
109109
from bionemo.evo2.models.evo2_provider import HyenaInferenceContext
110+
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported
110111
from bionemo.evo2.run.predict import initialize_inference_distributed, resolve_checkpoint_path
111112
from bionemo.evo2.run.text_generation_controller import Evo2TextGenerationController
112113

@@ -469,6 +470,8 @@ def setup_inference_engine(
469470
dist_config=dist_config,
470471
)
471472
logger.info("Initialized distributed environment")
473+
if use_subquadratic_ops:
474+
ensure_subquadratic_ops_supported()
472475

473476
# -------------------------------------------------------------------------
474477
# Step 5: Create model and load weights

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106

107107
from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH
108108
from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset
109+
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported
109110
from bionemo.recipeutils.inference.collation import batch_collator
110111

111112

@@ -1093,6 +1094,8 @@ def predict(
10931094
dist_config=dist_config,
10941095
)
10951096
logger.info("Initialized distributed environment")
1097+
if use_subquadratic_ops:
1098+
ensure_subquadratic_ops_supported()
10961099

10971100
# -------------------------------------------------------------------------
10981101
# Step 5: Create model and load weights

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES
3636
from megatron.bridge.training.post_training.checkpointing import has_modelopt_state
3737
from megatron.bridge.training.pretrain import pretrain
38-
from megatron.bridge.utils.common_utils import get_rank_safe
38+
from megatron.bridge.utils.common_utils import get_local_rank_preinit, get_rank_safe
3939

4040
from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH
4141
from bionemo.evo2.models.evo2_provider import MODEL_OPTIONS, hyena_forward_step, infer_model_type
42+
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported
4243
from bionemo.evo2.recipes.evo2 import evo2_1b_pretrain_config as pretrain_config
4344

4445

@@ -885,7 +886,9 @@ def train(args: argparse.Namespace) -> None:
885886
if args.num_layers:
886887
cfg.model.num_layers = args.num_layers
887888
if args.use_subquadratic_ops:
888-
# TODO assert that it is installed
889+
if torch.cuda.is_available():
890+
torch.cuda.set_device(get_local_rank_preinit())
891+
ensure_subquadratic_ops_supported()
889892
cfg.model.use_subquadratic_ops = True
890893

891894
if args.no_activation_checkpointing:

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_engine.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn.functional as F # noqa: N812
1919

2020
from bionemo.evo2.models.megatron.hyena import engine
21+
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported
2122

2223

2324
def test_fftconv_func_is_prefix_invariant_when_filter_is_longer_than_input():
@@ -83,6 +84,11 @@ def test_parallel_fir_short_cuda_path_matches_torch_depthwise_conv1d(use_subquad
8384
"""Short FIR prefill should match F.conv1d or fail before returning bad subq output."""
8485
if not torch.cuda.is_available():
8586
pytest.skip("short FIR CUDA path requires CUDA")
87+
if use_subquadratic_ops:
88+
try:
89+
ensure_subquadratic_ops_supported()
90+
except RuntimeError as e:
91+
pytest.xfail(str(e))
8692

8793
torch.manual_seed(1234)
8894
batch_size = 2
@@ -95,21 +101,16 @@ def test_parallel_fir_short_cuda_path_matches_torch_depthwise_conv1d(use_subquad
95101
weight = torch.randn(hidden_size, 1, kernel_size, device=device)
96102
bias = torch.randn(hidden_size, device=device)
97103

98-
try:
99-
actual, state = engine.parallel_fir(
100-
u=u,
101-
weight=weight,
102-
bias=bias,
103-
L=seq_len,
104-
gated_bias=True,
105-
fir_length=kernel_size,
106-
compute_state=True,
107-
use_subquadratic_ops=use_subquadratic_ops,
108-
)
109-
except RuntimeError as e:
110-
if use_subquadratic_ops and "failed a CUDA self-test" in str(e):
111-
pytest.xfail(str(e))
112-
raise
104+
actual, state = engine.parallel_fir(
105+
u=u,
106+
weight=weight,
107+
bias=bias,
108+
L=seq_len,
109+
gated_bias=True,
110+
fir_length=kernel_size,
111+
compute_state=True,
112+
use_subquadratic_ops=use_subquadratic_ops,
113+
)
113114

114115
u_bdl = u.transpose(1, 2).contiguous()
115116
expected = F.conv1d(

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from bionemo.evo2.models.megatron.hyena.hyena_layer_specs import hyena_stack_spec_no_te
2727
from bionemo.evo2.models.megatron.hyena.hyena_mixer import HyenaMixer
2828
from bionemo.evo2.models.megatron.hyena.hyena_utils import ImplicitModalFilter
29+
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported
2930

3031
from ....utils import distributed_model_parallel_state
3132

@@ -254,6 +255,10 @@ def test_subquadratic_ops_kernel( # noqa: D103
254255
# Skip bf16 with short convolution due to numerical instability
255256
if test_config.params_dtype == torch.bfloat16 and operator_type == "hyena_short_conv":
256257
pytest.skip("bf16 with short convolution is skipped due to numerical instability")
258+
try:
259+
ensure_subquadratic_ops_supported()
260+
except RuntimeError as e:
261+
pytest.xfail(str(e))
257262

258263
with distributed_model_parallel_state():
259264
# Create both models inside the same distributed context

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
wang_init_method,
3838
zigzag_get_overlapping_patches,
3939
)
40+
from bionemo.evo2.models.megatron.hyena.subquadratic_safety import ensure_subquadratic_ops_supported
4041

4142

4243
class MockProcessGroup:
@@ -137,7 +138,6 @@ def test_parallel_causal_depthwise_conv1d_uses_subquadratic_fast_conv(
137138
pg_collection=types.SimpleNamespace(cp=None),
138139
use_fast_causal_conv=True,
139140
use_subquadratic_ops=True,
140-
_subquadratic_ops_checked=False,
141141
)
142142

143143
y = ParallelCausalDepthwiseConv1d.forward(module, x, _use_cp=False)
@@ -304,29 +304,6 @@ def test_b2b_causal_conv1d_module_device_handling(): # noqa: D103
304304
assert result_cuda.device == x_cuda.device, "Device mismatch on CUDA"
305305

306306

307-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for subquadratic guard test")
308-
@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.ensure_subquadratic_b2b_causal_conv1d_supported")
309-
@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.b2b_causal_conv1d")
310-
def test_b2b_causal_conv1d_module_checks_subquadratic_kernel_once(mock_b2b, mock_ensure): # noqa: D103
311-
mock_b2b.side_effect = mock_b2b_causal_conv1d
312-
proj_conv = MockProjConv(kernel_size=3)
313-
mixer = MockMixer(kernel_size=5)
314-
b2b_module = B2BCausalConv1dModule(
315-
proj_conv,
316-
mixer,
317-
operator_type="hyena_short_conv",
318-
b2b_causal_conv1d=mock_b2b,
319-
pg_collection=MockProcessGroupCollection(),
320-
)
321-
322-
x = torch.randn(2, 96, 32, device="cuda")
323-
b2b_module(x)
324-
b2b_module(x)
325-
326-
assert mock_ensure.call_count == 1
327-
assert mock_b2b.call_count == 2
328-
329-
330307
def test_b2b_causal_conv1d_effective_padding_size():
331308
"""Test the zigzag pattern for data distribution in context parallel mode."""
332309
proj_conv = MockProjConv(kernel_size=3)
@@ -344,14 +321,14 @@ def test_b2b_causal_conv1d_effective_padding_size():
344321
assert b2b_module.effective_pad_size == expected_pad_size
345322

346323

347-
@pytest.mark.xfail(
348-
reason="subquadratic-ops fused B2B kernel may fail CUDA/PTX self-test on unsupported GPUs",
349-
strict=True,
350-
)
351324
def test_b2b_causal_conv1d_module_matches_sequential_reference():
352325
"""Document the isolated B2B CUDA kernel behavior before relying on the fused path."""
353326
if not torch.cuda.is_available():
354327
pytest.skip("B2B causal conv isolation test requires CUDA")
328+
try:
329+
ensure_subquadratic_ops_supported()
330+
except RuntimeError as e:
331+
pytest.xfail(str(e))
355332

356333
torch.manual_seed(1234)
357334
batch_size = 2

0 commit comments

Comments
 (0)