Skip to content

Commit 2e33221

Browse files
authored
[#15565][fix] AutoDeploy: Fix Super MTP IMA introduced by checkpointing replay (#15622)
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent 9882d4f commit 2e33221

6 files changed

Lines changed: 221 additions & 13 deletions

File tree

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,6 @@ def _prepare_inputs(
993993
_ungathered_new_lens=new_tokens_lens,
994994
**extra_args,
995995
)
996-
self.cache_seq_interface.prepare_replay_metadata()
997996

998997
self.iter_states["num_ctx_requests"] = num_prefill
999998
self.iter_states["num_ctx_tokens"] = num_prefill_tokens

tensorrt_llm/_torch/auto_deploy/shim/interface.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,12 @@ def _create_and_assign_state_views(
904904
self._caches[buf_name] = global_tensor
905905

906906
if replay_work_items:
907-
self._replay_work_items = torch.empty(
907+
# Zero-init as an extra precaution (not torch.empty). The
908+
# prepare_replay_metadata host-prepare hook (registered in
909+
# initialize_resources) populates this buffer on every
910+
# nest_sequences -- runtime and cudagraph capture alike -- so the
911+
# replay SSM kernel never reads it unprepared.
912+
self._replay_work_items = torch.zeros(
908913
self.info.max_num_state_slots,
909914
REPLAY_WORK_ITEM_WIDTH,
910915
device=self.info.device,
@@ -1417,6 +1422,15 @@ def initialize_resources(self) -> int:
14171422
f"max_tokens={s['max_tokens']}"
14181423
)
14191424

1425+
if self.info.batch_info.is_use_replay():
1426+
# Wrapper takes **kwargs to satisfy the host-prepare callable protocol;
1427+
# prepare_replay_metadata reads everything it needs from self, so no
1428+
# graph-input args are requested (empty arg list).
1429+
def _replay_metadata_hook(**_sequence_info_args) -> None:
1430+
self.prepare_replay_metadata()
1431+
1432+
self.info.register_host_prepare_for_attention_forward(_replay_metadata_hook, [])
1433+
14201434
return len(self._caches)
14211435

14221436
def _requires_token_estimate(self) -> bool:

tests/integration/defs/accuracy/test_llm_api_autodeploy.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -825,10 +825,7 @@ def test_mtp(self, world_size, attn_backend, model_id):
825825

826826
task = GSM8K(self.MODEL_NAME)
827827
task.evaluate(llm)
828-
# bf16 acceptance is stable; fp8/nvfp4 have higher variance due to
829-
# arithmetic rounding, so use a lower threshold for quantized models.
830-
min_rate = 0.50 if model_id == "bf16" else 0.40
831-
self.check_acceptance_rate(llm, min_acceptance_rate=min_rate)
828+
self.check_acceptance_rate(llm, min_acceptance_rate=0.50)
832829

833830
print_memory_usage("after evaluation")
834831

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,5 +408,6 @@ l0_dgx_b200:
408408
- accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[deepseek-ai_DeepSeek-R1-0528-True]
409409
- accuracy/test_llm_api_autodeploy.py::TestQwen3_5_397B_MoE::test_nvfp4[8]
410410
- accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8]
411+
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[nvfp4_ws8_80gb-trtllm]
411412
# ------------- AutoDeploy Perf Sanity ---------------
412413
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_ad_blackwell-r1_fp8_ad_ws8_1k1k] TIMEOUT (120)

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xg
66
accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[bf16-4-attn_dp_off-trtllm] SKIP (https://nvbugs/6367792)
77
accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[fp8-4-attn_dp_off-trtllm] SKIP (https://nvbugs/6367792)
88
accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[nvfp4-4-attn_dp_off-trtllm] SKIP (https://nvbugs/6367792)
9-
accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[fp8_ws4_80gb-trtllm] SKIP (https://nvbugs/6336682)
109
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] SKIP (https://nvbugs/6281818)
1110
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/6281818)
1211
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency] SKIP (https://nvbugs/6276981)
@@ -157,20 +156,15 @@ full:B300/accuracy/test_llm_api_pytorch.py::TestKimiK25::test_nvfp4[ep8] SKIP (h
157156
full:B300/disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6322073)
158157
full:B300/unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend -k "TRTLLM" SKIP (https://nvbugs/6165866)
159158
full:DGX_B200/unittest/_torch/modules/moe/test_moe_backend.py::test_moe_backend -k "TRTLLM" SKIP (https://nvbugs/6165866)
160-
full:DGX_H100/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[fp8_ws4_80gb-trtllm] SKIP (https://nvbugs/6336682)
161159
full:GB200/accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy SKIP (https://nvbugs/6276923)
162160
full:GB200/accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy_contention_opt SKIP (https://nvbugs/6276923)
163161
full:GB200/accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy_mode_b_overlap SKIP (https://nvbugs/6276923)
164-
full:GB200/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[fp8_ws4_80gb-trtllm] SKIP (https://nvbugs/6316981)
165162
full:GB200/accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8_moe_dflash SKIP (https://nvbugs/6316985)
166163
full:GB200/accuracy/test_llm_api_pytorch.py::TestQwen3_5_4B::test_dflash SKIP (https://nvbugs/6344883)
167164
full:GB200/accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype SKIP (https://nvbugs/6316983)
168165
full:GB200/disaggregated/test_ad_disagg.py::test_async_eagle3_full_model_handoff SKIP (https://nvbugs/6369254)
169166
full:GB300/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False-False] SKIP (https://nvbugs/6316984)
170167
full:GB300/accuracy/test_llm_api_autodeploy.py::TestNemotronNanoV3::test_accuracy[nvfp4-1-attn_dp_off-trtllm] SKIP (https://nvbugs/6329165)
171-
full:GB300/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[bf16_ws4_180gb-trtllm] SKIP (https://nvbugs/6316981)
172-
full:GB300/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[fp8_ws4_80gb-trtllm] SKIP (https://nvbugs/6316981)
173-
full:GB300/accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[nvfp4_ws4_80gb-trtllm] SKIP (https://nvbugs/6316981)
174168
full:GB300/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-fp8] SKIP (https://nvbugs/6316980)
175169
full:GB300/accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache_no_reuse-tp4-trtllm-fp8] SKIP (https://nvbugs/6316980)
176170
full:GB300/accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8_moe_dflash SKIP (https://nvbugs/6316985)

tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,22 @@
1919
from test_triton_mamba_cached_op import _random_params
2020

2121
import tensorrt_llm._torch.auto_deploy # noqa: F401
22-
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo
22+
from tensorrt_llm._torch.auto_deploy._compat import KvCacheConfig
23+
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import (
24+
BatchInfo,
25+
CausalConvResourceHandler,
26+
IntermediateConvStateHandler,
27+
ReplayCacheBufIdxHandler,
28+
ReplayNWritesHandler,
29+
ReplayOldBHandler,
30+
ReplayOldDAcumsumHandler,
31+
ReplayOldDtHandler,
32+
ReplayOldXHandler,
33+
ReplayPrevNumAcceptedHandler,
34+
ReplayWorkItemsHandler,
35+
SSMResourceHandler,
36+
)
37+
from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface
2338
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import (
2439
REPLAY_WORK_CACHE_BUF_IDX,
2540
REPLAY_WORK_CACHE_SLOT,
@@ -265,3 +280,191 @@ def test_flashinfer_extend_replay_calls_replay_kernel(mamba_env, head_dim):
265280
)
266281
assert out.shape == hidden_states.shape
267282
assert torch.isfinite(out).all()
283+
284+
285+
class _SpecDecModeForReplayTest:
286+
def use_one_engine(self):
287+
return True
288+
289+
290+
class _SpecConfigForReplayTest:
291+
def __init__(self, max_draft_len: int):
292+
self.max_draft_len = max_draft_len
293+
self.tokens_per_gen_step = max_draft_len + 1
294+
self.spec_dec_mode = _SpecDecModeForReplayTest()
295+
296+
297+
def _build_interface_with_replay_buffers(num_heads, head_dim, d_state, n_groups, max_batch_size):
298+
"""Allocate replay buffers through the real production path (CachedSequenceInterface).
299+
300+
Registers the Mamba + replay-buffer resource bundle for one layer and runs
301+
initialize_resources(), which is where the cache-manager-bound replay
302+
work-items buffer (interface._replay_work_items -- the tensor the replay SSM
303+
kernel actually reads) is allocated.
304+
"""
305+
conv_dim = head_dim * num_heads + 2 * n_groups * d_state
306+
interface = CachedSequenceInterface(
307+
max_seq_len=128,
308+
max_batch_size=max_batch_size,
309+
max_num_tokens=(128 + 1) * max_batch_size,
310+
device="cuda",
311+
kv_cache_config=KvCacheConfig(
312+
tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0
313+
),
314+
spec_config=_SpecConfigForReplayTest(max_draft_len=2),
315+
)
316+
interface.add_resource(
317+
"ssm_state_0",
318+
SSMResourceHandler(
319+
num_heads=num_heads, head_dim=head_dim, d_state=d_state, dtype=torch.bfloat16
320+
),
321+
)
322+
interface.add_resource(
323+
"conv_state_0", CausalConvResourceHandler(conv_dim=conv_dim, d_conv=4, dtype=torch.float32)
324+
)
325+
interface.add_resource(
326+
"intermediate_conv_state_0",
327+
IntermediateConvStateHandler(conv_dim=conv_dim, d_conv=4, dtype=torch.float32),
328+
)
329+
interface.add_resource(
330+
"replay_old_x_0",
331+
ReplayOldXHandler(num_heads=num_heads, head_dim=head_dim, dtype=torch.bfloat16),
332+
)
333+
interface.add_resource(
334+
"replay_old_B_0",
335+
ReplayOldBHandler(n_groups=n_groups, d_state=d_state, dtype=torch.bfloat16),
336+
)
337+
interface.add_resource("replay_old_dt_0", ReplayOldDtHandler(num_heads=num_heads))
338+
interface.add_resource("replay_old_dA_cumsum_0", ReplayOldDAcumsumHandler(num_heads=num_heads))
339+
interface.add_resource("replay_cache_buf_idx_0", ReplayCacheBufIdxHandler())
340+
interface.add_resource("replay_prev_num_accepted_0", ReplayPrevNumAcceptedHandler())
341+
interface.add_resource("replay_work_items_0", ReplayWorkItemsHandler())
342+
interface.add_resource("replay_n_writes_0", ReplayNWritesHandler())
343+
return interface
344+
345+
346+
def test_extend_replay_init_buffers(mamba_env):
347+
"""The replay path must not cause an out-of-bounds access on the replay buffers.
348+
349+
Behavioral guard for the replay path: every buffer the prepare hook populates (the
350+
work-items buffer and the n-writes count) is filled with garbage (out-of-bounds
351+
values, simulating uninitialized memory), then the production metadata-prep path runs
352+
and the real replay op executes; the test asserts no CUDA fault. With the fix, prep
353+
populates the buffers before the kernel reads them, so the garbage never reaches the
354+
kernel; without it the out-of-bounds values survive and fault.
355+
356+
Filling the buffers directly keeps the poison confined to them and makes the failure
357+
deterministic: fresh CUDA memory is often benign, so a poison-free run cannot
358+
reliably reproduce the bug.
359+
"""
360+
device = mamba_env["device"]
361+
dtype = mamba_env["dtype"]
362+
363+
# Production SuperV3 Mamba2 shape (AutoDeploy replicates mamba -> full heads/groups),
364+
# large enough that the replay kernel runs its persistent_main path, which reads the
365+
# cache slot from the replay work-items buffer.
366+
num_extend = 8
367+
tokens_per_extend = 7 # num_nextn_predict_layers (6) + 1
368+
num_heads = 128
369+
head_dim = 64
370+
n_groups, ssm_state_size = 8, 128
371+
372+
interface = _build_interface_with_replay_buffers(
373+
num_heads, head_dim, ssm_state_size, n_groups, max_batch_size=num_extend
374+
)
375+
interface.initialize_resources()
376+
377+
# Poison every buffer the prepare hook populates -- the work-items buffer and the
378+
# n-writes count -- with out-of-bounds values, simulating garbage / uninitialized
379+
# memory. The production metadata-prep below must overwrite them before the kernel
380+
# reads them; if prep is missing (the bug) the poison survives and faults.
381+
interface._replay_work_items.fill_(0x7FFFFFFF) # int32-max: out-of-bounds cache slot
382+
interface._replay_n_writes.fill_(0x7FFFFFFF) # int32-max: out-of-bounds write count
383+
384+
# Drive the production metadata-prep path -- the same one cudagraph capture uses --
385+
# so the replay work-items / n-writes buffers are populated exactly as in real runs
386+
# (set_capture_batch -> nest_sequences -> prepare_replay_metadata host-prepare hook).
387+
interface.info.set_capture_batch(max_draft_len=tokens_per_extend - 1, batch_size=num_extend)
388+
replay_work_items = interface._replay_work_items
389+
replay_n_writes = interface._replay_n_writes
390+
391+
# Per-token inputs and the remaining replay caches for the same extend batch.
392+
(hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) = _random_params(
393+
device, dtype, num_extend, tokens_per_extend, num_heads, head_dim, n_groups, ssm_state_size
394+
)
395+
ssm_state_cache = torch.zeros(
396+
num_extend, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype
397+
)
398+
slot_idx = torch.arange(num_extend, device=device, dtype=torch.int32)
399+
400+
replay_history_size = 16
401+
replay_old_x = torch.zeros(
402+
num_extend, 2, replay_history_size, num_heads, head_dim, device=device, dtype=torch.bfloat16
403+
)
404+
replay_old_b = torch.zeros(
405+
num_extend,
406+
2,
407+
replay_history_size,
408+
n_groups,
409+
ssm_state_size,
410+
device=device,
411+
dtype=torch.bfloat16,
412+
)
413+
replay_old_dt = torch.zeros(
414+
num_extend, 2, num_heads, replay_history_size, device=device, dtype=torch.float32
415+
)
416+
replay_old_da_cumsum = torch.zeros(
417+
num_extend, 2, num_heads, replay_history_size, device=device, dtype=torch.float32
418+
)
419+
replay_cache_buf_idx = torch.zeros(num_extend, device=device, dtype=torch.int32)
420+
replay_prev_num_accepted = torch.zeros(num_extend, device=device, dtype=torch.int32)
421+
422+
_bi = BatchInfo()
423+
_bi.update([0, 0, num_extend, num_extend * tokens_per_extend, 0, 0])
424+
_bi.update_use_replay(True)
425+
batch_info_host = _bi.serialize()
426+
cu_seqlen = torch.arange(
427+
0, (num_extend + 1) * tokens_per_extend, tokens_per_extend, device=device, dtype=torch.int32
428+
)
429+
use_initial_states = torch.zeros(num_extend, device=device, dtype=torch.bool)
430+
any_prefill_use_initial_states_host = torch.tensor([False], device=device, dtype=torch.bool)
431+
432+
out = torch.ops.auto_deploy.flashinfer_cached_ssm(
433+
hidden_states,
434+
A,
435+
B,
436+
C,
437+
D,
438+
dt,
439+
dt_bias,
440+
# STANDARD METADATA
441+
batch_info_host,
442+
cu_seqlen,
443+
slot_idx,
444+
use_initial_states,
445+
any_prefill_use_initial_states_host,
446+
# EXTRA METADATA
447+
None,
448+
None,
449+
None, # chunk_indices, chunk_offsets, seq_idx_prefill
450+
# CACHES
451+
ssm_state_cache,
452+
None, # intermediate_ssm_state_cache (None in replay mode)
453+
replay_old_x,
454+
replay_old_b,
455+
replay_old_dt,
456+
replay_old_da_cumsum,
457+
replay_cache_buf_idx,
458+
replay_prev_num_accepted,
459+
replay_work_items,
460+
replay_n_writes,
461+
# CONSTANTS
462+
time_step_limit,
463+
chunk_size,
464+
)
465+
466+
# Synchronize so any out-of-bounds access on the replay buffers surfaces here as a
467+
# CUDA error rather than asynchronously later.
468+
torch.cuda.synchronize()
469+
assert out.shape == hidden_states.shape
470+
assert torch.isfinite(out).all()

0 commit comments

Comments
 (0)