Skip to content

Commit 6f6aaf8

Browse files
committed
[None][feat] add MXFP8 weight format + CUTLASS W8A8 Linear and MoE
Add first-class MXFP8 (OCP microscaling: e4m3 elements + per-32-element UE8M0 block scales) weight quantization to the PyTorch backend and execute MXFP8xMXFP8 W8A8 GEMMs through CUTLASS on Blackwell sm_100/103 for both dense Linear layers and fused MoE. Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
1 parent f3b718a commit 6f6aaf8

31 files changed

Lines changed: 1670 additions & 79 deletions

cpp/include/tensorrt_llm/common/quantization.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ class QuantMode
134134
return QuantMode(BaseType(1u) << 16);
135135
}
136136

137+
static constexpr QuantMode mxfp8() noexcept
138+
{
139+
return QuantMode(BaseType(1u) << 17);
140+
}
141+
137142
constexpr BaseType value() const noexcept
138143
{
139144
return mValue;
@@ -224,6 +229,11 @@ class QuantMode
224229
return isSet(w4a16Mxfp4());
225230
}
226231

232+
constexpr bool hasMxfp8() const noexcept
233+
{
234+
return isSet(mxfp8());
235+
}
236+
227237
constexpr bool hasKvCacheQuant() const noexcept
228238
{
229239
return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache();

cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,20 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(
449449
CutlassGemmConfig::CandidateConfigTypeParam const config, int sm)
450450
{
451451
#ifdef FAST_BUILD
452-
// Fast build disables all configs except this one for SM100
453-
return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO,
454-
EpilogueScheduleType::TMA, ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined, ClusterShape::Undefined,
455-
sm}};
452+
// Fast build limits the candidate set to a single CTA tile shape but
453+
// keeps both 1SM (cluster 1x1x1) and 2SM (cluster 2x1x1) variants so
454+
// the autotuner can profile both. Block-scaled paths (MXFP8xMXFP8,
455+
// NVFP4) accept both; the 2SM variant is required as a candidate so
456+
// FAST_BUILD doesn't accidentally exclude all 2SM kernels (needed for
457+
// MMA M=256 configurations of the Mxf8f6f4 tensor-op).
458+
return {
459+
CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO,
460+
EpilogueScheduleType::TMA, ClusterShape::ClusterShape_1x1x1, ClusterShape::Undefined,
461+
ClusterShape::Undefined, sm},
462+
CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO,
463+
EpilogueScheduleType::TMA, ClusterShape::ClusterShape_2x1x1, ClusterShape::Undefined,
464+
ClusterShape::Undefined, sm},
465+
};
456466
#else
457467
if (config & CutlassGemmConfig::GROUPED_GEMM)
458468
{

cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,22 @@ template <class ArchTag, class TileShape, class ClusterShape, bool DYNAMIC_CGA,
3232
struct should_filter_tma_warp_specialized_gemm_problem_shape
3333
{
3434
#ifdef FAST_BUILD
35-
using SupportedCtaShape = cute::Shape<cute::_128, cute::_128, decltype(cute::get<2>(TileShape{}))>;
36-
using SupportedCgaShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
35+
// The launcher passes its MMA tile shape here, which is CTA_M *
36+
// (Is2SM ? 2 : 1). So a CTA_M=128 with cluster_M=2 (2SM mode) lands as
37+
// MmaTileShape M==256. We accept both M=128 (1SM) and M=256 (2SM) under
38+
// FAST_BUILD so MXFP8xMXFP8 grouped MoE (which requires the Mxf8f6f4
39+
// tensor-op's MMA M==256) is reachable. The 1SM variant is also kept
40+
// for per-tensor FP8 / BF16 paths.
41+
using SupportedCtaShape1Sm = cute::Shape<cute::_128, cute::_128, decltype(cute::get<2>(TileShape{}))>;
42+
using SupportedCtaShape2Sm = cute::Shape<cute::_256, cute::_128, decltype(cute::get<2>(TileShape{}))>;
43+
using SupportedCgaShape1Sm = cute::Shape<cute::_1, cute::_1, cute::_1>;
44+
using SupportedCgaShape2Sm = cute::Shape<cute::_2, cute::_1, cute::_1>;
3745

38-
constexpr static bool value = !cute::is_same_v<SupportedCtaShape, TileShape>
39-
|| !cute::is_same_v<SupportedCgaShape, ClusterShape> || DYNAMIC_CGA;
46+
constexpr static bool cta_ok
47+
= cute::is_same_v<SupportedCtaShape1Sm, TileShape> || cute::is_same_v<SupportedCtaShape2Sm, TileShape>;
48+
constexpr static bool cga_ok
49+
= cute::is_same_v<SupportedCgaShape1Sm, ClusterShape> || cute::is_same_v<SupportedCgaShape2Sm, ClusterShape>;
50+
constexpr static bool value = !cta_ok || !cga_ok || DYNAMIC_CGA;
4051
#else
4152
constexpr static bool value = false;
4253
#endif

cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_bf16.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(__nv_bfloat16, 256, 128, 128, 1, 1, 1
8080

8181
template class CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4>;
8282
template class CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A8_MXFP4_MXFP8>;
83+
template class CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W8A8_MXFP8_MXFP8>;
8384

8485
#endif
8586

cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp16.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(half, 256, 128, 128, 1, 1, 1)
7979

8080
template class CutlassFp4GemmRunner<half, FP4GemmType::W4A4_NVFP4_NVFP4>;
8181
template class CutlassFp4GemmRunner<half, FP4GemmType::W4A8_MXFP4_MXFP8>;
82+
template class CutlassFp4GemmRunner<half, FP4GemmType::W8A8_MXFP8_MXFP8>;
8283

8384
} // namespace cutlass_kernels
8485
} // namespace kernels

cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_fp32.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER_SM120(float, 256, 128, 128, 1, 1, 1)
7979

8080
template class CutlassFp4GemmRunner<float, FP4GemmType::W4A4_NVFP4_NVFP4>;
8181
template class CutlassFp4GemmRunner<float, FP4GemmType::W4A8_MXFP4_MXFP8>;
82+
template class CutlassFp4GemmRunner<float, FP4GemmType::W8A8_MXFP8_MXFP8>;
8283

8384
} // namespace cutlass_kernels
8485
} // namespace kernels

cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
#include "../include/fp4_gemm.h"
3939
#include "mxfp8_mxfp4_gemm_template_sm100.h"
40+
#include "mxfp8_mxfp8_gemm_template_sm100.h"
4041
#include "nvfp4_nvfp4_gemm_template_sm100.h"
4142
#include "nvfp4_nvfp4_gemm_template_sm120.h"
4243

@@ -323,6 +324,94 @@ size_t dispatchMXFP8xMXFP4GemmCTAShapeSm100(T* D, void const* A, void const* B,
323324
}
324325
}
325326

327+
template <typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_>
328+
size_t dispatchMXFP8xMXFP8GemmClusterShapeSm100(T* D, void const* A, void const* B, void const* input_sf,
329+
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count,
330+
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream,
331+
int* occupancy = nullptr)
332+
{
333+
334+
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
335+
336+
switch (gemmConfig.cluster_shape)
337+
{
338+
case tkc::ClusterShape::ClusterShape_2x1x1:
339+
return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<1>, cute::Int<1>,
340+
__2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
341+
stream, occupancy);
342+
break;
343+
case tkc::ClusterShape::ClusterShape_2x2x1:
344+
return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<2>, cute::Int<1>,
345+
__2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
346+
stream, occupancy);
347+
break;
348+
case tkc::ClusterShape::ClusterShape_4x2x1:
349+
return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<2>, cute::Int<1>,
350+
__2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
351+
stream, occupancy);
352+
break;
353+
case tkc::ClusterShape::ClusterShape_2x4x1:
354+
return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<4>, cute::Int<1>,
355+
__2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
356+
stream, occupancy);
357+
break;
358+
case tkc::ClusterShape::ClusterShape_4x4x1:
359+
return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<4>, cute::Int<1>,
360+
__2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes,
361+
stream, occupancy);
362+
break;
363+
default:
364+
throw std::runtime_error(
365+
"[TensorRT LLM Error][MXFP8][dispatch_gemm_cluster_shape] Config is invalid for MXFP8xMXFP8 GEMM.");
366+
break;
367+
}
368+
}
369+
370+
template <typename T>
371+
size_t dispatchMXFP8xMXFP8GemmCTAShapeSm100(T* D, void const* A, void const* B, void const* input_sf,
372+
void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count,
373+
tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream,
374+
int* occupancy = nullptr)
375+
{
376+
377+
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
378+
switch (gemmConfig.tile_config_sm100)
379+
{
380+
case tkc::CutlassTileConfigSM100::CtaShape128x64x128B:
381+
return dispatchMXFP8xMXFP8GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<64>, cute::Int<128>>(D, A, B,
382+
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
383+
occupancy);
384+
break;
385+
case tkc::CutlassTileConfigSM100::CtaShape128x256x128B:
386+
return dispatchMXFP8xMXFP8GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<256>, cute::Int<128>>(D, A, B,
387+
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
388+
occupancy);
389+
break;
390+
case tkc::CutlassTileConfigSM100::CtaShape128x128x256B:
391+
return dispatchMXFP8xMXFP8GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<128>, cute::Int<256>>(D, A, B,
392+
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
393+
occupancy);
394+
break;
395+
case tkc::CutlassTileConfigSM100::CtaShape128x256x256B:
396+
return dispatchMXFP8xMXFP8GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<256>, cute::Int<256>>(D, A, B,
397+
input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream,
398+
occupancy);
399+
break;
400+
case tkc::CutlassTileConfigSM100::Undefined:
401+
throw std::runtime_error("[TensorRT LLM Error][MXFP8][dispatch_gemm_cta_shape] Gemm config undefined.");
402+
break;
403+
case tkc::CutlassTileConfigSM100::ChooseWithHeuristic:
404+
throw std::runtime_error(
405+
"[TensorRT LLM Error][MXFP8][dispatch_gemm_cta_shape] Gemm config should have already been set by "
406+
"heuristic.");
407+
break;
408+
default:
409+
throw std::runtime_error(
410+
"[TensorRT LLM Error][MXFP8][dispatch_gemm_cta_shape] Config is invalid for MXFP8xMXFP8 GEMM.");
411+
break;
412+
}
413+
}
414+
326415
template <typename T, FP4GemmType fp4GemmType>
327416
CutlassFp4GemmRunner<T, fp4GemmType>::CutlassFp4GemmRunner()
328417
{
@@ -358,6 +447,19 @@ size_t CutlassFp4GemmRunner<T, fp4GemmType>::dispatchToArch(T* D, void const* A,
358447
"[TensorRT LLM Error][CutlassFp4GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS FP4 GEMM");
359448
}
360449
}
450+
else if constexpr (fp4GemmType == FP4GemmType::W8A8_MXFP8_MXFP8)
451+
{
452+
if (mSm == 100 || mSm == 103)
453+
{
454+
return dispatchMXFP8xMXFP8GemmCTAShapeSm100<T>(D, A, B, input_sf, weight_sf, global_sf, m, n, k,
455+
batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy);
456+
}
457+
else
458+
{
459+
throw std::runtime_error(
460+
"[TensorRT LLM Error][CutlassFp4GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS MXFP8 GEMM");
461+
}
462+
}
361463
else if constexpr (fp4GemmType == FP4GemmType::W4A4_NVFP4_NVFP4)
362464
{
363465
if (mSm == 103)
@@ -437,9 +539,12 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon
437539
{
438540
for (auto const& cluster_config : clusterShapes)
439541
{
440-
if constexpr (fp4GemmType == FP4GemmType::W4A8_MXFP4_MXFP8)
542+
if constexpr (fp4GemmType == FP4GemmType::W4A8_MXFP4_MXFP8
543+
|| fp4GemmType == FP4GemmType::W8A8_MXFP8_MXFP8)
441544
{
442-
// Skip for high smem usage.
545+
// Skip for high smem usage (MXFP8xMXFP8 has even higher
546+
// smem pressure than MXFP8xMXFP4 because the B operand is
547+
// 2x wider, so the same skips apply).
443548
if (cluster_config == tkc::ClusterShape::ClusterShape_1x1x1
444549
|| cluster_config == tkc::ClusterShape::ClusterShape_1x2x1
445550
|| cluster_config == tkc::ClusterShape::ClusterShape_1x4x1)

0 commit comments

Comments
 (0)