Skip to content

Ad fast iter/nemotron nano 20260323#250

Open
suyoggupta wants to merge 158 commits into
mainfrom
ad-fast-iter/nemotron-nano-20260323
Open

Ad fast iter/nemotron nano 20260323#250
suyoggupta wants to merge 158 commits into
mainfrom
ad-fast-iter/nemotron-nano-20260323

Conversation

@suyoggupta

Copy link
Copy Markdown

@coderabbitai summary

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

… at conc=256)

Custom Triton SSM decode kernel optimized for Nemotron Nano v3 dimensions
(nheads=64, dim=64, dstate=128). Uses BLOCK_SIZE_M=16 vs stock 4, reducing
grid size by 4x. Correctness validated against FlashInfer kernel.

Results (reduced 6-layer model, ISL/OSL=1000):
  conc=1:   TTFT 45.62->44.61ms (-2.2%), TPOT 2.181->2.187ms (noise)
  conc=16:  TTFT 1003->964ms (-3.8%), TPS 4220->4300 (+1.9%)
  conc=256: TTFT 3809->2876ms (-24.5%), TPS 13076->13437 (+2.8%)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Add slot-indexed cache access to the fused conv1d+SSM Triton kernel.
This kernel combines causal conv1d update (with SiLU) and SSM state update
into a single kernel launch, eliminating intermediate buffer writes.

Microbenchmark: 1.38x faster at batch=1, 1.09x at batch=4 vs separate
conv1d + FlashInfer SSM. Needs graph transform integration to test e2e.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…@c1)

Fuse cuda_cached_causal_conv1d + tuned_cached_ssm into a single
fused_cached_conv_ssm custom op via compile-stage graph transform.

For decode tokens: single Triton kernel (_fused_conv_ssm_kernel) eliminates
the intermediate [batch, conv_dim] HBM round-trip between conv1d and SSM.
For prefill tokens: falls back to causal_conv1d_fn + mamba_chunk_scan_combined.

Key changes:
- fused_mamba_decode.py: add fused_conv_ssm_decode() Python launcher; fix
  B/C channel conv-state update (missing write bug); guard concurrent writes
  with pid_h % nheads_per_group == 0 to avoid write-write races
- fused_cached_conv_ssm.py: new custom op; normalizes 3D->2D conv weight
- fuse_conv_ssm.py: compile-stage transform; traces hidden/B/C args back
  through reshape/split to find matching conv+SSM node pairs
- default.yaml: register fuse_conv_ssm (disabled by default)
- nano_v3_reduced.yaml: enable fuse_conv_ssm

Results (6-layer reduced model, ISL=1000/OSL=1000):
- conc=1 TPOT: 2.181ms vs iter11 2.187ms (-0.3%, within noise)
- conc=1 TTFT: 46.17ms vs iter11 44.61ms (+3.5%, within noise)
- Transform: 3 matches found, correctness verified over 3 decode steps
- Expected gain small: SSM only 1.4% of decode time, 1.38x fused speedup
  at batch=1 -> ~0.4% e2e improvement (matches observed)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… kernel

Adds auto_deploy::relu2_quant_fp8 Triton custom op and a post_load_fusion
graph transform fuse_relu2_quant_fp8 that replaces the three-kernel chain:

    relu(x)                     → vectorized_elementwise_kernel (clamp)
    pow(relu_out, 2)            → vectorized_elementwise_kernel (pow)
    trtllm_quant_fp8_linear(…)  → includes scaleMatrixPerTensorVec (quantize)

with a single Triton kernel:

    relu2_quant_fp8(x, scale)   → clamp(relu(x)^2 / scale).to(fp8)
    trtllm_fp8_prequant_linear  → pure FP8 GEMM with no internal quantize step

On the 6-layer reduced Nemotron Nano 30B-FP8 model: 2 relu2 patterns fused,
TPOT -0.2% at conc=1 (within noise at this scale; expected ~1.5% on the full
30B model which has ~15 MoE shared-expert blocks using relu2 activation).

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… + add TTFT TODO

iter20's fuse_conv_ssm caused TTFT +3.5% regression (44.61→46.17ms) with
only -0.3% TPOT gain (within noise). Disable in reduced model config.

Add TODO for large TTFT opportunity: fuse prefill conv1d+SSM scan into a
single block-tiled Triton kernel (two HBM round-trips → one), targeting
~15-25% TTFT reduction similar in magnitude to iter11's decode SSM tuning.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…ssm + TTFT TODO

