|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <functional> |
| 4 | + |
| 5 | +#include "mega.hpp" |
| 6 | +#include "../jit_kernels/impls/sm90_fp8_mega_moe.hpp" |
| 7 | + |
| 8 | +namespace deep_gemm::mega { |
| 9 | + |
| 10 | +static int get_token_alignment_for_sm90_mega_moe() { |
| 11 | + return layout::kLCMCandidateBlockM; |
| 12 | +} |
| 13 | + |
| 14 | +static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>> |
| 15 | +get_symm_buffer_size_for_sm90_mega_moe( |
| 16 | + const int& num_ranks, const int& num_experts, |
| 17 | + const int& num_max_tokens_per_rank, const int& num_topk, |
| 18 | + const int& hidden, const int& intermediate_hidden, |
| 19 | + const bool& use_fp8_dispatch, const std::string& activation) { |
| 20 | + DG_HOST_ASSERT(num_experts % num_ranks == 0); |
| 21 | + DG_HOST_ASSERT(use_fp8_dispatch); |
| 22 | + DG_HOST_ASSERT(activation == "swiglu"); |
| 23 | + |
| 24 | + const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); |
| 25 | + |
| 26 | + const auto fp8_token_layout = layout::Data(hidden); |
| 27 | + const auto bf16_token_layout = layout::Data(hidden * 2); |
| 28 | + const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); |
| 29 | + const auto fp8_sf_layout = layout::Data(hidden / 32); |
| 30 | + const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 16); |
| 31 | + const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); |
| 32 | + const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); |
| 33 | + const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); |
| 34 | + |
| 35 | + const auto input_token_buffer = layout::Buffer( |
| 36 | + fp8_token_layout, 1, num_max_tokens_per_rank, |
| 37 | + workspace.get_end_ptr()); |
| 38 | + const auto input_sf_buffer = layout::Buffer( |
| 39 | + fp8_sf_layout, 1, num_max_tokens_per_rank, |
| 40 | + input_token_buffer.get_end_ptr()); |
| 41 | + const auto input_topk_idx_buffer = layout::Buffer( |
| 42 | + input_topk_idx_layout, 1, num_max_tokens_per_rank, |
| 43 | + input_sf_buffer.get_end_ptr()); |
| 44 | + const auto input_topk_weights_buffer = layout::Buffer( |
| 45 | + input_topk_weights_layout, 1, num_max_tokens_per_rank, |
| 46 | + input_topk_idx_buffer.get_end_ptr()); |
| 47 | + |
| 48 | + const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens); |
| 49 | + int num_max_padded_sf_pool_tokens = 0; |
| 50 | + for (int block_m: layout::kCandidateBlockM) { |
| 51 | + num_max_padded_sf_pool_tokens = std::max( |
| 52 | + num_max_padded_sf_pool_tokens, |
| 53 | + layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m) |
| 54 | + ); |
| 55 | + } |
| 56 | + |
| 57 | + const auto l1_token_buffer = layout::Buffer( |
| 58 | + fp8_token_layout, 1, num_max_pool_tokens, |
| 59 | + input_topk_weights_buffer.get_end_ptr()); |
| 60 | + const auto l1_sf_buffer = layout::Buffer( |
| 61 | + fp8_sf_layout, 1, num_max_padded_sf_pool_tokens, |
| 62 | + l1_token_buffer.get_end_ptr()); |
| 63 | + const auto l1_topk_weights_buffer = layout::Buffer( |
| 64 | + l1_topk_weights_layout, 1, num_max_pool_tokens, |
| 65 | + l1_sf_buffer.get_end_ptr()); |
| 66 | + |
| 67 | + const auto l2_token_buffer = layout::Buffer( |
| 68 | + fp8_intermediate_token_layout, 1, num_max_pool_tokens, |
| 69 | + l1_topk_weights_buffer.get_end_ptr()); |
| 70 | + const auto l2_sf_buffer = layout::Buffer( |
| 71 | + fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens, |
| 72 | + l2_token_buffer.get_end_ptr()); |
| 73 | + |
| 74 | + const auto combine_token_buffer = layout::Buffer( |
| 75 | + bf16_token_layout, num_topk, num_max_tokens_per_rank, |
| 76 | + l2_sf_buffer.get_end_ptr()); |
| 77 | + |
| 78 | + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); |
| 79 | + |
| 80 | + auto slice_input_buffers = [=](const torch::Tensor& buffer) { |
| 81 | + auto x = torch::from_blob( |
| 82 | + math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)), |
| 83 | + {num_max_tokens_per_rank, hidden}, |
| 84 | + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); |
| 85 | + auto x_sf = torch::from_blob( |
| 86 | + math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)), |
| 87 | + {num_max_tokens_per_rank, hidden / 128}, |
| 88 | + torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); |
| 89 | + auto topk_idx = torch::from_blob( |
| 90 | + math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)), |
| 91 | + {num_max_tokens_per_rank, num_topk}, |
| 92 | + torch::TensorOptions().dtype(torch::kInt64).device(buffer.device())); |
| 93 | + auto topk_weights = torch::from_blob( |
| 94 | + math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)), |
| 95 | + {num_max_tokens_per_rank, num_topk}, |
| 96 | + torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); |
| 97 | + auto l1_acts = torch::from_blob( |
| 98 | + math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)), |
| 99 | + {num_max_pool_tokens, hidden}, |
| 100 | + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); |
| 101 | + auto l1_acts_sf = torch::from_blob( |
| 102 | + math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)), |
| 103 | + {num_max_padded_sf_pool_tokens, hidden / 128}, |
| 104 | + {1, num_max_padded_sf_pool_tokens}, |
| 105 | + torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); |
| 106 | + auto l2_acts = torch::from_blob( |
| 107 | + math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)), |
| 108 | + {num_max_pool_tokens, intermediate_hidden}, |
| 109 | + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); |
| 110 | + auto l2_acts_sf = torch::from_blob( |
| 111 | + math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)), |
| 112 | + {num_max_padded_sf_pool_tokens, intermediate_hidden / 64}, |
| 113 | + {1, num_max_padded_sf_pool_tokens}, |
| 114 | + torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); |
| 115 | + return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); |
| 116 | + }; |
| 117 | + return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers}; |
| 118 | +} |
| 119 | + |
| 120 | +static void fp8_mega_moe( |
| 121 | + const torch::Tensor& y, |
| 122 | + const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple, |
| 123 | + const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple, |
| 124 | + const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats, |
| 125 | + const torch::Tensor& sym_buffer, |
| 126 | + const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx, |
| 127 | + const int& num_max_tokens_per_rank, |
| 128 | + const int& num_experts, const int& num_topk, |
| 129 | + const std::tuple<int, int, int>& recipe, |
| 130 | + const std::string& activation, |
| 131 | + const std::optional<float>& activation_clamp_opt, |
| 132 | + const bool& fast_math |
| 133 | +) { |
| 134 | + const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; |
| 135 | + const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; |
| 136 | + |
| 137 | + const auto arch_major = device_runtime->get_arch_major(); |
| 138 | + DG_HOST_ASSERT(arch_major == 9); |
| 139 | + |
| 140 | + const auto num_tokens = static_cast<int>(y.size(0)); |
| 141 | + const auto [rm, rn, rk] = recipe; |
| 142 | + DG_HOST_ASSERT(rm == 128 and rn == 128 and rk == 128); |
| 143 | + DG_HOST_ASSERT(activation == "swiglu"); |
| 144 | + |
| 145 | + const auto activation_clamp = |
| 146 | + activation_clamp_opt.value_or(std::numeric_limits<float>::infinity()); |
| 147 | + DG_HOST_ASSERT(activation_clamp >= 0); |
| 148 | + |
| 149 | + DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); |
| 150 | + DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); |
| 151 | + DG_HOST_ASSERT(l1_weights.scalar_type() == torch::kFloat8_e4m3fn); |
| 152 | + DG_HOST_ASSERT(l2_weights.scalar_type() == torch::kFloat8_e4m3fn); |
| 153 | + const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = get_shape<3>(l1_weights); |
| 154 | + const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = get_shape<3>(l2_weights); |
| 155 | + DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank); |
| 156 | + DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); |
| 157 | + DG_HOST_ASSERT(hidden == hidden_); |
| 158 | + DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); |
| 159 | + DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); |
| 160 | + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); |
| 161 | + DG_HOST_ASSERT(intermediate_hidden / 64 <= 64); |
| 162 | + |
| 163 | + constexpr int kGranMN = 128, kGranK = 128; |
| 164 | + check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, |
| 165 | + num_experts_per_rank, false, true, torch::kFloat); |
| 166 | + check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, |
| 167 | + num_experts_per_rank, false, true, torch::kFloat); |
| 168 | + |
| 169 | + if (cumulative_local_expert_recv_stats.has_value()) { |
| 170 | + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); |
| 171 | + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank); |
| 172 | + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); |
| 173 | + } |
| 174 | + |
| 175 | + const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size()); |
| 176 | + const auto num_experts_ = num_experts_per_rank * num_ranks; |
| 177 | + const auto [num_required_bytes, slice] = get_symm_buffer_size_for_sm90_mega_moe( |
| 178 | + num_ranks, num_experts, |
| 179 | + num_max_tokens_per_rank, num_topk, |
| 180 | + hidden, intermediate_hidden, |
| 181 | + true, activation); |
| 182 | + DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes)); |
| 183 | + DG_HOST_ASSERT(num_experts == num_experts_); |
| 184 | + |
| 185 | + const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); |
| 186 | + |
| 187 | + sm90_fp8_mega_moe(y, |
| 188 | + l1_acts, l1_acts_sf, |
| 189 | + l2_acts, l2_acts_sf, |
| 190 | + l1_weights, l2_weights, |
| 191 | + l1_weights_sf, l2_weights_sf, |
| 192 | + cumulative_local_expert_recv_stats, |
| 193 | + sym_buffer_ptrs, |
| 194 | + rank_idx, num_max_tokens_per_rank, |
| 195 | + num_experts_per_rank, |
| 196 | + num_tokens, num_topk, |
| 197 | + hidden, intermediate_hidden, |
| 198 | + activation_clamp, fast_math); |
| 199 | + |
| 200 | + if (get_env<int>("DG_COMM_KERNEL_DEBUG")) |
| 201 | + sym_buffer.zero_(); |
| 202 | +} |
| 203 | + |
| 204 | +} // namespace deep_gemm::mega |
0 commit comments