Skip to content

Commit 4ed5d0d

Browse files
committed
Remove overly granular checks on compatability
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent 2be5190 commit 4ed5d0d

3 files changed

Lines changed: 75 additions & 17 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def parallel_fir(
8383
fir_length,
8484
compute_state,
8585
use_subquadratic_ops=False,
86+
check_subquadratic_ops=True,
8687
):
8788
"""Compute parallel finite impulse response filtering with optional state computation."""
8889
L = u.shape[1] # noqa: N806
@@ -94,7 +95,8 @@ def parallel_fir(
9495
if fir_length >= 128:
9596
if use_subquadratic_ops:
9697
# subq-ops fft_causal_conv1d expects [B, D, L] input and [D, L] filter; dtypes must match
97-
ensure_subquadratic_fft_causal_conv1d_supported()
98+
if check_subquadratic_ops and u.is_cuda:
99+
ensure_subquadratic_fft_causal_conv1d_supported()
98100
k = weight[:, :, :L].squeeze(1) if weight.dim() == 3 else weight[:, :L]
99101
u_fp32 = u.to(torch.float32)
100102
z = _subq_fft_causal_conv1d(u_fp32, k.to(torch.float32))
@@ -113,7 +115,8 @@ def parallel_fir(
113115
if _subq_causal_conv1d is None:
114116
raise ImportError(_subq_error_msg)
115117
# subq-ops causal_conv1d expects pre-padded [B, D, L+pad] input and [D, K] weight.
116-
ensure_subquadratic_causal_conv1d_supported()
118+
if check_subquadratic_ops and u.is_cuda:
119+
ensure_subquadratic_causal_conv1d_supported()
117120
pad_size = fir_length - 1
118121
x_padded = F.pad(u.to(torch.float32), (pad_size, 0))
119122
w = weight.squeeze(1) if weight.dim() == 3 else weight

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,9 @@ def causal_conv1d_fn(*args, **kwargs):
6060
from subquadratic_ops_torch.fft_causal_conv1d import fft_causal_conv1d as _subq_fft_causal_conv1d
6161
from subquadratic_ops_torch.implicit_filter import implicit_filter
6262

63-
def causal_conv1d(*args, **kwargs):
64-
"""Run guarded subquadratic causal_conv1d."""
65-
ensure_subquadratic_causal_conv1d_supported()
66-
return _subq_causal_conv1d(*args, **kwargs)
67-
68-
def b2b_causal_conv1d(*args, **kwargs):
69-
"""Run guarded subquadratic b2b_causal_conv1d."""
70-
ensure_subquadratic_b2b_causal_conv1d_supported()
71-
return _subq_b2b_causal_conv1d(*args, **kwargs)
72-
73-
def fft_causal_conv1d(*args, **kwargs):
74-
"""Run guarded subquadratic fft_causal_conv1d."""
75-
ensure_subquadratic_fft_causal_conv1d_supported()
76-
return _subq_fft_causal_conv1d(*args, **kwargs)
63+
causal_conv1d = _subq_causal_conv1d
64+
b2b_causal_conv1d = _subq_b2b_causal_conv1d
65+
fft_causal_conv1d = _subq_fft_causal_conv1d
7766
except ImportError as e:
7867
msg_causal_conv1d = f"Problem importing subquadratic_ops: {e}. causal_conv1d is not available."
7968
msg_b2b_causal_conv1d = f"Problem importing subquadratic_ops: {e}. b2b_causal_conv1d is not available."
@@ -471,7 +460,17 @@ def hyena_no_weight_decay_cond_with_embeddings(name, param):
471460
return ("embedding" in name) or hyena_no_weight_decay_cond(name, param)
472461

473462

474-
def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=False, use_subquadratic_ops=False): # noqa: N803
463+
def fftconv_func(
464+
u,
465+
k,
466+
D, # noqa: N803
467+
dropout_mask,
468+
gelu=True,
469+
k_rev=None,
470+
bidirectional=False,
471+
use_subquadratic_ops=False,
472+
check_subquadratic_ops=True,
473+
):
475474
"""Apply a 1D convolution to the input sequence u using the filter k and the shortcut D."""
476475
seqlen = u.shape[-1]
477476
fft_size = 2 * seqlen
@@ -504,6 +503,8 @@ def fftconv_func(u, k, D, dropout_mask, gelu=True, k_rev=None, bidirectional=Fal
504503
# causal
505504
else:
506505
if use_subquadratic_ops:
506+
if check_subquadratic_ops and u.is_cuda:
507+
ensure_subquadratic_fft_causal_conv1d_supported()
507508
y = fft_causal_conv1d(u, k.squeeze(0))
508509
else:
509510
fft_size = max(fft_size, 2 * k.shape[-1])
@@ -902,6 +903,7 @@ def __init__(
902903
self.zigzag = zigzag
903904

904905
self.use_subquadratic_ops = transformer_config.use_subquadratic_ops
906+
self._subquadratic_ops_checked = False
905907

906908
self.model_parallel_size = self.pg_collection.tp.size() if self.pg_collection.tp is not None else 1
907909
self.model_parallel_rank = self.pg_collection.tp.rank() if self.pg_collection.tp is not None else 0
@@ -984,6 +986,16 @@ def reset_parameters(self):
984986
bounds = math.sqrt(1 / self.kernel_size)
985987
torch.nn.init.uniform_(self.conv_bias, a=-bounds, b=bounds)
986988

989+
def _ensure_subquadratic_ops_supported(self):
990+
"""Run expensive subquadratic-op CUDA self-tests once per operator instance."""
991+
if self._subquadratic_ops_checked or not self.use_subquadratic_ops:
992+
return
993+
if self.operator_type == "hyena_medium_conv" and self.kernel_size < 128:
994+
ensure_subquadratic_causal_conv1d_supported()
995+
else:
996+
ensure_subquadratic_fft_causal_conv1d_supported()
997+
self._subquadratic_ops_checked = True
998+
987999
def forward_long(self, *, x1, x2, v, h, bias, inference_context):
9881000
"""Forward pass long."""
9891001
import bionemo.evo2.models.megatron.hyena.engine as engine
@@ -1074,6 +1086,7 @@ def get_filter_state(filter_name):
10741086
fir_length=self.kernel_size, # self.short_filter_length,
10751087
compute_state=inference_context is not None,
10761088
use_subquadratic_ops=self.use_subquadratic_ops,
1089+
check_subquadratic_ops=False,
10771090
)
10781091
y = rearrange(y, "b d l -> b l d")
10791092
y = y * x1
@@ -1099,6 +1112,8 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None):
10991112
Input shapes: bs, (num_groups, group_size), seq_length
11001113
Output shapes: bs, (num_groups, group_size), seq_length
11011114
"""
1115+
if x1.is_cuda:
1116+
self._ensure_subquadratic_ops_supported()
11021117
B, GDG, L = x1.shape # noqa: N806
11031118
x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L]
11041119

@@ -1189,6 +1204,7 @@ def forward(self, x1, x2, v, _hyena_use_cp=True, inference_context=None):
11891204
gelu=False,
11901205
bidirectional=self.bidirectional,
11911206
use_subquadratic_ops=self.use_subquadratic_ops,
1207+
check_subquadratic_ops=False,
11921208
)
11931209
z = z.to(v.dtype)
11941210

@@ -1388,6 +1404,7 @@ def __init__(
13881404
self.num_groups = num_groups
13891405
self.transformer_config = transformer_config
13901406
self.use_subquadratic_ops = transformer_config.use_subquadratic_ops
1407+
self._subquadratic_ops_checked = False
13911408
self.short_conv_L = hyena_config.short_conv_L
13921409
self.local_init = local_init
13931410
if pg_collection is None:
@@ -1543,6 +1560,7 @@ def __init__(
15431560
"""
15441561
super().__init__()
15451562
self.b2b_causal_conv1d_fn = b2b_causal_conv1d
1563+
self._check_subquadratic_ops = b2b_causal_conv1d is globals()["b2b_causal_conv1d"]
15461564
if pg_collection is None:
15471565
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
15481566
self.pg_collection = pg_collection
@@ -1567,6 +1585,14 @@ def __init__(
15671585
raise ValueError(f"Operator type {operator_type} not supported")
15681586

15691587
self.effective_pad_size = (self._mixer_kernel_size - 1) + (self._proj_conv_kernel_size - 1)
1588+
self._subquadratic_ops_checked = False
1589+
1590+
def _ensure_subquadratic_ops_supported(self):
1591+
"""Run the B2B CUDA self-test once per wrapper instance."""
1592+
if self._subquadratic_ops_checked or not self._check_subquadratic_ops:
1593+
return
1594+
ensure_subquadratic_b2b_causal_conv1d_supported()
1595+
self._subquadratic_ops_checked = True
15701596

15711597
def forward(self, x, _use_cp=True):
15721598
"""Forward pass for the B2BCausalConv1dModule.
@@ -1580,6 +1606,8 @@ def forward(self, x, _use_cp=True):
15801606
# Validate input dimensions
15811607
if x.dim() != 3:
15821608
raise ValueError("Input tensor must be 3D [batch_size, hidden_dim, seq_len]")
1609+
if x.is_cuda:
1610+
self._ensure_subquadratic_ops_supported()
15831611

15841612
# Extract weights at runtime to avoid parameter registration
15851613
proj_weight = self._proj_conv_module.short_conv_weight
@@ -1713,6 +1741,9 @@ def get_filter_state(filter_name):
17131741
L = u.shape[1] # noqa: N806
17141742
fir_state = get_filter_state("fir")
17151743
if fir_state is None:
1744+
if self.use_subquadratic_ops and u.is_cuda and not self._subquadratic_ops_checked:
1745+
ensure_subquadratic_causal_conv1d_supported()
1746+
self._subquadratic_ops_checked = True
17161747
z_pre, fir_state = engine.parallel_fir(
17171748
u=u,
17181749
weight=torch.tensor(weight), # self.short_filter_weight,
@@ -1722,6 +1753,7 @@ def get_filter_state(filter_name):
17221753
fir_length=self.kernel_size, # self.short_filter_length,
17231754
compute_state=inference_context is not None,
17241755
use_subquadratic_ops=self.use_subquadratic_ops,
1756+
check_subquadratic_ops=False,
17251757
)
17261758
else:
17271759
if len(u.shape) > 2:

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,29 @@ def test_b2b_causal_conv1d_module_device_handling(): # noqa: D103
278278
assert result_cuda.device == x_cuda.device, "Device mismatch on CUDA"
279279

280280

281+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for subquadratic guard test")
282+
@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.ensure_subquadratic_b2b_causal_conv1d_supported")
283+
@patch("bionemo.evo2.models.megatron.hyena.hyena_utils.b2b_causal_conv1d")
284+
def test_b2b_causal_conv1d_module_checks_subquadratic_kernel_once(mock_b2b, mock_ensure): # noqa: D103
285+
mock_b2b.side_effect = mock_b2b_causal_conv1d
286+
proj_conv = MockProjConv(kernel_size=3)
287+
mixer = MockMixer(kernel_size=5)
288+
b2b_module = B2BCausalConv1dModule(
289+
proj_conv,
290+
mixer,
291+
operator_type="hyena_short_conv",
292+
b2b_causal_conv1d=mock_b2b,
293+
pg_collection=MockProcessGroupCollection(),
294+
)
295+
296+
x = torch.randn(2, 96, 32, device="cuda")
297+
b2b_module(x)
298+
b2b_module(x)
299+
300+
assert mock_ensure.call_count == 1
301+
assert mock_b2b.call_count == 2
302+
303+
281304
def test_b2b_causal_conv1d_effective_padding_size():
282305
"""Test the zigzag pattern for data distribution in context parallel mode."""
283306
proj_conv = MockProjConv(kernel_size=3)

0 commit comments

Comments
 (0)