|
19 | 19 | from test_triton_mamba_cached_op import _random_params |
20 | 20 |
|
21 | 21 | import tensorrt_llm._torch.auto_deploy # noqa: F401 |
22 | | -from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import BatchInfo |
| 22 | +from tensorrt_llm._torch.auto_deploy._compat import KvCacheConfig |
| 23 | +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import ( |
| 24 | + BatchInfo, |
| 25 | + CausalConvResourceHandler, |
| 26 | + IntermediateConvStateHandler, |
| 27 | + ReplayCacheBufIdxHandler, |
| 28 | + ReplayNWritesHandler, |
| 29 | + ReplayOldBHandler, |
| 30 | + ReplayOldDAcumsumHandler, |
| 31 | + ReplayOldDtHandler, |
| 32 | + ReplayOldXHandler, |
| 33 | + ReplayPrevNumAcceptedHandler, |
| 34 | + ReplayWorkItemsHandler, |
| 35 | + SSMResourceHandler, |
| 36 | +) |
| 37 | +from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface |
23 | 38 | from tensorrt_llm._torch.modules.mamba.mamba2_metadata import ( |
24 | 39 | REPLAY_WORK_CACHE_BUF_IDX, |
25 | 40 | REPLAY_WORK_CACHE_SLOT, |
@@ -265,3 +280,191 @@ def test_flashinfer_extend_replay_calls_replay_kernel(mamba_env, head_dim): |
265 | 280 | ) |
266 | 281 | assert out.shape == hidden_states.shape |
267 | 282 | assert torch.isfinite(out).all() |
| 283 | + |
| 284 | + |
| 285 | +class _SpecDecModeForReplayTest: |
| 286 | + def use_one_engine(self): |
| 287 | + return True |
| 288 | + |
| 289 | + |
| 290 | +class _SpecConfigForReplayTest: |
| 291 | + def __init__(self, max_draft_len: int): |
| 292 | + self.max_draft_len = max_draft_len |
| 293 | + self.tokens_per_gen_step = max_draft_len + 1 |
| 294 | + self.spec_dec_mode = _SpecDecModeForReplayTest() |
| 295 | + |
| 296 | + |
| 297 | +def _build_interface_with_replay_buffers(num_heads, head_dim, d_state, n_groups, max_batch_size): |
| 298 | + """Allocate replay buffers through the real production path (CachedSequenceInterface). |
| 299 | +
|
| 300 | + Registers the Mamba + replay-buffer resource bundle for one layer and runs |
| 301 | + initialize_resources(), which is where the cache-manager-bound replay |
| 302 | + work-items buffer (interface._replay_work_items -- the tensor the replay SSM |
| 303 | + kernel actually reads) is allocated. |
| 304 | + """ |
| 305 | + conv_dim = head_dim * num_heads + 2 * n_groups * d_state |
| 306 | + interface = CachedSequenceInterface( |
| 307 | + max_seq_len=128, |
| 308 | + max_batch_size=max_batch_size, |
| 309 | + max_num_tokens=(128 + 1) * max_batch_size, |
| 310 | + device="cuda", |
| 311 | + kv_cache_config=KvCacheConfig( |
| 312 | + tokens_per_block=32, max_tokens=1024, free_gpu_memory_fraction=0.0 |
| 313 | + ), |
| 314 | + spec_config=_SpecConfigForReplayTest(max_draft_len=2), |
| 315 | + ) |
| 316 | + interface.add_resource( |
| 317 | + "ssm_state_0", |
| 318 | + SSMResourceHandler( |
| 319 | + num_heads=num_heads, head_dim=head_dim, d_state=d_state, dtype=torch.bfloat16 |
| 320 | + ), |
| 321 | + ) |
| 322 | + interface.add_resource( |
| 323 | + "conv_state_0", CausalConvResourceHandler(conv_dim=conv_dim, d_conv=4, dtype=torch.float32) |
| 324 | + ) |
| 325 | + interface.add_resource( |
| 326 | + "intermediate_conv_state_0", |
| 327 | + IntermediateConvStateHandler(conv_dim=conv_dim, d_conv=4, dtype=torch.float32), |
| 328 | + ) |
| 329 | + interface.add_resource( |
| 330 | + "replay_old_x_0", |
| 331 | + ReplayOldXHandler(num_heads=num_heads, head_dim=head_dim, dtype=torch.bfloat16), |
| 332 | + ) |
| 333 | + interface.add_resource( |
| 334 | + "replay_old_B_0", |
| 335 | + ReplayOldBHandler(n_groups=n_groups, d_state=d_state, dtype=torch.bfloat16), |
| 336 | + ) |
| 337 | + interface.add_resource("replay_old_dt_0", ReplayOldDtHandler(num_heads=num_heads)) |
| 338 | + interface.add_resource("replay_old_dA_cumsum_0", ReplayOldDAcumsumHandler(num_heads=num_heads)) |
| 339 | + interface.add_resource("replay_cache_buf_idx_0", ReplayCacheBufIdxHandler()) |
| 340 | + interface.add_resource("replay_prev_num_accepted_0", ReplayPrevNumAcceptedHandler()) |
| 341 | + interface.add_resource("replay_work_items_0", ReplayWorkItemsHandler()) |
| 342 | + interface.add_resource("replay_n_writes_0", ReplayNWritesHandler()) |
| 343 | + return interface |
| 344 | + |
| 345 | + |
| 346 | +def test_extend_replay_init_buffers(mamba_env): |
| 347 | + """The replay path must not cause an out-of-bounds access on the replay buffers. |
| 348 | +
|
| 349 | + Behavioral guard for the replay path: every buffer the prepare hook populates (the |
| 350 | + work-items buffer and the n-writes count) is filled with garbage (out-of-bounds |
| 351 | + values, simulating uninitialized memory), then the production metadata-prep path runs |
| 352 | + and the real replay op executes; the test asserts no CUDA fault. With the fix, prep |
| 353 | + populates the buffers before the kernel reads them, so the garbage never reaches the |
| 354 | + kernel; without it the out-of-bounds values survive and fault. |
| 355 | +
|
| 356 | + Filling the buffers directly keeps the poison confined to them and makes the failure |
| 357 | + deterministic: fresh CUDA memory is often benign, so a poison-free run cannot |
| 358 | + reliably reproduce the bug. |
| 359 | + """ |
| 360 | + device = mamba_env["device"] |
| 361 | + dtype = mamba_env["dtype"] |
| 362 | + |
| 363 | + # Production SuperV3 Mamba2 shape (AutoDeploy replicates mamba -> full heads/groups), |
| 364 | + # large enough that the replay kernel runs its persistent_main path, which reads the |
| 365 | + # cache slot from the replay work-items buffer. |
| 366 | + num_extend = 8 |
| 367 | + tokens_per_extend = 7 # num_nextn_predict_layers (6) + 1 |
| 368 | + num_heads = 128 |
| 369 | + head_dim = 64 |
| 370 | + n_groups, ssm_state_size = 8, 128 |
| 371 | + |
| 372 | + interface = _build_interface_with_replay_buffers( |
| 373 | + num_heads, head_dim, ssm_state_size, n_groups, max_batch_size=num_extend |
| 374 | + ) |
| 375 | + interface.initialize_resources() |
| 376 | + |
| 377 | + # Poison every buffer the prepare hook populates -- the work-items buffer and the |
| 378 | + # n-writes count -- with out-of-bounds values, simulating garbage / uninitialized |
| 379 | + # memory. The production metadata-prep below must overwrite them before the kernel |
| 380 | + # reads them; if prep is missing (the bug) the poison survives and faults. |
| 381 | + interface._replay_work_items.fill_(0x7FFFFFFF) # int32-max: out-of-bounds cache slot |
| 382 | + interface._replay_n_writes.fill_(0x7FFFFFFF) # int32-max: out-of-bounds write count |
| 383 | + |
| 384 | + # Drive the production metadata-prep path -- the same one cudagraph capture uses -- |
| 385 | + # so the replay work-items / n-writes buffers are populated exactly as in real runs |
| 386 | + # (set_capture_batch -> nest_sequences -> prepare_replay_metadata host-prepare hook). |
| 387 | + interface.info.set_capture_batch(max_draft_len=tokens_per_extend - 1, batch_size=num_extend) |
| 388 | + replay_work_items = interface._replay_work_items |
| 389 | + replay_n_writes = interface._replay_n_writes |
| 390 | + |
| 391 | + # Per-token inputs and the remaining replay caches for the same extend batch. |
| 392 | + (hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) = _random_params( |
| 393 | + device, dtype, num_extend, tokens_per_extend, num_heads, head_dim, n_groups, ssm_state_size |
| 394 | + ) |
| 395 | + ssm_state_cache = torch.zeros( |
| 396 | + num_extend, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype |
| 397 | + ) |
| 398 | + slot_idx = torch.arange(num_extend, device=device, dtype=torch.int32) |
| 399 | + |
| 400 | + replay_history_size = 16 |
| 401 | + replay_old_x = torch.zeros( |
| 402 | + num_extend, 2, replay_history_size, num_heads, head_dim, device=device, dtype=torch.bfloat16 |
| 403 | + ) |
| 404 | + replay_old_b = torch.zeros( |
| 405 | + num_extend, |
| 406 | + 2, |
| 407 | + replay_history_size, |
| 408 | + n_groups, |
| 409 | + ssm_state_size, |
| 410 | + device=device, |
| 411 | + dtype=torch.bfloat16, |
| 412 | + ) |
| 413 | + replay_old_dt = torch.zeros( |
| 414 | + num_extend, 2, num_heads, replay_history_size, device=device, dtype=torch.float32 |
| 415 | + ) |
| 416 | + replay_old_da_cumsum = torch.zeros( |
| 417 | + num_extend, 2, num_heads, replay_history_size, device=device, dtype=torch.float32 |
| 418 | + ) |
| 419 | + replay_cache_buf_idx = torch.zeros(num_extend, device=device, dtype=torch.int32) |
| 420 | + replay_prev_num_accepted = torch.zeros(num_extend, device=device, dtype=torch.int32) |
| 421 | + |
| 422 | + _bi = BatchInfo() |
| 423 | + _bi.update([0, 0, num_extend, num_extend * tokens_per_extend, 0, 0]) |
| 424 | + _bi.update_use_replay(True) |
| 425 | + batch_info_host = _bi.serialize() |
| 426 | + cu_seqlen = torch.arange( |
| 427 | + 0, (num_extend + 1) * tokens_per_extend, tokens_per_extend, device=device, dtype=torch.int32 |
| 428 | + ) |
| 429 | + use_initial_states = torch.zeros(num_extend, device=device, dtype=torch.bool) |
| 430 | + any_prefill_use_initial_states_host = torch.tensor([False], device=device, dtype=torch.bool) |
| 431 | + |
| 432 | + out = torch.ops.auto_deploy.flashinfer_cached_ssm( |
| 433 | + hidden_states, |
| 434 | + A, |
| 435 | + B, |
| 436 | + C, |
| 437 | + D, |
| 438 | + dt, |
| 439 | + dt_bias, |
| 440 | + # STANDARD METADATA |
| 441 | + batch_info_host, |
| 442 | + cu_seqlen, |
| 443 | + slot_idx, |
| 444 | + use_initial_states, |
| 445 | + any_prefill_use_initial_states_host, |
| 446 | + # EXTRA METADATA |
| 447 | + None, |
| 448 | + None, |
| 449 | + None, # chunk_indices, chunk_offsets, seq_idx_prefill |
| 450 | + # CACHES |
| 451 | + ssm_state_cache, |
| 452 | + None, # intermediate_ssm_state_cache (None in replay mode) |
| 453 | + replay_old_x, |
| 454 | + replay_old_b, |
| 455 | + replay_old_dt, |
| 456 | + replay_old_da_cumsum, |
| 457 | + replay_cache_buf_idx, |
| 458 | + replay_prev_num_accepted, |
| 459 | + replay_work_items, |
| 460 | + replay_n_writes, |
| 461 | + # CONSTANTS |
| 462 | + time_step_limit, |
| 463 | + chunk_size, |
| 464 | + ) |
| 465 | + |
| 466 | + # Synchronize so any out-of-bounds access on the replay buffers surfaces here as a |
| 467 | + # CUDA error rather than asynchronously later. |
| 468 | + torch.cuda.synchronize() |
| 469 | + assert out.shape == hidden_states.shape |
| 470 | + assert torch.isfinite(out).all() |
0 commit comments