iter11's tuned_ssm (BLOCK_SIZE_M=16) gave -24.5% TTFT@c256 but regressed
TPOT by +1.2% (2.161→2.187ms) vs flashinfer_ssm baseline. Not worth the
decode latency cost.

Reverted insert_cached_ssm_attention to backend: flashinfer_ssm.

Added TODO: optimize mamba_chunk_scan_combined (prefill SSM scan) instead —
purely on the TTFT path, zero TPOT impact. Target: ~20-30% TTFT reduction
at high concurrency by tuning block sizes or fusing prefill conv1d+SSM scan.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…tuned_ssm>32)

Recover TTFT -24.5%@c256 from iter11 without the TPOT regression.

Root cause: tuned_ssm Triton kernel (BLOCK_SIZE_M=16) is ~1.2% slower than
flashinfer CUDA kernel at batch=1 decode-only (c1 TPOT).

Fix: adaptive dispatch in tuned_ssm backend —
  num_decode <= 32: flashinfer.mamba.selective_state_update (no regression)
  num_decode >  32: tuned_selective_state_update (BLOCK_SIZE_M=16, 4x fewer blocks)

Expected: TPOT back to ~2.161ms at c1 (flashinfer path), TTFT -24.5% at c256
(tuned path kicks in at 256 concurrent decode tokens).

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…, BLOCK_N=128)

Fix the core issue with the prior single-tile kernel:
  BLOCK_N=next_power_of_2(2688)=4096: 35% compute wasted (1408 masked
  elements processed as zeros). num_stages=3 with no loop = dead config.

New two-pass streaming approach:
  - BLOCK_N=128: divides 2688 exactly (128x21), zero waste
  - Pass1: stream x in tiles, accumulate x^2 (no register pressure)
  - Pass2: normalize+gamma+write BF16+FP8 with num_stages=2 prefetch
  - tl.rsqrt (hardware instruction) + 1/scale->mul instead of div/element

Microbench at H=2688 (Nemotron):
  B=256: 23.0us vs separate 19.2us (0.83x, still slower - kernel overhead)
  B=4096: 51.6us vs separate 92.7us (1.8x faster - bandwidth dominates)
  Crossover ~B=1000-4096.

fuse_rmsnorm_quant_fp8 remains disabled (hurts TPOT at small batch).
Improved kernel is ready for workloads with large prefill batches.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…h capture

Move `import flashinfer` to module level in tuned_backend_mamba.py.
The inline import inside the adaptive dispatch conditional ran during
CUDA graph capture (for batch sizes ≤32), triggering flashinfer CUDA
initialization mid-capture and causing SIGBUS at batch_size=384.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… into conv1d)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…FT@c256 4.5x regression)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…1→2.154ms, beats baseline 2.161ms)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…_ssm (flashinfer<=32, tuned Triton>32, same op name)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… (prefill uses inline silu to avoid large-batch regression)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…f module-level import causes TPOT regression)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…ter28b/29 on parallel GPUs

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…lashinfer only (no Triton, isolates branch overhead)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…eros+copy_)

For pure-decode batches (no prefill), bypass the preallocated-buffer path:
- Skip torch.zeros([bs, H, D]) → eliminates cudaMemset per Mamba layer
- Skip copy_(y_decode → preallocated_ssm_out_d) → eliminates cudaMemcpy per Mamba layer
- Return flashinfer/tuned output directly without intermediate buffer

Also restores tuned Triton path for large batches (>32) from iter28b diagnostic
(iter28c had Triton disabled for diagnostic purposes).

Expected TPOT gain at c1: ~2 GPU ops × 15 Mamba layers ≈ 30-90μs saved.
TTFT@c256 improvement preserved via Triton path for large decode batches.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…el + re-enable transform

For seq_len <= 4 (decode path), dispatch a single-pass kernel that loads all
n_cols=2688 elements into registers at once (BLOCK_N=4096, 32 warps), computes
norm factor in-register, and writes BF16+FP8 in one pass — no L2 re-read vs
the two-pass streaming approach (21 iterations at BLOCK_N=128).

Saves: (a) one HBM read per row, (b) 20 loop iterations, (c) one separate FP8
quant kernel launch per RMSNorm in the graph (by re-enabling fuse_rmsnorm_quant_fp8).

