Skip to content

Commit 5fe0a17

Browse files
authored
[None][perf] DSv4 prep: attention fusion custom ops (#15390)
Signed-off-by: Fanrong Li <lfr-0531@users.noreply.github.com> Co-authored-by: Fanrong Li <lfr-0531@users.noreply.github.com>
1 parent 2772b99 commit 5fe0a17

14 files changed

Lines changed: 1286 additions & 4 deletions
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/*
2+
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// Fused 1x128 FP8 quantize + UE8M0 scale packing.
18+
//
19+
// Replaces the (scale_1x128_kernel + pack_fp32_into_ue8m0) two-kernel sequence
20+
// used by SM100 deep_gemm fp8 block-scale GEMMs. Adapted from the SM120 MoE
21+
// in-kernel packing pattern (`scale_1x128_kernel_sm120` in
22+
// sm120_blockwise_gemm/sm120_fp8_moe_gemm_1d1d.cuh), specialised for the
23+
// non-MoE case (single contiguous batch, no token offsets).
24+
25+
#include "fp8_blockscale_quant_packed.h"
26+
27+
#include "tensorrt_llm/common/config.h"
28+
#include "tensorrt_llm/common/envUtils.h"
29+
30+
#include <cstdint>
31+
#include <cuda_bf16.h>
32+
#include <cuda_fp8.h>
33+
#include <cuda_runtime_api.h>
34+
35+
TRTLLM_NAMESPACE_BEGIN
36+
37+
namespace kernels::fp8_blockscale_gemm
38+
{
39+
40+
namespace
41+
{
42+
43+
__device__ __forceinline__ float reciprocal_approximate_ftz_local(float a)
44+
{
45+
float b;
46+
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
47+
return b;
48+
}
49+
50+
// Each warp consumes one row × 4 quantization blocks (4 × 128 = 512 K elems).
51+
// 32 lanes split into 4 lane-groups of 8: each group covers 1 quant block
52+
// (8 lanes × 16 BF16 elems = 128 elems). After per-block amax, lanes
53+
// 0/8/16/24 each hold one UE8M0 scale byte; lane 0 packs them into a uint32
54+
// and stores in the deep_gemm-expected MN-major layout.
55+
template <int WarpsPerBlock>
56+
__global__ void fp8_quantize_1x128_packed_kernel_impl(__nv_fp8_e4m3* __restrict__ fp8_output,
57+
int32_t* __restrict__ packed_scale_output, __nv_bfloat16 const* __restrict__ input, int const m, int const k,
58+
int const scale_leading_dim_uint32)
59+
{
60+
int const packed_sf_k_idx = static_cast<int>(blockIdx.x);
61+
int const warp_id = static_cast<int>(threadIdx.x) >> 5;
62+
int const lane_id = static_cast<int>(threadIdx.x) & 31;
63+
int const m_idx = static_cast<int>(blockIdx.y) * WarpsPerBlock + warp_id;
64+
65+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
66+
cudaGridDependencySynchronize();
67+
#endif
68+
69+
if (m_idx >= m)
70+
{
71+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
72+
cudaTriggerProgrammaticLaunchCompletion();
73+
#endif
74+
return;
75+
}
76+
77+
int const k_base = packed_sf_k_idx * 512 + lane_id * 16;
78+
79+
// ---- 1. Load 16 BF16 elements per lane. ----
80+
auto const* in_ptr = reinterpret_cast<double4 const*>(input + static_cast<int64_t>(m_idx) * k + k_base);
81+
constexpr int kLoadNumElems = sizeof(double4) / sizeof(__nv_bfloat16); // 16
82+
83+
union LoadTrick
84+
{
85+
double4 pack;
86+
__nv_bfloat16 v[kLoadNumElems];
87+
};
88+
89+
LoadTrick load_trick;
90+
bool const k_in_range = (k_base < k);
91+
load_trick.pack = k_in_range ? in_ptr[0] : double4{};
92+
93+
if (k_in_range && k_base + kLoadNumElems > k)
94+
{
95+
int const valid = k - k_base;
96+
#pragma unroll
97+
for (int i = 0; i < kLoadNumElems; ++i)
98+
{
99+
if (i >= valid)
100+
{
101+
load_trick.v[i] = __nv_bfloat16(0.0f);
102+
}
103+
}
104+
}
105+
106+
// ---- 2. Per-block amax (lanes 0..7 / 8..15 / 16..23 / 24..31 = 4 quant blocks). ----
107+
__nv_bfloat16 max_elem = __nv_bfloat16(0.0f);
108+
#pragma unroll
109+
for (int i = 0; i < kLoadNumElems; ++i)
110+
{
111+
max_elem = __hmax(max_elem, __habs(load_trick.v[i]));
112+
}
113+
float amax = static_cast<float>(max_elem);
114+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFFu, amax, 4, 8));
115+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFFu, amax, 2, 8));
116+
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFFu, amax, 1, 8));
117+
amax = fmaxf(amax, 1e-10f);
118+
119+
// ---- 3. UE8M0 dequant scale. ----
120+
float const dequant_scale_raw = amax * reciprocal_approximate_ftz_local(448.0f);
121+
__nv_fp8_e8m0 ue8m0_scale;
122+
ue8m0_scale.__x = __nv_cvt_float_to_e8m0(dequant_scale_raw, __NV_SATFINITE, cudaRoundPosInf);
123+
124+
// Recover quant_scale = 1 / 2^(exp - 127) for fp8 conversion.
125+
constexpr uint32_t FP32_EXPONENT_BIAS = 127u;
126+
float const quant_scale = (ue8m0_scale.__x == 0)
127+
? 1.0f
128+
: exp2f(static_cast<float>(FP32_EXPONENT_BIAS) - static_cast<float>(ue8m0_scale.__x));
129+
130+
// ---- 4. Quantize and store FP8 output. ----
131+
constexpr int kStoreNumElems = sizeof(float4) / sizeof(__nv_fp8_e4m3); // 16
132+
133+
union StoreTrick
134+
{
135+
float4 pack;
136+
__nv_fp8_e4m3 v[kStoreNumElems];
137+
};
138+
139+
StoreTrick store_trick;
140+
store_trick.pack = float4{};
141+
#pragma unroll
142+
for (int i = 0; i < kStoreNumElems; ++i)
143+
{
144+
store_trick.v[i] = __nv_fp8_e4m3(static_cast<float>(load_trick.v[i]) * quant_scale);
145+
}
146+
auto* out_ptr = reinterpret_cast<float4*>(fp8_output + static_cast<int64_t>(m_idx) * k + k_base);
147+
if (k_in_range)
148+
{
149+
if (k_base + kStoreNumElems > k)
150+
{
151+
int const valid = k - k_base;
152+
#pragma unroll
153+
for (int i = 0; i < kStoreNumElems; ++i)
154+
{
155+
if (i >= valid)
156+
{
157+
store_trick.v[i] = __nv_fp8_e4m3(0.0f);
158+
}
159+
}
160+
}
161+
out_ptr[0] = store_trick.pack;
162+
}
163+
164+
// ---- 5. Pack 4 UE8M0 scales (lanes 0/8/16/24) and store. ----
165+
uint32_t const s0 = __shfl_sync(0xFFFFFFFFu, static_cast<uint32_t>(ue8m0_scale.__x), 0);
166+
uint32_t const s1 = __shfl_sync(0xFFFFFFFFu, static_cast<uint32_t>(ue8m0_scale.__x), 8);
167+
uint32_t const s2 = __shfl_sync(0xFFFFFFFFu, static_cast<uint32_t>(ue8m0_scale.__x), 16);
168+
uint32_t const s3 = __shfl_sync(0xFFFFFFFFu, static_cast<uint32_t>(ue8m0_scale.__x), 24);
169+
if (lane_id == 0)
170+
{
171+
// Mask off scale bytes whose sf_k is past the actual K.
172+
int const num_sf_k = (k + 127) / 128;
173+
int const sf_k_base = packed_sf_k_idx * 4;
174+
uint32_t packed = 0u;
175+
if (sf_k_base + 0 < num_sf_k)
176+
packed |= s0;
177+
if (sf_k_base + 1 < num_sf_k)
178+
packed |= (s1 << 8);
179+
if (sf_k_base + 2 < num_sf_k)
180+
packed |= (s2 << 16);
181+
if (sf_k_base + 3 < num_sf_k)
182+
packed |= (s3 << 24);
183+
// Layout: packed_scale[packed_sf_k_idx, m_idx]
184+
packed_scale_output[static_cast<int64_t>(packed_sf_k_idx) * scale_leading_dim_uint32 + m_idx] = packed;
185+
}
186+
187+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
188+
cudaTriggerProgrammaticLaunchCompletion();
189+
#endif
190+
}
191+
192+
} // namespace
193+
194+
void launch_fp8_quantize_1x128_packed_bf16_e4m3(__nv_fp8_e4m3* fp8_output, int32_t* packed_scale_output,
195+
__nv_bfloat16 const* input, int m, int k, int scale_leading_dim_uint32, cudaStream_t stream)
196+
{
197+
if (m <= 0 || k <= 0)
198+
{
199+
return;
200+
}
201+
202+
constexpr int kWarpsPerBlock = 4;
203+
int const num_packed_sf_k = (((k + 127) / 128) + 3) / 4;
204+
int const m_blocks = (m + kWarpsPerBlock - 1) / kWarpsPerBlock;
205+
dim3 const grid(num_packed_sf_k, m_blocks, 1);
206+
dim3 const block(kWarpsPerBlock * 32, 1, 1);
207+
208+
tensorrt_llm::common::launchWithPdlWhenEnabled("fp8_quantize_1x128_packed_kernel_impl",
209+
fp8_quantize_1x128_packed_kernel_impl<kWarpsPerBlock>, grid, block, 0, stream, fp8_output, packed_scale_output,
210+
input, m, k, scale_leading_dim_uint32);
211+
}
212+
213+
} // namespace kernels::fp8_blockscale_gemm
214+
215+
TRTLLM_NAMESPACE_END
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// Host-callable launcher for the fused FP8 1x128 quantize + UE8M0-pack kernel.
18+
// Kernel implementation lives in fp8_blockscale_quant_packed.cu and is built
19+
// by nvcc; this header is safe to include from .cpp files compiled by g++.
20+
21+
#pragma once
22+
23+
#include "tensorrt_llm/common/config.h"
24+
25+
#include <cstdint>
26+
#include <cuda_bf16.h>
27+
#include <cuda_fp8.h>
28+
#include <cuda_runtime_api.h>
29+
30+
TRTLLM_NAMESPACE_BEGIN
31+
32+
namespace kernels::fp8_blockscale_gemm
33+
{
34+
35+
// Launches the fused 1x128 FP8 quant + UE8M0 pack kernel.
36+
//
37+
// Inputs:
38+
// input : BF16 [m, k] row-major contiguous
39+
// Outputs:
40+
// fp8_output : E4M3 [m, k] row-major contiguous
41+
// packed_scale_output : uint32 [packed_sf_k, scale_leading_dim_uint32]
42+
// where packed_sf_k = ceil(ceil(k/128)/4)
43+
//
44+
// `scale_leading_dim_uint32` is the (uint32) stride between consecutive
45+
// packed_sf_k rows of the scale tensor; caller is responsible for choosing
46+
// it (typically aligned to 4 uint32 = 16 bytes for TMA alignment).
47+
void launch_fp8_quantize_1x128_packed_bf16_e4m3(__nv_fp8_e4m3* fp8_output, int32_t* packed_scale_output,
48+
__nv_bfloat16 const* input, int m, int k, int scale_leading_dim_uint32, cudaStream_t stream);
49+
50+
} // namespace kernels::fp8_blockscale_gemm
51+
52+
TRTLLM_NAMESPACE_END

0 commit comments

Comments
 (0)