Skip to content

Commit fce68b3

Browse files
author
yinding
committed
Add SM90 MegaMoE support with TVM FFI bindings
1 parent 86d705d commit fce68b3

7 files changed

Lines changed: 4737 additions & 2 deletions

File tree

csrc/apis/sm90_mega.hpp

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#pragma once
2+
3+
#include <functional>
4+
5+
#include "mega.hpp"
6+
#include "../jit_kernels/impls/sm90_fp8_mega_moe.hpp"
7+
8+
namespace deep_gemm::mega {
9+
10+
static int get_token_alignment_for_sm90_mega_moe() {
11+
return layout::kLCMCandidateBlockM;
12+
}
13+
14+
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
15+
get_symm_buffer_size_for_sm90_mega_moe(
16+
const int& num_ranks, const int& num_experts,
17+
const int& num_max_tokens_per_rank, const int& num_topk,
18+
const int& hidden, const int& intermediate_hidden,
19+
const bool& use_fp8_dispatch, const std::string& activation) {
20+
DG_HOST_ASSERT(num_experts % num_ranks == 0);
21+
DG_HOST_ASSERT(use_fp8_dispatch);
22+
DG_HOST_ASSERT(activation == "swiglu");
23+
24+
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
25+
26+
const auto fp8_token_layout = layout::Data(hidden);
27+
const auto bf16_token_layout = layout::Data(hidden * 2);
28+
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
29+
const auto fp8_sf_layout = layout::Data(hidden / 32);
30+
const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
31+
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
32+
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
33+
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
34+
35+
const auto input_token_buffer = layout::Buffer(
36+
fp8_token_layout, 1, num_max_tokens_per_rank,
37+
workspace.get_end_ptr());
38+
const auto input_sf_buffer = layout::Buffer(
39+
fp8_sf_layout, 1, num_max_tokens_per_rank,
40+
input_token_buffer.get_end_ptr());
41+
const auto input_topk_idx_buffer = layout::Buffer(
42+
input_topk_idx_layout, 1, num_max_tokens_per_rank,
43+
input_sf_buffer.get_end_ptr());
44+
const auto input_topk_weights_buffer = layout::Buffer(
45+
input_topk_weights_layout, 1, num_max_tokens_per_rank,
46+
input_topk_idx_buffer.get_end_ptr());
47+
48+
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
49+
int num_max_padded_sf_pool_tokens = 0;
50+
for (int block_m: layout::kCandidateBlockM) {
51+
num_max_padded_sf_pool_tokens = std::max(
52+
num_max_padded_sf_pool_tokens,
53+
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
54+
);
55+
}
56+
57+
const auto l1_token_buffer = layout::Buffer(
58+
fp8_token_layout, 1, num_max_pool_tokens,
59+
input_topk_weights_buffer.get_end_ptr());
60+
const auto l1_sf_buffer = layout::Buffer(
61+
fp8_sf_layout, 1, num_max_padded_sf_pool_tokens,
62+
l1_token_buffer.get_end_ptr());
63+
const auto l1_topk_weights_buffer = layout::Buffer(
64+
l1_topk_weights_layout, 1, num_max_pool_tokens,
65+
l1_sf_buffer.get_end_ptr());
66+
67+
const auto l2_token_buffer = layout::Buffer(
68+
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
69+
l1_topk_weights_buffer.get_end_ptr());
70+
const auto l2_sf_buffer = layout::Buffer(
71+
fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
72+
l2_token_buffer.get_end_ptr());
73+
74+
const auto combine_token_buffer = layout::Buffer(
75+
bf16_token_layout, num_topk, num_max_tokens_per_rank,
76+
l2_sf_buffer.get_end_ptr());
77+
78+
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
79+
80+
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
81+
auto x = torch::from_blob(
82+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
83+
{num_max_tokens_per_rank, hidden},
84+
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
85+
auto x_sf = torch::from_blob(
86+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
87+
{num_max_tokens_per_rank, hidden / 128},
88+
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
89+
auto topk_idx = torch::from_blob(
90+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
91+
{num_max_tokens_per_rank, num_topk},
92+
torch::TensorOptions().dtype(torch::kInt64).device(buffer.device()));
93+
auto topk_weights = torch::from_blob(
94+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
95+
{num_max_tokens_per_rank, num_topk},
96+
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
97+
auto l1_acts = torch::from_blob(
98+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
99+
{num_max_pool_tokens, hidden},
100+
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
101+
auto l1_acts_sf = torch::from_blob(
102+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
103+
{num_max_padded_sf_pool_tokens, hidden / 128},
104+
{1, num_max_padded_sf_pool_tokens},
105+
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
106+
auto l2_acts = torch::from_blob(
107+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
108+
{num_max_pool_tokens, intermediate_hidden},
109+
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
110+
auto l2_acts_sf = torch::from_blob(
111+
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
112+
{num_max_padded_sf_pool_tokens, intermediate_hidden / 64},
113+
{1, num_max_padded_sf_pool_tokens},
114+
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
115+
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
116+
};
117+
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
118+
}
119+
120+
static void fp8_mega_moe(
121+
const torch::Tensor& y,
122+
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
123+
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
124+
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
125+
const torch::Tensor& sym_buffer,
126+
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
127+
const int& num_max_tokens_per_rank,
128+
const int& num_experts, const int& num_topk,
129+
const std::tuple<int, int, int>& recipe,
130+
const std::string& activation,
131+
const std::optional<float>& activation_clamp_opt,
132+
const bool& fast_math
133+
) {
134+
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
135+
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;
136+
137+
const auto arch_major = device_runtime->get_arch_major();
138+
DG_HOST_ASSERT(arch_major == 9);
139+
140+
const auto num_tokens = static_cast<int>(y.size(0));
141+
const auto [rm, rn, rk] = recipe;
142+
DG_HOST_ASSERT(rm == 128 and rn == 128 and rk == 128);
143+
DG_HOST_ASSERT(activation == "swiglu");
144+
145+
const auto activation_clamp =
146+
activation_clamp_opt.value_or(std::numeric_limits<float>::infinity());
147+
DG_HOST_ASSERT(activation_clamp >= 0);
148+
149+
DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K);
150+
DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K);
151+
DG_HOST_ASSERT(l1_weights.scalar_type() == torch::kFloat8_e4m3fn);
152+
DG_HOST_ASSERT(l2_weights.scalar_type() == torch::kFloat8_e4m3fn);
153+
const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = get_shape<3>(l1_weights);
154+
const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = get_shape<3>(l2_weights);
155+
DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank);
156+
DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_);
157+
DG_HOST_ASSERT(hidden == hidden_);
158+
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
159+
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
160+
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
161+
DG_HOST_ASSERT(intermediate_hidden / 64 <= 64);
162+
163+
constexpr int kGranMN = 128, kGranK = 128;
164+
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
165+
num_experts_per_rank, false, true, torch::kFloat);
166+
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
167+
num_experts_per_rank, false, true, torch::kFloat);
168+
169+
if (cumulative_local_expert_recv_stats.has_value()) {
170+
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
171+
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
172+
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
173+
}
174+
175+
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
176+
const auto num_experts_ = num_experts_per_rank * num_ranks;
177+
const auto [num_required_bytes, slice] = get_symm_buffer_size_for_sm90_mega_moe(
178+
num_ranks, num_experts,
179+
num_max_tokens_per_rank, num_topk,
180+
hidden, intermediate_hidden,
181+
true, activation);
182+
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
183+
DG_HOST_ASSERT(num_experts == num_experts_);
184+
185+
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);
186+
187+
sm90_fp8_mega_moe(y,
188+
l1_acts, l1_acts_sf,
189+
l2_acts, l2_acts_sf,
190+
l1_weights, l2_weights,
191+
l1_weights_sf, l2_weights_sf,
192+
cumulative_local_expert_recv_stats,
193+
sym_buffer_ptrs,
194+
rank_idx, num_max_tokens_per_rank,
195+
num_experts_per_rank,
196+
num_tokens, num_topk,
197+
hidden, intermediate_hidden,
198+
activation_clamp, fast_math);
199+
200+
if (get_env<int>("DG_COMM_KERNEL_DEBUG"))
201+
sym_buffer.zero_();
202+
}
203+
204+
} // namespace deep_gemm::mega
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#pragma once
2+
3+
#include "mega_moe.hpp"
4+
5+
namespace deep_gemm {
6+
7+
// ============================================================================
8+
// SM90 (Hopper) MegaMoE configuration
9+
// ----------------------------------------------------------------------------
10+
// SM90 differs from SM100 in:
11+
// - No tensor memory (TMEM): WGMMA accumulators live in registers.
12+
// - No FP4: weights are FP8 e4m3 with per-128 channel float scales.
13+
// - No 2-CTA cluster MMA: TMA multicast cluster=2 may still be used.
14+
// - Activation SF is float, not UE8M0 int: L1 input uses per-128 K and the
15+
// fused L1 epilogue writes L2 activation SF at per-64 K granularity.
16+
// The kernel implementation is in `deep_gemm/impls/sm90_fp8_mega_moe.cuh`.
17+
// ============================================================================
18+
19+
struct MegaMoESM90Config {
20+
int block_m, block_n, block_k;
21+
int cluster_size;
22+
int num_max_pool_tokens;
23+
int num_padded_sf_pool_tokens;
24+
int swizzle_acts_mode, swizzle_weights_mode;
25+
int num_experts_per_wave;
26+
int num_stages, smem_size;
27+
int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads;
28+
29+
friend std::ostream& operator << (std::ostream& os, const MegaMoESM90Config& config) {
30+
os << "MegaMoESM90Config("
31+
<< "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k
32+
<< ", cluster_size=" << config.cluster_size
33+
<< ", num_max_pool_tokens=" << config.num_max_pool_tokens
34+
<< ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens
35+
<< ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode
36+
<< ", num_experts_per_wave=" << config.num_experts_per_wave
37+
<< ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size
38+
<< ", num_dispatch_threads=" << config.num_dispatch_threads
39+
<< ", num_non_epilogue_threads=" << config.num_non_epilogue_threads
40+
<< ", num_epilogue_threads=" << config.num_epilogue_threads << ")";
41+
return os;
42+
}
43+
};
44+
45+
static std::tuple<int, int> get_block_config_for_mega_moe_sm90(
46+
const int& num_ranks, const int& num_experts,
47+
const int& num_max_tokens_per_rank, const int& num_topk,
48+
const int& num_tokens) {
49+
const float expected_tokens_per_expert =
50+
static_cast<float>(num_tokens) * num_ranks * num_topk / num_experts;
51+
const bool auto_split_mn = expected_tokens_per_expert >= 64.0f;
52+
if (auto_split_mn)
53+
return {128, 512};
54+
55+
const int block_m = 64;
56+
const int num_epilogue_warpgroups = 2;
57+
58+
DG_HOST_ASSERT(std::any_of(
59+
layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs,
60+
[=](const auto& candidate) { return candidate == block_m; })
61+
);
62+
return {block_m, num_epilogue_warpgroups * 128};
63+
}
64+
65+
static int get_num_experts_per_wave_for_mega_moe_sm90(
66+
const int& num_experts_per_rank, const int& num_tokens, const int& num_topk,
67+
const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) {
68+
const float expected_tokens_per_expert =
69+
static_cast<float>(num_tokens) * num_topk / num_experts_per_rank;
70+
if (expected_tokens_per_expert < 1.0f or expected_tokens_per_expert > 4.0f)
71+
return num_experts_per_rank;
72+
73+
if (block_m == 64 and intermediate_hidden >= 3072) {
74+
const int num_n_blocks_per_expert = (2 * intermediate_hidden) / block_n;
75+
const int single_wave_blocks =
76+
num_experts_per_rank * num_n_blocks_per_expert;
77+
if (single_wave_blocks >= 4 * num_sms)
78+
return num_experts_per_rank;
79+
}
80+
return get_num_experts_per_wave_for_mega_moe(
81+
num_experts_per_rank, num_tokens, num_topk,
82+
intermediate_hidden, block_m, block_n, num_sms);
83+
}
84+
85+
static std::pair<int, int> get_pipeline_config_for_mega_moe_sm90(
86+
const int& smem_capacity,
87+
const int& num_experts, const int& hidden,
88+
const int& block_m, const int& block_n, const int& block_k,
89+
const int& num_dispatch_warps, const int& num_epilogue_warps) {
90+
constexpr int kSmemAlignment = 1024;
91+
92+
const int smem_expert_count_size = align(
93+
num_experts * static_cast<int>(sizeof(uint32_t)), kSmemAlignment);
94+
const int smem_send_buffers_size = align(
95+
static_cast<int>(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()),
96+
kSmemAlignment);
97+
const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size;
98+
99+
const int smem_cd_l1 = block_m * (block_n / 2);
100+
const int smem_cd_l2 = block_m * block_n * static_cast<int>(sizeof(nv_bfloat16));
101+
const int smem_cd = align(std::max(smem_cd_l1, smem_cd_l2), kSmemAlignment);
102+
103+
const int smem_sfa_per_stage = align(2 * block_m * static_cast<int>(sizeof(float)), 128);
104+
const int smem_sfb_per_stage = 0;
105+
const int smem_per_stage = block_m * block_k + block_n * block_k +
106+
smem_sfa_per_stage + smem_sfb_per_stage;
107+
108+
const int smem_barriers_fixed = (num_dispatch_warps + 2 * num_epilogue_warps) * 8;
109+
const int smem_barriers_per_stage = 2 * 8;
110+
const int smem_fixed = smem_dispatch_size + smem_cd + smem_barriers_fixed;
111+
112+
const int num_stages = (smem_capacity - smem_fixed) /
113+
(smem_per_stage + smem_barriers_per_stage);
114+
DG_HOST_ASSERT(num_stages >= 2);
115+
const int smem_size = smem_fixed + num_stages * (smem_per_stage + smem_barriers_per_stage);
116+
DG_HOST_ASSERT(smem_size <= smem_capacity);
117+
return {num_stages, smem_size};
118+
}
119+
120+
static MegaMoESM90Config get_mega_moe_config_sm90(
121+
const int& num_ranks, const int& num_experts, const int& num_experts_per_rank,
122+
const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk,
123+
const int& hidden, const int& intermediate_hidden,
124+
const int& num_padded_sf_pool_tokens) {
125+
const auto [block_m, num_epilogue_threads] = get_block_config_for_mega_moe_sm90(
126+
num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens);
127+
const float expected_tokens_per_expert =
128+
static_cast<float>(num_tokens) * num_ranks * num_topk / num_experts;
129+
const bool auto_split_mn = expected_tokens_per_expert >= 64.0f;
130+
const bool decode_split_n_path =
131+
block_m == 64 and num_epilogue_threads == 256;
132+
const bool decode_use_block_n_256 =
133+
decode_split_n_path and intermediate_hidden >= 3072 and
134+
expected_tokens_per_expert >= 0.25f and
135+
(2 * intermediate_hidden) % 256 == 0;
136+
const int block_n = auto_split_mn ? 256
137+
: (decode_use_block_n_256 ? 256 : 128);
138+
const int block_k = 128;
139+
const int cluster_size = 1;
140+
const int num_max_pool_tokens = layout::get_num_max_pool_tokens(
141+
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
142+
const int swizzle_acts_mode = 128;
143+
const int swizzle_weights_mode = 128;
144+
145+
const int num_sms = device_runtime->get_num_sms();
146+
const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90(
147+
num_experts_per_rank, num_tokens, num_topk,
148+
intermediate_hidden, block_m, block_n, num_sms);
149+
150+
const bool reduce_decode_threads = num_epilogue_threads == 128;
151+
const bool decode_split_n =
152+
block_m == 64 and num_epilogue_threads == 256;
153+
const bool shrink_non_epilogue = reduce_decode_threads or decode_split_n;
154+
const int num_dispatch_threads =
155+
(num_epilogue_threads == 512 or shrink_non_epilogue) ? 64 : 128;
156+
const bool split_sfa_loader_warp = false;
157+
const int num_non_epilogue_threads =
158+
split_sfa_loader_warp ? 128 :
159+
((num_epilogue_threads == 512 or shrink_non_epilogue) ? 64 : 128);
160+
DG_HOST_ASSERT((num_dispatch_threads + num_non_epilogue_threads) % 128 == 0);
161+
162+
const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe_sm90(
163+
SM90ArchSpec::smem_capacity,
164+
num_experts, hidden,
165+
block_m, block_n, block_k,
166+
num_dispatch_threads / 32, num_epilogue_threads / 32);
167+
168+
const auto config = MegaMoESM90Config {
169+
block_m, block_n, block_k,
170+
cluster_size,
171+
num_max_pool_tokens, num_padded_sf_pool_tokens,
172+
swizzle_acts_mode, swizzle_weights_mode,
173+
num_experts_per_wave,
174+
num_stages, smem_size,
175+
num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads
176+
};
177+
178+
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
179+
const auto key = fmt::format(
180+
"MegaMoESM90Config(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",
181+
num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk);
182+
static std::unordered_set<std::string> printed;
183+
if (printed.count(key) == 0) {
184+
std::cout << key << ": " << config << std::endl;
185+
printed.insert(key);
186+
}
187+
}
188+
return config;
189+
}
190+
191+
} // namespace deep_gemm

0 commit comments

Comments
 (0)