Previously disabled: iter9 kernel (BLOCK_N=4096, two-pass) regressed TPOT +1.2%.
iter31 uses single-pass for B<=4, two-pass for B>4 — only the B=1 decode path
is changed.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…-pass kernel still regresses TPOT)

Single-pass B<=4 kernel still regresses TPOT at c1 vs FlashInfer+FP8-quant baseline.
Disabled in config; code kept for future investigation with profiling data.
Iter30 SSM fast path (skip torch.zeros+copy_) remains active.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…er<=32, triton>32)

Add _RMN_ADAPTIVE_THRESHOLD=32 adaptive dispatch to fuse_rmsnorm_quant_fp8:
- seq_len<=32 (decode): FlashInfer CUDA kernels (rmsnorm_quant + rmsnorm for
  non-fused; fused_add_rmsnorm_quant + rmsnorm for fused add case) — avoids
  Triton's B=1 overhead that caused TPOT regression in iter6/9/25/31.
- seq_len>32 (prefill): Triton two-pass fused kernel (1 vs 2 kernel launches).
- Lazy imports to prevent CUDA context ordering issues (same as iter28b SSM).
- Enable fuse_rmsnorm_quant_fp8 in nano_v3_reduced.yaml.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…out_bf16 is DCE'd)

Key insight: fuse_rmsnorm_quant_fp8 transform only fires when ALL terminal consumers
of the norm output are FP8 linears (lines 251-256 of transform). After rewiring,
bf16_node has zero users and is eliminated by eliminate_dead_code(). So out_bf16 never
needs to be computed correctly — return an uninitialized empty tensor.

Simplify FlashInfer path for seq_len<=32 to single kernel:
- Non-fused: rmsnorm_quant(out_fp8, input, weight, scale, eps) only (1 call vs iter33's 2)
- Fused: elementwise add → out_add, then rmsnorm_quant(out_fp8, out_add, ...) (2 ops)
  out_add is still correct (needed for next layer residual); out_bf16 is empty (DCE'd)

Result: at B=1 decode, FlashInfer path is now truly 1 kernel (non-fused) or 2 ops (fused),
matching or beating the original unfused baseline.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…ale.item() → tl.load)

Replace FlashInfer rmsnorm_quant (requires scale.item() → D2H sync → CUDA graph
capture crash) with:
  1. flashinfer.norm.rmsnorm (no scale, CUDA graph compatible)
  2. _fp8_quant_only_kernel: reads scale via tl.load(scale_ptr) with no .item()

Both iter33 and iter33b crashed during CUDA graph warmup with:
  CUDA error: operation not permitted when stream is capturing

Root cause: scale.item() inside rmsnorm_quant path triggered cudaMemcpy D2H +
device synchronization which is forbidden during CUDA stream capture.

Fix: split rmsnorm_quant into two CUDA-graph-safe ops:
  non-fused: rmsnorm → fp8_quant_only (2 kernels)
  fused:     add → rmsnorm → fp8_quant_only (3 kernels)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…(regression)

iter33c measured TPOT avg=2.245ms (+4.2% vs 2.154ms baseline) and TTFT@c256
p5=760ms (2x worse than 353ms). Root causes:

1. Small-batch path (B<=32): 3-kernel FlashInfer+Triton split adds extra launch
   overhead vs 1-kernel Triton single-pass that was used before.

2. Large-batch: fuse_rmsnorm_quant_fp8 switches trtllm_quant_fp8_linear →
   trtllm_fp8_prequant_linear + separate quant kernel. The hardware-fused
   quant+GEMM in trtllm_quant_fp8_linear is faster at large batch.

Disable the transform. This is the 4th attempt at fuse_rmsnorm_quant_fp8
that showed regression (iter9, iter25, iter31, iter33c). The transform
fundamentally conflicts with the hardware-fused FP8 GEMM pipeline.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… M≤4 on sm_90+

iter34 (attn_backend=flashinfer): TPOT +4% (2.263ms), TTFT@c256 +31% → reverted.
FlashInfer decode attention slower than trtllm kernel_mha for this model.

iter35: Add Triton GEMV kernel for trtllm_quant_fp8_linear at M≤4 on H100.
- cuBLAS on H100 achieves ~13% HBM bandwidth at M=1 (GEMV regime)
- Custom Triton GEMV targets 50%+ bandwidth utilisation
- Dispatches from _trtllm_fp8_prequant_linear_core when: sm≥90, M≤4,
  out_dtype ∈ {bf16, fp32}, scales present
- Scale combined as device tensor (tl.load) — CUDA graph safe, no .item()
- Expected: 4-8% TPOT improvement from faster Mamba in_proj/out_proj GEMVs
  (31.8MB in_proj weight × 2 Mamba layers = dominant bandwidth consumer)
- BLOCK_N=128, BLOCK_K=256, num_warps=4

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…ults

iter35 regression: TPOT +5.8% (2.154→2.280ms).
Root cause: cuBLAS FP8 GEMV on H100 already achieves ~95% HBM bandwidth
(near-optimal). Triton GEMV kernel at ~53% bandwidth was ~1.8x slower.
Regression = ~8μs/call × 15+ FP8 linear ops = ~126μs accumulated overhead.

iter34 result also recorded: attn_backend=flashinfer TPOT +4% (2.263ms),
TTFT@c256 +31% (4235ms vs 3238ms). trtllm attention backend retained.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…on, reverted)

