Skip to content

Commit 5411e09

Browse files
committed
Address PR feedback
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent c372637 commit 5411e09

3 files changed

Lines changed: 37 additions & 8 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def __init__(
121121
pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(layer_type_list)
122122

123123
if get_cpu_offload_context is not None:
124-
# Megatron Core changed this helper from six to seven positional arguments
125-
# across releases. Pass only the arguments accepted by the installed version.
124+
# MCore 0.x has shipped both six- and seven-argument variants of this helper.
125+
# Pass only the arguments accepted by the installed version; if a future helper
126+
# uses *args, pass the full compatibility list rather than counting *args as one slot.
126127
offload_args = [
127128
self.config.cpu_offloading,
128129
self.config.cpu_offloading_num_layers,
@@ -132,9 +133,16 @@ def __init__(
132133
self.config.cpu_offloading_double_buffering,
133134
getattr(self.config, "cpu_offloading_retain_pinned_cpu_buffers", False),
134135
]
135-
num_offload_params = len(inspect.signature(get_cpu_offload_context).parameters)
136+
offload_params = tuple(inspect.signature(get_cpu_offload_context).parameters.values())
137+
if any(param.kind is inspect.Parameter.VAR_POSITIONAL for param in offload_params):
138+
num_offload_args = len(offload_args)
139+
else:
140+
num_offload_args = sum(
141+
param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
142+
for param in offload_params
143+
)
136144
(self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context(
137-
*offload_args[:num_offload_params],
145+
*offload_args[:num_offload_args],
138146
)
139147
self.config._cpu_offloading_context = self.offload_context if self.config.cpu_offloading else None
140148
else:

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def fftconv_func(
467467
):
468468
"""Apply a 1D convolution to the input sequence u using the filter k and the shortcut D."""
469469
seqlen = u.shape[-1]
470-
fft_size = 2 * seqlen
470+
fft_size = max(2 * seqlen, 2 * k.shape[-1])
471471

472472
# check if k is less than seqlen -- subquadratic_ops input does not need padding
473473
if not use_subquadratic_ops and k.shape[-1] < seqlen:
@@ -499,7 +499,6 @@ def fftconv_func(
499499
if use_subquadratic_ops:
500500
y = fft_causal_conv1d(u, k.squeeze(0))
501501
else:
502-
fft_size = max(fft_size, 2 * k.shape[-1])
503502
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
504503
if k_rev is not None:
505504
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
@@ -646,6 +645,8 @@ def compute_filter(self, L, t, glogp, R): # noqa: N803
646645

647646
return h, None
648647

648+
# Keep this eager. The short-prefill prefix-invariance tests in tests/bionemo/evo2/run
649+
# cover the prior torch.compile regression with dynamic filter lengths and custom ops.
649650
def filter(self, L, *args, **kwargs): # noqa: N803
650651
"""Get t and the convolution filter for t and the requested sequence length."""
651652
if self._cp_size > 1:
@@ -768,8 +769,7 @@ def forward(self, L, *args, **kwargs): # noqa: N803
768769
"""
769770
return self.filter(L, *args, **kwargs)
770771

771-
# Keep this eager. Compiling this helper can leave global dispatcher state
772-
# that interferes with unrelated custom autograd/custom-op call sites.
772+
# Keep this eager for the same short-prefill prefix-invariance reproducer as ImplicitModalFilter.filter.
773773
def filter(self, L, *args, **kwargs): # noqa: N803
774774
"""Compute the filter as a function of h and decay for the requested sequence length."""
775775
h = self.h[:, :L]

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,27 @@ def test_fftconv_func():
537537
assert output_short.shape == u.shape
538538

539539

540+
def test_fftconv_func_bidirectional_is_prefix_invariant_when_filter_is_longer_than_input():
541+
"""Bidirectional FFT convolution should not alias short prefixes when the filter is long."""
542+
torch.manual_seed(1234)
543+
batch_size = 2
544+
short_len = 5
545+
long_len = 64
546+
hidden_size = 4
547+
filter_len = 64
548+
549+
u_short = torch.randn(batch_size, hidden_size, short_len)
550+
u_long = torch.zeros(batch_size, hidden_size, long_len)
551+
u_long[..., :short_len] = u_short
552+
k = torch.randn(1, 2 * hidden_size, filter_len)
553+
D = torch.randn(hidden_size) # noqa: N806
554+
555+
short_out = fftconv_func(u_short, k, D, None, gelu=False, bidirectional=True)
556+
long_out = fftconv_func(u_long, k, D, None, gelu=False, bidirectional=True)[..., :short_len]
557+
558+
torch.testing.assert_close(short_out, long_out, rtol=1e-5, atol=1e-5)
559+
560+
540561
def test_fftconv_func_high_dimensional_input():
541562
"""Test fftconv_func with high-dimensional input to cover the len(u.shape) > 3 case."""
542563
batch_size = 2

0 commit comments

Comments
 (0)