|
37 | 37 |
|
38 | 38 | #include "../include/fp4_gemm.h" |
39 | 39 | #include "mxfp8_mxfp4_gemm_template_sm100.h" |
| 40 | +#include "mxfp8_mxfp8_gemm_template_sm100.h" |
40 | 41 | #include "nvfp4_nvfp4_gemm_template_sm100.h" |
41 | 42 | #include "nvfp4_nvfp4_gemm_template_sm120.h" |
42 | 43 |
|
@@ -323,6 +324,94 @@ size_t dispatchMXFP8xMXFP4GemmCTAShapeSm100(T* D, void const* A, void const* B, |
323 | 324 | } |
324 | 325 | } |
325 | 326 |
|
| 327 | +template <typename T, typename CTA_M_, typename CTA_N_, typename CTA_K_> |
| 328 | +size_t dispatchMXFP8xMXFP8GemmClusterShapeSm100(T* D, void const* A, void const* B, void const* input_sf, |
| 329 | + void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, |
| 330 | + tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, |
| 331 | + int* occupancy = nullptr) |
| 332 | +{ |
| 333 | + |
| 334 | + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); |
| 335 | + |
| 336 | + switch (gemmConfig.cluster_shape) |
| 337 | + { |
| 338 | + case tkc::ClusterShape::ClusterShape_2x1x1: |
| 339 | + return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<1>, cute::Int<1>, |
| 340 | + __2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, |
| 341 | + stream, occupancy); |
| 342 | + break; |
| 343 | + case tkc::ClusterShape::ClusterShape_2x2x1: |
| 344 | + return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<2>, cute::Int<1>, |
| 345 | + __2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, |
| 346 | + stream, occupancy); |
| 347 | + break; |
| 348 | + case tkc::ClusterShape::ClusterShape_4x2x1: |
| 349 | + return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<2>, cute::Int<1>, |
| 350 | + __2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, |
| 351 | + stream, occupancy); |
| 352 | + break; |
| 353 | + case tkc::ClusterShape::ClusterShape_2x4x1: |
| 354 | + return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<2>, cute::Int<4>, cute::Int<1>, |
| 355 | + __2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, |
| 356 | + stream, occupancy); |
| 357 | + break; |
| 358 | + case tkc::ClusterShape::ClusterShape_4x4x1: |
| 359 | + return genericMXFP8xMXFP8GemmKernelLauncher<T, CTA_M_, CTA_N_, CTA_K_, cute::Int<4>, cute::Int<4>, cute::Int<1>, |
| 360 | + __2SM>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, |
| 361 | + stream, occupancy); |
| 362 | + break; |
| 363 | + default: |
| 364 | + throw std::runtime_error( |
| 365 | + "[TensorRT LLM Error][MXFP8][dispatch_gemm_cluster_shape] Config is invalid for MXFP8xMXFP8 GEMM."); |
| 366 | + break; |
| 367 | + } |
| 368 | +} |
| 369 | + |
| 370 | +template <typename T> |
| 371 | +size_t dispatchMXFP8xMXFP8GemmCTAShapeSm100(T* D, void const* A, void const* B, void const* input_sf, |
| 372 | + void const* weight_sf, float const* global_sf, int m, int n, int k, int batch_count, |
| 373 | + tkc::CutlassGemmConfig gemmConfig, char* workspace, const size_t workspaceBytes, cudaStream_t stream, |
| 374 | + int* occupancy = nullptr) |
| 375 | +{ |
| 376 | + |
| 377 | + TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); |
| 378 | + switch (gemmConfig.tile_config_sm100) |
| 379 | + { |
| 380 | + case tkc::CutlassTileConfigSM100::CtaShape128x64x128B: |
| 381 | + return dispatchMXFP8xMXFP8GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<64>, cute::Int<128>>(D, A, B, |
| 382 | + input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, |
| 383 | + occupancy); |
| 384 | + break; |
| 385 | + case tkc::CutlassTileConfigSM100::CtaShape128x256x128B: |
| 386 | + return dispatchMXFP8xMXFP8GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<256>, cute::Int<128>>(D, A, B, |
| 387 | + input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, |
| 388 | + occupancy); |
| 389 | + break; |
| 390 | + case tkc::CutlassTileConfigSM100::CtaShape128x128x256B: |
| 391 | + return dispatchMXFP8xMXFP8GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<128>, cute::Int<256>>(D, A, B, |
| 392 | + input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, |
| 393 | + occupancy); |
| 394 | + break; |
| 395 | + case tkc::CutlassTileConfigSM100::CtaShape128x256x256B: |
| 396 | + return dispatchMXFP8xMXFP8GemmClusterShapeSm100<T, cute::Int<128>, cute::Int<256>, cute::Int<256>>(D, A, B, |
| 397 | + input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, workspaceBytes, stream, |
| 398 | + occupancy); |
| 399 | + break; |
| 400 | + case tkc::CutlassTileConfigSM100::Undefined: |
| 401 | + throw std::runtime_error("[TensorRT LLM Error][MXFP8][dispatch_gemm_cta_shape] Gemm config undefined."); |
| 402 | + break; |
| 403 | + case tkc::CutlassTileConfigSM100::ChooseWithHeuristic: |
| 404 | + throw std::runtime_error( |
| 405 | + "[TensorRT LLM Error][MXFP8][dispatch_gemm_cta_shape] Gemm config should have already been set by " |
| 406 | + "heuristic."); |
| 407 | + break; |
| 408 | + default: |
| 409 | + throw std::runtime_error( |
| 410 | + "[TensorRT LLM Error][MXFP8][dispatch_gemm_cta_shape] Config is invalid for MXFP8xMXFP8 GEMM."); |
| 411 | + break; |
| 412 | + } |
| 413 | +} |
| 414 | + |
326 | 415 | template <typename T, FP4GemmType fp4GemmType> |
327 | 416 | CutlassFp4GemmRunner<T, fp4GemmType>::CutlassFp4GemmRunner() |
328 | 417 | { |
@@ -358,6 +447,19 @@ size_t CutlassFp4GemmRunner<T, fp4GemmType>::dispatchToArch(T* D, void const* A, |
358 | 447 | "[TensorRT LLM Error][CutlassFp4GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS FP4 GEMM"); |
359 | 448 | } |
360 | 449 | } |
| 450 | + else if constexpr (fp4GemmType == FP4GemmType::W8A8_MXFP8_MXFP8) |
| 451 | + { |
| 452 | + if (mSm == 100 || mSm == 103) |
| 453 | + { |
| 454 | + return dispatchMXFP8xMXFP8GemmCTAShapeSm100<T>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, |
| 455 | + batch_count, gemmConfig, workspace, workspaceBytes, stream, occupancy); |
| 456 | + } |
| 457 | + else |
| 458 | + { |
| 459 | + throw std::runtime_error( |
| 460 | + "[TensorRT LLM Error][CutlassFp4GemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS MXFP8 GEMM"); |
| 461 | + } |
| 462 | + } |
361 | 463 | else if constexpr (fp4GemmType == FP4GemmType::W4A4_NVFP4_NVFP4) |
362 | 464 | { |
363 | 465 | if (mSm == 103) |
@@ -437,9 +539,12 @@ std::vector<tkc::CutlassGemmConfig> CutlassFp4GemmRunner<T, fp4GemmType>::getCon |
437 | 539 | { |
438 | 540 | for (auto const& cluster_config : clusterShapes) |
439 | 541 | { |
440 | | - if constexpr (fp4GemmType == FP4GemmType::W4A8_MXFP4_MXFP8) |
| 542 | + if constexpr (fp4GemmType == FP4GemmType::W4A8_MXFP4_MXFP8 |
| 543 | + || fp4GemmType == FP4GemmType::W8A8_MXFP8_MXFP8) |
441 | 544 | { |
442 | | - // Skip for high smem usage. |
| 545 | + // Skip for high smem usage (MXFP8xMXFP8 has even higher |
| 546 | + // smem pressure than MXFP8xMXFP4 because the B operand is |
| 547 | + // 2x wider, so the same skips apply). |
443 | 548 | if (cluster_config == tkc::ClusterShape::ClusterShape_1x1x1 |
444 | 549 | || cluster_config == tkc::ClusterShape::ClusterShape_1x2x1 |
445 | 550 | || cluster_config == tkc::ClusterShape::ClusterShape_1x4x1) |
|
0 commit comments