Skip to content

Commit 8e5a865

Browse files
jstjohnclaude
andauthored
[Evo2] Switch from static to faster/more modern dynamic inference engine (NVIDIA-BioNeMo#1597)
### Description * New dynamic inference engine in evo2 with cudagraph support. Benchmarked with a 1024-token prompt and 1024 requested generated tokens. All runs verified the output JSONL reported 1024 prompt tokens and 1024 completion tokens. Comparison performed on 2xA6000 GPUs at bf16 precision. | Model | Parallelism | Prompt / Generation | `origin/main` static engine | Dynamic engine | Speedup | Tokens verified | |---|---:|---:|---:|---:|---:|---| | Evo2 1B | TP=1 | 1024 / 1024 | 38.7 tok/s, 26.44s | 129.6 tok/s, 7.90s | 3.35x | 1024 prompt + 1024 completion | | Evo2 7B | TP=2 | 1024 / 1024 | 28.4 tok/s, 36.07s | 62.2 tok/s, 16.47s | 2.19x | 1024 prompt + 1024 completion | #### Usage <!--- How does a user interact with the changed code --> ```python TODO: Add code snippet ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks). This label can be used to enforce running all framework tests. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. #### Triggering Code Rabbit AI Review To trigger a code review from code rabbit, comment on a pull request with one of these commands: - @coderabbitai review - Triggers a standard review - @coderabbitai full review - Triggers a comprehensive review See https://docs.coderabbit.ai/reference/review-commands for a full list of commands. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Native dynamic inference engine for improved performance and memory efficiency * Chunked prefill and dynamic batching support for flexible inference control * **Improvements** * Optimized inference state management through in-place operations * CUDA graph acceleration for faster decode-time inference * **Documentation** * Updated fine-tuning tutorial with corrected CLI arguments for prediction commands <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: John St John <jstjohn@nvidia.com> Signed-off-by: John St. John <jstjohn@nvidia.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 6aa5103 commit 8e5a865

18 files changed

Lines changed: 2774 additions & 1070 deletions

bionemo-recipes/recipes/evo2_megatron/examples/evo2_classifier.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
get_checkpoint_run_config_filename,
7373
read_run_config,
7474
)
75-
from megatron.bridge.utils.instantiate_utils import instantiate
75+
from megatron.bridge.utils.instantiate_utils import instantiate, register_allowed_target_prefix
7676
from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size
7777
from megatron.core import dist_checkpointing, parallel_state
7878
from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator
@@ -97,6 +97,15 @@
9797
logger: logging.Logger = logging.getLogger(__name__)
9898

9999

100+
# This example is launched as ``evo2_classifier.py`` (e.g. ``torchrun ... evo2_classifier.py``), so
101+
# the providers defined below are serialized into a trained checkpoint's run_config with a
102+
# ``_target_`` of ``evo2_classifier.<Provider>``. Megatron-Bridge's ``instantiate`` only resolves
103+
# targets under an allow-listed module prefix, so register this module's prefix here (mirroring
104+
# ``bionemo.evo2.`` in evo2_provider.py). Without it, rebuilding the model at predict time in
105+
# ``_build_classifier_from_checkpoint`` raises InstantiationException.
106+
register_allowed_target_prefix("evo2_classifier.")
107+
108+
100109
# ─────────────────────────────────────────────────────────────────────────────
101110
# Model: subclass of HyenaModel with a classification head
102111
# ─────────────────────────────────────────────────────────────────────────────