Test: multi_stream_moe: enabled: false
Result: TPOT@c=1 +3.5% (2.154→2.229ms) regression
Insight: shared+routed expert DO overlap at c=1 via multi-streaming; parallelism
benefit outweighs @torch._dynamo.disable stream-sync overhead
Action: reverted to enabled: true

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…a out_proj (2/3 layers, neutral result)

Fuse triton_rmsnorm_gated + scaleMatrixPerTensorVec into a single Triton kernel
(gated_rms_norm_quant_fp8) for Mamba out_proj, replacing trtllm_quant_fp8_linear
with trtllm_fp8_prequant_linear for 2 of 3 Mamba layers in the reduced model.

Changes:
- custom_ops/quantization/gated_rms_norm_quant_fp8.py: New Triton kernel fusing
  gated-RMSNorm + FP8 per-tensor quantization in one pass
- transform/library/fuse_gated_rmsnorm_quant_fp8.py: Graph transform matching
  triton_rmsnorm_gated then trtllm_quant_fp8_linear and replacing with fused op.
  Key fix: skip dtype-cast ops (aten.to.dtype, prims.convert_element_type) during
  view-chain re-apply to avoid FP8 to BF16 cast after fusion.
- modeling_nemotron_h.py: MambaRMSNormGated.forward now calls triton_rmsnorm_gated
  directly (previously used gated_rms_norm_ref plain-Python)
- default.yaml: register fuse_gated_rmsnorm_quant_fp8 (disabled by default)
- nano_v3_reduced.yaml: enable fuse_gated_rmsnorm_quant_fp8

Result: 2.178ms avg / 2.114ms p50 (baseline 2.154ms) - neutral, within noise.
Only 2 of 3 Mamba layers match (3rd out_proj consumed by earlier transform).

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…P8 Mamba layers fused (neutral result)

Debug investigation: only 2/3 Mamba layers have FP8 out_proj in the checkpoint;
the 3rd Mamba layer uses BF16 out_proj. fuse_gated_rmsnorm_quant_fp8 already
fuses all applicable layers (2/2). operator.getitem added to _VIEW_OPS to trace
through TP shard-slice ops in future models. Neutral TPOT result (~2.18ms vs
2.154ms baseline, within noise).

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…sion, kept as disabled)

Adds ShardLmHeadVocabParallel transform that slices lm_head.weight to the
local TP rank's vocab shard and inserts AllGather along dim=-1 after partial
logits. Solves the weight-tying issue by using get_lm_head_node() (position-
based) rather than attribute name matching.

Result on v3-Mcore checkpoint (4x H100, TP=4):
  baseline TPOT avg @c=1: 2.268ms
  iter39 TPOT avg @c=1:   2.340ms (+3.2% regression)
  TTFT regression:         +14% (AllGather adds latency)

Root cause: trtllm_dist_all_gather overhead at c=1 decode (~72μs) exceeds
the GEMV savings from 4x smaller weight. Left disabled in config.

Separately: on v3-Mcore checkpoint fuse_gated_rmsnorm_quant_fp8 fires
3/3 Mamba layers (vs 2/2 on old checkpoint).

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Add benchmark sweep script and optimization doc for the fused Mamba
decode kernel. Baseline (num_warps=4, BLOCK_DIM=64, BLOCK_DSTATE=128):
B=1→8.78us, B=8→19.58us, B=64→109.88us, B=384→630.83us.
Kernel accounts for ~100% of e2e time. Documents B/C conv-state
write-read race condition (nheads_per_group=8).

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…CK_DSTATE sweep

