Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ set(APHRODITE_EXT_SRC
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
# Updated to v4.2.0 (Sept 15, 2025) for improved SM120 support
set(CUTLASS_REVISION "v4.2.0" CACHE STRING "CUTLASS revision to use")

# Use the specified CUTLASS source directory for compilation if APHRODITE_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{APHRODITE_CUTLASS_SRC_DIR})
Expand Down Expand Up @@ -553,7 +554,9 @@ set(APHRODITE_EXT_SRC
set(SRCS
"kernels/quantization/fp4/nvfp4_quant_kernels.cu"
"kernels/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"kernels/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu")
"kernels/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"kernels/quantization/fp4/nvfp4_experts_quant.cu"
"kernels/quantization/fp4/nvfp4_blockwise_moe_sm120_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
Expand Down Expand Up @@ -634,7 +637,7 @@ set(APHRODITE_EXT_SRC
endif()
endif()

cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;12.0;12.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "kernels/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
Expand All @@ -655,7 +658,7 @@ set(APHRODITE_EXT_SRC
endif()

# moe_data.cu is used by all CUTLASS MoE kernels.
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;12.0;12.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "kernels/quantization/cutlass_w8a8/moe/moe_data.cu")
set_gencode_flags_for_srcs(
Expand Down
14 changes: 14 additions & 0 deletions aphrodite/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,20 @@ def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
# Detect SM120 architecture (RTX 5090/Blackwell GeForce)
device = a_tensors.device if a_tensors.is_cuda else torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device)

# Use SM120 kernel for compute capability 12.0 and above
if major == 12 and minor == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for the SM120 architecture is too specific. By checking for major == 12 and minor == 0, you are limiting this path to compute capability 12.0 exactly. To ensure forward compatibility with future minor revisions of the same architecture (e.g., 12.1), it's better to check only the major version number.

Suggested change
if major == 12 and minor == 0:
if major == 12:

# Check if SM120 kernel is available
if hasattr(torch.ops._C, 'cutlass_fp4_group_mm_sm120'):
return torch.ops._C.cutlass_fp4_group_mm_sm120(
out_tensors, a_tensors, b_tensors,
a_scales, b_scales, alphas,
problem_sizes, expert_offsets, sf_offsets)

# Fall back to standard kernel
return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors,
a_scales, b_scales, alphas,
problem_sizes, expert_offsets,
Expand Down
6 changes: 6 additions & 0 deletions kernels/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@ void cutlass_fp4_group_mm(
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);

void cutlass_fp4_group_mm_sm120(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);

void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
Expand Down
65 changes: 40 additions & 25 deletions kernels/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ void cutlass_moe_mm_sm90(

#endif

#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
#if defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
Expand All @@ -68,7 +69,8 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
#endif

#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
Expand Down Expand Up @@ -267,11 +269,21 @@ void cutlass_moe_mm(
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
if (version_num >= 120) {
// we use SM100 kernels for SM120 devices
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90 or 100");
"No compiled cutlass_moe_mm for CUDA device capability: ", version_num,
". Required capability: 90, 100, or 120. Note: SM120+ devices should use "
"SM100 kernels.");
}

void get_cutlass_moe_mm_data(
Expand All @@ -283,8 +295,9 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
Expand All @@ -295,26 +308,27 @@ void get_cutlass_moe_mm_data(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
version_num, ". Required capability: 90 or 100");
version_num, ". Required capability: 90, 100, or 120");
}

void get_cutlass_moe_mm_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k,
blockscale_offsets);
return;
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k,
blockscale_offsets);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: ",
version_num, ". Required capability: 90 or 100");
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}

void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
Expand All @@ -327,8 +341,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
Expand All @@ -338,7 +353,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90 or 100");
version_num, ". Required capability: 90, 100, or 120");
}

void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
Expand Down
33 changes: 30 additions & 3 deletions kernels/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass_extensions/common.hpp"
#include <cassert>

using namespace cute;
Expand Down Expand Up @@ -351,11 +352,29 @@ constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)

#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_fp4_group_mm_sm120(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
#endif

void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
{
int32_t version_num = get_sm_version_num();
if (version_num >= 120) {
return cutlass_fp4_group_mm_sm120(output, a, b, a_blockscale,
b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets);
}
}
#endif
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
Expand Down Expand Up @@ -395,10 +414,18 @@ void cutlass_fp4_group_mm(
expert_offsets, sf_offsets, M, N, K);
}
#else
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
{
int32_t version_num = get_sm_version_num();
if (version_num >= 120) {
return cutlass_fp4_group_mm_sm120(output, a, b, a_blockscale,
b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets);
}
}
#endif
Comment on lines +417 to +426
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code, which dispatches to the sm120 kernel, is a duplicate of the logic at lines 368-377. This redundancy makes the code harder to maintain. The initial check at the beginning of the function is sufficient to handle the dispatch, so this second block can be safely removed.

TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, Aphrodite must "
"be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
"12.8 or above.");
"No compiled cutlass_fp4_group_mm kernel for this architecture.");
#endif
}
Loading
Loading