bionemo-recipes/recipes/evo2_megatron/examples/fine-tuning-tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@
10101010
"\n",
10111011
"You can now run inference with:\n",
10121012
" infer_evo2 --ckpt-dir pretraining_demo/evo2/checkpoints --prompt 'ATCGATCG' --max-new-tokens 100\n",
1013-
" predict_evo2 --ckpt-dir pretraining_demo/evo2/checkpoints --input-fasta <your_fasta> --output-dir <output>\n"
1013+
" predict_evo2 --ckpt-dir pretraining_demo/evo2/checkpoints --fasta <your_fasta> --output-dir <output>\n"
10141014
]
10151015
}
10161016
],
@@ -1030,7 +1030,7 @@
10301030
"\n",
10311031
"print(\"\\nYou can now run inference with:\")\n",
10321032
"print(f\" infer_evo2 --ckpt-dir {ckpt_dir} --prompt 'ATCGATCG' --max-new-tokens 100\")\n",
1033-
"print(f\" predict_evo2 --ckpt-dir {ckpt_dir} --input-fasta <your_fasta> --output-dir <output>\")"
1033+
"print(f\" predict_evo2 --ckpt-dir {ckpt_dir} --fasta <your_fasta> --output-dir <output>\")"
10341034
]
10351035
},
10361036
{

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,248 @@ def reset(self):
146146
delattr(self, key)
147147

148148

149+
# =============================================================================
150+
# Dynamic-inference Hyena state packing
151+
# =============================================================================
152+
153+
154+
class _PackedHyenaSlotStateDict(dict):
155+
"""``id(module)`` -> packed-slot view map for Hyena recurrent state.
156+
157+
Hyena ops read/write recurrent state through ``*_filter_state_dict`` attributes keyed
158+
by ``id(module)``. This dict preserves that API while routing registered ids into
159+
sub-slice views of the live ``DynamicInferenceContext`` Mamba state buffers:
160+
``mamba_conv_states[layer, slot]`` for the projection FIR ring and
161+
``mamba_ssm_states[layer, slot]`` for the layer mixer state. Unregistered ids fall
162+
back to plain dict storage.
163+
"""
164+
165+
def __init__(self, kind: str):
166+
super().__init__()
167+
self._kind = kind
168+
# id(module) -> view tensor (a sub-slice of the packed slot buffer).
169+
self._views: dict = {}
170+
171+
def register(self, module_id: int, view: "torch.Tensor") -> None:
172+
"""Register the packed-slot sub-slice view that backs ``module_id``."""
173+
self._views[module_id] = view
174+
175+
def __setitem__(self, module_id, state):
176+
view = self._views.get(module_id)
177+
if view is None or state is None:
178+
# Unregistered owner, or an explicit None clear (re-prefill seed wipe).
179+
super().__setitem__(module_id, state)
180+
return
181+
if state.data_ptr() != view.data_ptr():
182+
# Prefill seed, or a realloc step branch returning a NEW tensor: copy into the
183+
# packed view. The in-place step branch returns the view itself -> no-op copy.
184+
if state.shape == view.shape:
185+
view.copy_(state)
186+
else:
187+
# Short prefill can seed a FIR ring smaller than the allocated slot. Right-align
188+
# the available tail so decode can use the fixed-size in-place ring immediately.
189+
assert state.shape[:-1] == view.shape[:-1] and state.shape[-1] <= view.shape[-1], (
190+
f"packed {self._kind} seed shape {tuple(state.shape)} incompatible with ring view "
191+
f"{tuple(view.shape)} (only the FIR ring last dim may be shorter)."
192+
)
193+
view.zero_()
194+
view[..., view.shape[-1] - state.shape[-1] :].copy_(state)
195+
super().__setitem__(module_id, view)
196+
197+
def reset_for_new_request(self) -> None:
198+
"""Drop dict entries so ``.get(id)`` returns None and the caller re-prefills."""
199+
super().clear()
200+
201+
202+
def build_evo2_mamba_inference_state_config(model, *, conv_dtype=None, ssm_dtype=None):
203+
"""Build the mcore Mamba state config used by Evo2 dynamic inference.
204+
205+
Evo2 Hyena layers expose two recurrent state slots per layer, matching the two slots
206+
that ``DynamicInferenceContext`` allocates for hybrid Mamba models: ``conv_states``
207+
for the projection FIR ring and ``ssm_states`` for the layer mixer state. The
208+
``HyenaStack`` provides the uniform packed shapes and layer type list that mcore uses
209+
to allocate those buffers and map layer numbers to state slots.
210+
211+
The slot dtypes default to fp32 because Hyena decode recurrences run in fp32 and update
212+
the packed sub-slice views in place.
213+
214+
Args:
215+
model: The Evo2 ``HyenaModel``.
216+
conv_dtype: Override for the conv slot dtype (default ``torch.float32``).
217+
ssm_dtype: Override for the ssm slot dtype (default ``torch.float32``).
218+
219+
Returns:
220+
A ``MambaInferenceStateConfig`` ready to pass as
221+
``InferenceConfig(mamba_inference_state_config=...)``.
222+
"""
223+
from megatron.core.inference.config import (
224+
MambaInferenceStateConfig, # lazy: heavy mcore import — keep evo2_provider importable without the full inference stack
225+
)
226+
227+
decoder = model.decoder if hasattr(model, "decoder") else model
228+
conv_states_shape, ssm_states_shape = decoder.mamba_state_shapes_per_request()
229+
layer_type_list = decoder.layer_type_list # mcore symbols, set in HyenaStack.__init__
230+
return MambaInferenceStateConfig(
231+
layer_type_list=list(layer_type_list),
232+
conv_states_shape=tuple(conv_states_shape),
233+
ssm_states_shape=tuple(ssm_states_shape),
234+
conv_states_dtype=conv_dtype or torch.float32,
235+
ssm_states_dtype=ssm_dtype or torch.float32,
236+
)
237+
238+
239+
def make_evo2_dynamic_inference_context_cls():
240+
"""Return mcore's ``DynamicInferenceContext`` class for Evo2 decode.
241+
242+
Evo2 constrains each standalone decode context to the active request count and enables
243+
decode-only CUDA graph dimensions, so the graph path does not need an Evo2-specific
244+
context subclass. Keeping the exact mcore type also preserves mcore's CUDA graph
245+
argument checks without runtime compatibility hooks.
246+
247+
Returns:
248+
The mcore ``DynamicInferenceContext`` class.
249+
"""
250+
from megatron.core.inference.contexts.dynamic_context import (
251+
DynamicInferenceContext, # lazy: heavy mcore import; keep evo2_provider importable without the full inference stack
252+
)
253+
254+
return DynamicInferenceContext
255+
256+
257+
def compute_evo2_paged_kv_buffer_size_gb(
258+
model_config,
259+
*,
260+
mamba_state_config,
261+
max_sequence_length: int,
262+
max_requests: int,
263+
block_size_tokens: int = 256,
264+
safety_blocks: int = 2,
265+
) -> float:
266+
"""Compute a right-sized ``buffer_size_gb`` for one Evo2 dynamic context.
267+
268+
``DynamicInferenceContext`` derives its KV block count from ``buffer_size_gb``. For
269+
hybrid models in the installed mcore version, the no-``mamba_memory_ratio`` path uses
270+
``buffer_size_bytes // (block_size_bytes + mamba_states_memory_per_request)``. This
271+
helper mirrors that arithmetic and returns the smallest buffer that covers
272+
``ceil(max_sequence_length / block_size_tokens) + 1 dummy + safety_blocks`` KV blocks.
273+
274+
Args:
275+
model_config: The Evo2 ``HyenaModel`` transformer config.
276+
mamba_state_config: The ``MambaInferenceStateConfig`` produced by
277+
:func:`build_evo2_mamba_inference_state_config`.
278+
max_sequence_length: Prompt plus generation length to allocate for.
279+
max_requests: The context's ``max_requests``.
280+
block_size_tokens: KV block size used by the context.
281+
safety_blocks: Extra KV blocks beyond the requested sequence length.
282+
283+
Returns:
284+
The ``buffer_size_gb`` value to pass to ``InferenceConfig``.
285+
"""
286+
# --- Per-partition attention geometry (mcore dynamic_context.py:285-299). ---
287+
num_attention_heads = getattr(model_config, "num_query_groups", None) or model_config.num_attention_heads
288+
kv_channels = getattr(model_config, "kv_channels", None) or (
289+
model_config.hidden_size // model_config.num_attention_heads
290+
)
291+
projection_size = kv_channels * num_attention_heads
292+
head_dim = projection_size // num_attention_heads
293+
tp_size = int(getattr(model_config, "tensor_model_parallel_size", 1) or 1)
294+
heads_per_partition = num_attention_heads // tp_size if num_attention_heads >= tp_size else 1
295+
296+
# --- Layer-type counts from the mamba state config's layer_type_list. ---
297+
# Symbols.ATTENTION layers need paged KV; the Hyena ("mamba"-slotted) layers do NOT (they hold
298+
# only the conv/ssm recurrent-state slots). Counting from layer_type_list keeps this correct
299+
# under any (truncated) hybrid_override_pattern.
300+
from megatron.core.ssm.mamba_hybrid_layer_allocation import ( # lazy: heavy mcore import — keep evo2_provider importable without the full inference stack
301+
Symbols as _McoreSymbols,
302+
)
303+
304+
layer_type_list = list(mamba_state_config.layer_type_list)
305+
num_attention_layers = sum(1 for s in layer_type_list if s == _McoreSymbols.ATTENTION)
306+
num_mamba_layers = sum(1 for s in layer_type_list if s == _McoreSymbols.MAMBA)
307+
308+
# --- block_size_bytes (mcore dynamic_context.py:376-383). ---
309+
kv_dtype_size_bytes = model_config.params_dtype.itemsize
310+
block_size_bytes = (
311+
kv_dtype_size_bytes * 2 * num_attention_layers * block_size_tokens * heads_per_partition * head_dim
312+
)
313+
314+
# --- mamba_states_memory_per_request (mcore dynamic_context.py:386-394). ---
315+
conv_bytes = math.prod(mamba_state_config.conv_states_shape) * mamba_state_config.conv_states_dtype.itemsize
316+
ssm_bytes = math.prod(mamba_state_config.ssm_states_shape) * mamba_state_config.ssm_states_dtype.itemsize
317+
mamba_per_request = (conv_bytes + ssm_bytes) * num_mamba_layers
318+
319+
# --- Target KV block count: requested sequence + dummy block + safety. ---
320+
target_blocks = math.ceil(int(max_sequence_length) / block_size_tokens) + 1 + int(safety_blocks)
321+
target_blocks = max(2, target_blocks) # mcore floors block_count at 2 (active + dummy)
322+
323+
# --- Invert mcore's hybrid block-count formula for the no-mamba-ratio path. ---
324+
total_bytes = target_blocks * (block_size_bytes + mamba_per_request)
325+
return (total_bytes + 1) / (1024**3)
326+
327+
328+
def bind_hyena_packed_views_to_dynamic_context(model, dyn_ctx, *, request_slot: int):
329+
"""Bind Hyena state-dict entries to a live ``DynamicInferenceContext`` Mamba slot.
330+
331+
``DynamicInferenceContext`` allocates ``mamba_conv_states`` and ``mamba_ssm_states``
332+
from the Evo2 Mamba state config. This function installs the Hyena ``*_filter_state_dict``
333+
dictionaries that route each layer's existing state writes into the assigned request slot:
334+
the projection FIR ring uses the conv slot, and the layer mixer state uses the leading
335+
sub-slice of the ssm slot.
336+
337+
It must run after the request has been added and after ``initialize_all_tensors`` so the
338+
mamba state buffers and request slot are available. The current standalone path binds one
339+
request slot at a time; batched decode would need per-row state gathers in the Hyena step
340+
kernels.
341+
342+
Args:
343+
model: The Evo2 ``HyenaModel``.
344+
dyn_ctx: A live ``DynamicInferenceContext`` built with the Evo2 mamba state config.
345+
request_slot: The mamba state slot assigned to the active request.
346+
347+
Returns:
348+
The installed ``_PackedHyenaSlotStateDict`` objects.
349+
"""
350+
decoder = model.decoder if hasattr(model, "decoder") else model
351+
_conv_shape, _ssm_shape, per_layer = decoder.hyena_state_shapes_per_request()
352+
353+
conv_states = dyn_ctx.mamba_conv_states # (num_mamba_layers, max_requests, *conv_shape)
354+
ssm_states = dyn_ctx.mamba_ssm_states # (num_mamba_layers, max_requests, *ssm_shape)
355+
layer_map = dyn_ctx.layer_map # global-0based -> per-type-local index
356+
357+
# One packed dict per state-dict bucket the Hyena ops use, installed on the live context.
358+
packed: dict = {}
359+
for kind in ("fir", "inner_fir", "iir"):
360+
d = _PackedHyenaSlotStateDict(kind)
361+
packed[kind] = d
362+
object.__setattr__(dyn_ctx, f"{kind}_filter_state_dict", d)
363+
364+
# Iterate Hyena layers in the SAME order as ``per_layer`` (hyena_state_shapes_per_request
365+
# walks ``decoder.layers`` skipping attention) and resolve each layer's mamba-local index via
366+
# the context's layer_map so the conv/ssm sub-slice lands in the exact slot
367+
# ``mamba_states_cache(layer_number)`` would return.
368+
hyena_layers = [
369+
layer
370+
for layer in decoder.layers
371+
if hasattr(layer, "mixer") and hasattr(layer.mixer, "hyena_state_shapes_per_request")
372+
]
373+
assert len(hyena_layers) == len(per_layer), (
374+
f"Hyena-layer/per-layer-shape count mismatch ({len(hyena_layers)} vs {len(per_layer)}); "
375+
"hyena_state_shapes_per_request() and the layer walk disagree."
376+
)
377+
for layer, shapes in zip(hyena_layers, per_layer):
378+
mamba_layer_idx = layer_map[layer.layer_number - 1]
379+
# conv slot: whole per-(layer,request) row, reshaped to [B=1, *conv_shape] for the op.
380+
conv_row = conv_states[mamba_layer_idx, request_slot] # (*conv_shape)
381+
conv_view = conv_row.unsqueeze(0) # [1, *conv_shape] — STABLE alias (no copy)
382+
packed["fir"].register(shapes.conv_owner_id, conv_view)
383+
# ssm slot: leading sub-slice of the row, reshaped to [B=1, :width, :last_dim].
384+
w, last = shapes.ssm_shape
385+
ssm_view = ssm_states[mamba_layer_idx, request_slot, :w, :last].unsqueeze(0) # [1, w, last]
386+
packed[shapes.ssm_kind].register(shapes.ssm_owner_id, ssm_view)
387+
388+
return list(packed.values())
389+
390+
149391
def get_batch(
150392
data_iterator: Iterable, cfg: ConfigContainer, use_mtp: bool = False, *, pg_collection
151393
) -> tuple[
@@ -821,5 +1063,6 @@ def infer_model_type(model_size: str) -> str:
8211063
"HyenaNV40bModelProvider",
8221064
"HyenaNVTestModelProvider",
8231065
"HyenaTestModelProvider",
1066+
"compute_evo2_paged_kv_buffer_size_gb",
8241067
"infer_model_type",
8251068
]

0 commit comments

Comments
 (0)