num_warps=8 beats default 4 by ~3.6% at B=384 (608us vs 631us).
BLOCK_DIM=64 and BLOCK_DSTATE=128 are already optimal.
BLOCK_DSTATE=64 is incorrect (skips half dstate). Documents best
tuning params: num_warps=8, BLOCK_DIM=64, BLOCK_DSTATE=128.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…static_range reverted

Update default num_warps from 4 to 8: B=384 608us (-3.6% vs baseline 631us).
Attempted tl.static_range for conv loops but it caused 1.59x regression
(966us) due to register pressure explosion from forced unrolling. Reverted.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…onv state loads

evict_last on SSM state load: B=1 -10.8%, B=8 -3.2%, B=384 -3.1%.
evict_last on conv state reads (hidden+B/C): marginal additional gain.
Combined improvement vs baseline: B=384 631us→589us = -6.7%.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…updates

Attempted and documented:
- Separate B/C state update kernel: reverted (kernel overhead outweighs savings)
- 2D batch-load for conv state [BLOCK_DIM, KW_PAD]: reverted (+4.8% regression)
- num_stages=2: no improvement
- evict_first on weights: no improvement
- BLOCK_DIM=32 num_warps=2: no improvement
- Precomputed B/C kernel (_bc_conv_compute_kernel): defined but not integrated

Key finding: 445 b32 registers/thread due to state[64,128]=8192 fp32 per CTA.
Only 1 CTA/SM possible. Future: dstate tiling or B/C precompute for >1 CTA/SM.
Final state: B=1 8.22us (-8.7%), B=384 588.6us (-6.7%) vs baseline.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… — no gain at B=384 (SSM state dominates), slower at small batch

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… state is bf16 in production (mamba_ssm_cache_dtype=auto), not fp32; actual perf 2x better than previously measured

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…bf16 state — num_warps=4 optimal at B=64+ (was 8 for fp32); B=384 306us

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…erlap HBM latency with B/C conv compute; B=384 304us (-0.6%), B=64/128 -1.5%

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…ents — batched heads worse, B/C costs 3-5%, near bandwidth roofline

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…) before state prefetch — B=16/32 -2.5%, B=64/128 -1.4%, B=384 303us

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…es after SSM output — B=384 301us (-0.7%), B=64-256 -0.7-1.0%; overlaps B/C write with SSM drain

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… prefetch — WORSE; evict_last kept

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… no perf diff vs sequential (roofline-limited)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…ts — no gain (49KB fits in 50MB L2)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…e writes — within noise, no benefit

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… store — WORSE (+5.8% at B=384); state is hot across decode steps

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…t — WORSE (+6.1% at B=384); input shared by 64 heads

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…sis — occupancy-limited

Iter 21: num_warps sweep for small batch — num_warps=4 confirmed optimal.
  Noop profiling reveals SSM state R+W = 84% of B=1 latency (6.2us of 7.3us).

Iter 22: Persistent kernel (128 CTAs filling all SMs) — 1.2-2.0x SLOWER.
  Dynamic while-loop prevents Triton compiler optimizations; added to file
  as fused_conv_ssm_decode_persistent() for reference only.

Iter 23: Two-kernel dstate-split approach — 2.0-2.2x SLOWER at B=1.
  Phase 1: conv kernel (batch, nheads), Phase 2: SSM with tl.atomic_add over
  N_DSTATE_TILES=4 tiles giving 256 CTAs. Extra kernel launches + atomic_add
  overhead negate parallelism gains. Correctness issues from bf16 intermediates.

Iter 24: Root cause confirmed — GPU occupancy is the architectural limit.
  At B=1, 64 CTAs on 132 SMs = 50% utilization. Effective BW: 165 GB/s vs
  1409 GB/s at B=384 (8.5x gap). Noop lower bound = 6.2us; full kernel = 7.3us.
  The 1.3us BW roofline is unreachable for single-kernel dispatch at B=1.
  BLOCK_DIM=8/16/32/64 all give same noop time; increasing CTAs does not help
  because bottleneck is occupancy, not instruction count. CUDA graph adds only
  0.13us improvement (launch overhead not the bottleneck).

No change to production fused_conv_ssm_decode() launcher.
B=384 remains at 301us = 1.01x BW roofline (optimal).

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…t analysis

