Skip to content

Commit f283e2a

Browse files
authored
Merge branch 'main' into dev-mixin-cleanups
2 parents ebffba0 + 09449d4 commit f283e2a

139 files changed

Lines changed: 6161 additions & 1260 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ TensorRT LLM
1010
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
1111
[![cuda](https://img.shields.io/badge/cuda-13.1.1-green)](https://developer.nvidia.com/cuda-downloads)
1212
[![torch](https://img.shields.io/badge/torch-2.10.0-green)](https://pytorch.org)
13-
[![version](https://img.shields.io/badge/release-1.3.0rc18-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
13+
[![version](https://img.shields.io/badge/release-1.3.0rc19-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
1414
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
1515

1616
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html)   |   [Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](https://nvidia.github.io/TensorRT-LLM/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)

cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ namespace kvc = tensorrt_llm::executor::kv_cache;
2424

2525
#pragma once
2626

27+
namespace tensorrt_llm::testing
28+
{
29+
class KVCacheTransferManagerTestAccess;
30+
} // namespace tensorrt_llm::testing
31+
2732
namespace tensorrt_llm::batch_manager::kv_cache_manager
2833
{
2934

@@ -76,10 +81,15 @@ class KVCacheTransferManager
7681
[[nodiscard]] KvCacheTransferStats getAndResetTransferStats();
7782

7883
private:
84+
friend class ::tensorrt_llm::testing::KVCacheTransferManagerTestAccess;
85+
7986
//! \brief Get pointer to pool specified by cache block.
8087
static tr::ITensor::SharedPtr computeBlockPointer(
8188
BlockPtr const& block, std::vector<KVCacheBlockPool> const& pools, size_t poolIdx);
8289

90+
//! \brief Get pool-qualified index for pending transfer tracking.
91+
[[nodiscard]] static kernels::KVCacheIndex::UnderlyingType getPendingTransferIndex(BlockPtr const& block);
92+
8393
/*!
8494
* \brief The key method that copies the src block to the dst block.
8595
*
@@ -107,8 +117,8 @@ class KVCacheTransferManager
107117
runtime::BufferManager mOnboardManager;
108118
runtime::BufferManager mOffloadManager;
109119

110-
// Track reads and writes for blocks. Note that it is the memory pool index that
111-
// identifies the raw memory blocks involved in I/O, not the block Id.
120+
// Track reads and writes for blocks. Note that it is the pool-qualified memory pool index
121+
// that identifies the raw memory blocks involved in I/O, not the block Id.
112122
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
113123
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
114124
// Reference to parent loopback agent

cpp/kernels/fmha_v2/src/fmha/kernel_traits.h

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ template <
145145
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
146146
bool BMM2_FP16_EPILOGUE = true,
147147
// non-positive means disabled
148-
int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0>
148+
int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0,
149+
// Enable skip softmax attention feature.
150+
bool ENABLE_SKIP_SOFTMAX_ = false>
149151
struct Kernel_traits_
150152
{
151153

@@ -197,6 +199,9 @@ struct Kernel_traits_
197199
SAGE_BLOCK_SIZE_V = SAGE_BLOCK_SIZE_V_
198200
};
199201

202+
// Are we enabling skip softmax attention feature?
203+
static constexpr bool ENABLE_SKIP_SOFTMAX = ENABLE_SKIP_SOFTMAX_;
204+
200205
// TODO: expose these tiling params to the interface
201206
enum
202207
{
@@ -1005,10 +1010,13 @@ template <
10051010
// The output type.
10061011
typename OutputType = typename Traits::A_type,
10071012
// The sage attention block size for Q, K and V
1008-
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
1013+
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0,
1014+
// Enable skip softmax attention feature.
1015+
bool ENABLE_SKIP_SOFTMAX = false>
10091016
using Kernel_traits_v2 = Kernel_traits_<Traits, fmha::v2::Gmem_tile_qkv, fmha::v2::Gmem_tile_qkv,
10101017
fmha::v2::Gmem_tile_qkv, Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N,
1011-
CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
1018+
CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V,
1019+
ENABLE_SKIP_SOFTMAX>;
10121020

10131021
////////////////////////////////////////////////////////////////////////////////////////////////////
10141022

@@ -1038,11 +1046,13 @@ template <
10381046
// The output type.
10391047
typename OutputType = typename Traits::A_type,
10401048
// The sage attention block size for Q, K and V
1041-
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
1042-
using Kernel_traits_v2_q_k_v
1043-
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v,
1044-
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
1045-
2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
1049+
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0,
1050+
// Enable skip softmax attention feature.
1051+
bool ENABLE_SKIP_SOFTMAX = false>
1052+
using Kernel_traits_v2_q_k_v = Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v,
1053+
fmha::v2::Gmem_tile_q_k_v, Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M,
1054+
WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K,
1055+
SAGE_BLOCK_SIZE_V, ENABLE_SKIP_SOFTMAX>;
10461056

10471057
////////////////////////////////////////////////////////////////////////////////////////////////////
10481058

@@ -1072,11 +1082,13 @@ template <
10721082
// The output type.
10731083
typename OutputType = typename Traits::A_type,
10741084
// The sage attention block size for Q, K and V
1075-
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
1076-
using Kernel_traits_v2_paged_kv_cache
1077-
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_paged_kv, fmha::v2::Gmem_tile_paged_kv,
1078-
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
1079-
2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
1085+
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0,
1086+
// Enable skip softmax attention feature.
1087+
bool ENABLE_SKIP_SOFTMAX = false>
1088+
using Kernel_traits_v2_paged_kv_cache = Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_paged_kv,
1089+
fmha::v2::Gmem_tile_paged_kv, Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M,
1090+
WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K,
1091+
SAGE_BLOCK_SIZE_V, ENABLE_SKIP_SOFTMAX>;
10801092

10811093
////////////////////////////////////////////////////////////////////////////////////////////////////
10821094

@@ -1106,11 +1118,13 @@ template <
11061118
// The output type.
11071119
typename OutputType = typename Traits::A_type,
11081120
// The sage attention block size for Q, K and V
1109-
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
1121+
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0,
1122+
// Enable skip softmax attention feature.
1123+
bool ENABLE_SKIP_SOFTMAX = false>
11101124
using Kernel_traits_v2_contiguous_kv_cache = Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v,
11111125
fmha::v2::Gmem_tile_contiguous_kv, fmha::v2::Gmem_tile_contiguous_kv,
11121126
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, 0, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2,
1113-
MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
1127+
MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V, ENABLE_SKIP_SOFTMAX>;
11141128

11151129
////////////////////////////////////////////////////////////////////////////////////////////////////
11161130

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# skip_softmax — TMA-load + sync-MMA warp-specialized FMHA for sm_120 / sm_121
2+
3+
> This is the sm_120 / sm_121 warp-specialized context FMHA that carries the
4+
> per-warp **skip-softmax** optimization (hence the name). Only half of the
5+
> Hopper warp-specialization recipe ports to consumer Blackwell: TMA-driven
6+
> async loads survive, but async MMA does not (sm_120 / sm_121 have no
7+
> `wgmma.async` equivalent), so the compute warps stay on `mma.sync` while a
8+
> dedicated producer warp drives the loads with TMA.
9+
10+
This directory implements a warp-specialized context FMHA for the sm_120
11+
family (sm_120 / sm_121). It targets BF16, causal mask, `head_dim ==
12+
head_dim_v` in `{128, 256}`, and the PACKED_QKV layout. The kernel carries the
13+
per-warp skip-softmax optimization into the warp-specialized design.
14+
15+
## Files
16+
17+
| File | Role |
18+
|------|------|
19+
| `kernel_traits.h` | `Kernel_traits_skip_softmax_sm120`: wraps `fmha::Kernel_traits_v2` for the LDGSTS-friendly `Smem_tile_*` types, then layers on the producer/consumer warp roles, the granular smem buffers, the circular-buffer barriers, and the V re-tile (see below). |
20+
| `dma_sync_mma.h` | Producer (`DMA::run`). Issues `cp.async.bulk.tensor.3d.shared::cta.global.tile` for Q / K / V into the granular buffers. `DMA::Host::init_params` builds the three `CUtensorMap` descriptors with the driver-API `cuTensorMapEncodeTiled`. |
21+
| `compute_sync_mma.h` | Consumer (`Compute::run`). The kv-loop body — BMM1 (`fmha::gemm`) + softmax + causal mask + per-warp skip-softmax vote + BMM2 + epilogue — reading the granular `Smem_tile_q/k/v` per ring slot. |
22+
23+
The translation unit and the in-engine dispatch bridges live in
24+
`cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/skip_softmax_sm120/fused_multihead_flash_attention_ws_sm120.cu`,
25+
and the entry kernel in
26+
`cpp/kernels/fmha_v2/src/fused_multihead_flash_attention_kernel_ws_sm120.h`.
27+
28+
## How the runner reaches this kernel
29+
30+
This is the **default** sm_120 / sm_121 context FMHA — there is no opt-in flag.
31+
`FusedMultiHeadAttentionXMMAKernelV2::run` dispatches every prefill whose config
32+
matches (sm_120 / sm_121, BF16 in/out, causal, `head_dim == head_dim_v` in
33+
`{128, 256}`, PACKED_QKV) **and** that carries no feature the kernel does not
34+
implement (alibi, logit softcapping, sage attention, sliding-window / custom mask,
35+
returning softmax stats, interleaved) to the `run_skip_softmax_*` bridges; every
36+
other config falls through to the cubin/launcher path.
37+
38+
The per-tile skip-softmax optimization is selected by
39+
`Launch_params::enableSkipSoftmax` (set when a skip-softmax threshold `> 0` is
40+
configured): the bridges instantiate the `ENABLE_SKIP_SOFTMAX = true` kernel
41+
variant when skipping is requested, and the `false` variant — a plain
42+
full-softmax prefill with no skip-check overhead — otherwise.
43+
44+
The translation unit is compiled only into the `_context_attention_kernels_120`
45+
CMake target (sm_120 family). The all-architecture dispatch TU references the
46+
bridge symbols under `TLLM_ENABLE_SKIP_SOFTMAX_SM120`, which CMake defines only when
47+
sm_120 is built, so builds that exclude sm_120 neither reference nor link the
48+
(then-absent) symbols.
49+
50+
## Design rationale
51+
52+
### Why TMA loads, not "just split the warps"
53+
54+
In the non-warp-specialized tiled kernel, the Q / K / V loads are *multi-thread*
55+
LDGSTS operations: each of the 128 threads issues several `LDGSTS` instructions
56+
to cover `(tile rows × D bytes)`. There is no way to "have warp 0 do the load"
57+
without rewriting the gmem/smem tile load helpers — the partition is baked into
58+
them. TMA fixes exactly this: a single descriptor + a single
59+
`cp.async.bulk.tensor` from one thread issues an entire tile load, and the
60+
consumers wait on an `mbarrier`. So the producer warp uses TMA, not LDGSTS.
61+
62+
### TMA descriptor format
63+
64+
Blackwell's TMA engine requires the driver-API `cuTensorMapEncodeTiled`
65+
(128-byte `CUtensorMap`) descriptor — the same form the shipping
66+
trtllmGenKernels FMHA uses. The fmha_v2 hand-rolled 64-byte `fmha::cudaTmaDesc`
67+
(Hopper-era bit layout) is rejected and faults at `UTMALDG`. The descriptors
68+
are built host-side in `DMA::Host::init_params` and passed to the kernel as
69+
`__grid_constant__` params.
70+
71+
### Why the LDGSTS smem tiles can be filled by TMA
72+
73+
The make-or-break question for reusing the existing consumer `Smem_tile_*` is
74+
whether their LDGSTS XOR swizzle equals a TMA hardware swizzle mode. It does:
75+
the Q and K granular tiles use `BYTES_PER_ROW = 128`, `BYTES_PER_STS = 16`,
76+
`ROWS_PER_XOR_PATTERN = 8`, i.e. a physical 16-byte chunk index of
77+
`(col / 8) ^ (row % 8)` — byte-identical to the TMA 128B hardware swizzle. So a
78+
chunked 128B-swizzle TMA load fills `Smem_tile_q/k` directly and the consumer's
79+
`ldmatrix` reads correct data.
80+
81+
### V is re-tiled to 64-wide DV chunks
82+
83+
The natural `Smem_tile_v` packs the full `DV` (256) into the lead dim, giving
84+
512-byte smem rows that no TMA swizzle mode can reproduce (`cuTensorMapEncodeTiled`
85+
caps the leading box dim at the 128-byte swizzle width; a 512-byte leading dim
86+
only encodes with `SWIZZLE_NONE`, which is plain row-major and does not match
87+
the consumer's XOR-swizzled read). Instead, V is tiled into `BMM2_DV_CHUNK = 64`
88+
wide groups so the V smem tile has `LEAD_DIM = 64` → 128-byte rows — the same
89+
proven layout as K — and the existing `N == 64` `ldsmt` read path applies
90+
unchanged. The producer streams `DV / 64` dv-chunks per kv-tile; the consumer
91+
BMM2 contracts per dv-chunk into the corresponding `acc_o` sub-range.
92+
93+
### `setmaxnreg` is unavailable here
94+
95+
`setmaxnreg.{dec,inc}` is a Hopper / datacenter-Blackwell instruction
96+
(sm_90 / 100 / 103); ptxas hard-errors on sm_120 / sm_121. The producer/consumer
97+
register-budget split therefore does not exist on this hardware and is guarded
98+
off (no-op on sm_120 / sm_121).
99+
100+
## What the port wins, and what it does not
101+
102+
Wins on sm_120 / sm_121:
103+
104+
- **Fewer load instructions** — one `cp.async.bulk.tensor` per tile replaces
105+
the many per-thread `LDGSTS` of the tiled kernel.
106+
- **Per-buffer-slot waits** (`mbarrier`) instead of CTA-wide `__syncthreads()`
107+
between load and compute: a consumer warp unblocks as soon as its tile lands.
108+
109+
Does not win:
110+
111+
- **MMA / softmax overlap** — there is no `wgmma.async` on sm_120, so a consumer
112+
warp's `mma.sync` blocks its issuing thread until result registers commit. The
113+
Hopper warpspec hides BMM1/BMM2 MMA latency behind softmax/`frag_p` work; that
114+
is not achievable with sync MMA only.
115+
- **Register-budget split**`setmaxnreg` is unavailable (see above).
116+
117+
## Relationship to a CuTe-DSL kernel
118+
119+
CUTLASS 4.x has Blackwell sm_120 FMHA examples implementing the TMA-load +
120+
sync-MMA pattern in CuTe DSL. A longer-term direction is to route the sm_120 /
121+
sm_121 dispatch into a CuTe-DSL kernel. This fmha_v2 implementation maps the
122+
relationship between the existing fmha_v2 infrastructure and that design and is
123+
self-contained: the dispatch is gated, and the directory plus the entry-kernel
124+
header are isolated (no other code includes them).

0 commit comments

Comments
 (0)