Iter 25a: Tested num_stages=2 on conv loops (k=0..2) — no improvement.
  Only 3 loop iterations (kernel_width-1=3) are too few for software
  pipelining to amortize overhead. All batch sizes within noise (+-2.5%).

Iter 25b: Quantified B/C computation cost by zeroing B/C values.
  B/C section costs 0.6us at B=1 (8% of 7.4us) and 8.2us at B=384 (2.7%).
  Maximum savings from perfect elimination: 0.5us at B=1, 7.2us at B=384.
  Any B/C separation adds ~3-5us kernel launch overhead -- net-negative.
  This rules out the two-kernel B/C precompute approach for all batch sizes.

No code change to launcher. Running doc updated.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…a_decode; fix _run_ssm_prefill unpack bug in tuned_backend_mamba

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…(pass num_decode not num_seq/num_total_tokens)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…rgs (was missing, causing stale state across decode steps)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…SiLU scope

Two correctness fixes for the fused conv+SSM decode kernel:

1. fused_mamba_decode.py: fix B/C conv state race condition
   In _fused_conv_ssm_kernel, head-0 per group wrote the B/C conv state
   shift in-kernel (guarded by pid_h%nheads_per_group==0). At large batch,
   head-0 on one SM could write the shifted B/C state before other heads on
   different SMs finished reading it, causing stale reads and wrong SSM inputs.
   Fixed by removing the in-kernel B/C state write entirely and launching a
   separate _bc_conv_state_update_kernel after the main kernel, serializing
   all reads before any writes.

   Also moved the hidden-channel conv state write to after all B/C reads
   (avoiding a cache-line false-sharing hazard at the hidden/B boundary).

2. fused_cached_conv_ssm.py: apply SiLU to all channels during prefill
   The original model applies self.act (SiLU) to the full conv output
   [batch, conv_dim] before splitting into x/B/C. The prefill path was only
   applying SiLU to the hidden (x) channels and leaving B/C unactivated,
   causing prefill/decode mismatch in multi-step generation.
   Fixed by applying SiLU to the entire inp_flat[:num_prefill_tokens].

Both bugs contribute to GSM8k accuracy degradation (~1.78% vs ~68% reference)
when fuse_conv_ssm is enabled.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…CE to prevent double execution

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
… memory corruption

causal_conv1d_update and tuned_selective_state_update both guard against
PAD_SLOT_ID=-1 (pad batches in CUDA graph captures).  The two Triton
kernels in fused_conv_ssm_decode did not, causing pointer arithmetic
  ptr + (-1) * stride
which reads/writes memory BEFORE the tensor buffer.  This silently
corrupted adjacent tensors (model weights, other caches) on every
CUDA-graph decode step that contained padding.

Simple 1-sequence tests passed because no padding was used; the full
accuracy evaluation (batch up to 384) hit padding constantly, producing
catastrophically wrong outputs (GSM8k 2.35%).

Fix: add `if conv_slot < 0 or ssm_slot < 0: return` at the top of both
_fused_conv_ssm_kernel and _bc_conv_state_update_kernel, matching the
guard in causal_conv1d_triton.py and tuned_ssm_kernel.py.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…x_dec

slot_idx is torch.long (int64). The .to(torch.int32) call created a copy
that was frozen in the CUDA graph memory pool at capture-time slot values.
During replay with different sequences at different cache slots, the Triton
kernel read stale slot indices => read/wrote SSM/conv state at wrong
positions => garbage decode outputs => GSM8k ~1.8%.

Removing the cast returns a view (int64 slice) that is updated in-place
by the resource handler before each CUDA graph replay. The Triton kernel
already does tl.load(...).to(tl.int64) internally so int64 input works
without any other change.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…issing, causing 94% of dt values unclamped → wrong SSM states → 2% GSM8k accuracy)

The Triton fused conv+SSM decode kernel was missing the dt_clamp step
that the reference tuned_selective_state_update applies via dt_clamp_min/max.
With time_step_limit=(0.001, 0.1) for Nemotron Nano v3, 94% of dt values
after softplus exceed 0.1, causing exp(A*dt) ≈ 0 and effectively zeroing
the SSM state every decode step → garbage outputs → 2% GSM8k accuracy.

Fix: pass dt_clamp_min/max from time_step_limit through fused_conv_ssm_decode
to the Triton kernel, applying tl.minimum(tl.maximum(dt, min), max) after
softplus. Default values (0.0, inf) are no-ops for models without clamping.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants