|
28 | 28 | from .fused_moe_backend_base import UnquantizedFusedMoEMethod |
29 | 29 |
|
30 | 30 | if current_platform.is_cuda(): |
31 | | - from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce |
| 31 | + from fastdeploy.model_executor.ops.gpu import ( |
| 32 | + count_tokens_per_expert_func, |
| 33 | + moe_expert_dispatch, |
| 34 | + moe_expert_reduce, |
| 35 | + ) |
32 | 36 |
|
33 | 37 | try: |
34 | 38 | from fastdeploy.model_executor.ops.gpu import ( |
@@ -286,6 +290,70 @@ def apply_tp( |
286 | 290 | layer.gate_correction_bias, |
287 | 291 | getattr(layer, "renormalize", True), |
288 | 292 | ) |
| 293 | + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": |
| 294 | + # moe_permute path: CUDA-graph safe, no D2H copies |
| 295 | + print("use moe_permute in tp") |
| 296 | + topk_idx_i32 = topk_idx.astype(paddle.int32) |
| 297 | + override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1) |
| 298 | + ( |
| 299 | + permute_input, |
| 300 | + permute_indices_per_token, # zipped_expertwise_rowmap |
| 301 | + dst_weights, |
| 302 | + _scale_out, |
| 303 | + _m_indices, |
| 304 | + ) = paddle.nn.functional.moe_permute( |
| 305 | + hidden_states=x, |
| 306 | + scale=None, |
| 307 | + expert_routemap_topk=topk_idx_i32, |
| 308 | + expert_prob_topk=topk_weights, |
| 309 | + num_experts=layer.num_experts, |
| 310 | + tokens_per_expert=[], |
| 311 | + padding_alignment=128, |
| 312 | + return_expert_indices=True, |
| 313 | + override_buffer_size=override_buffer_size, |
| 314 | + ) |
| 315 | + |
| 316 | + # Compute token_nums_per_expert (prefix sum) on GPU. |
| 317 | + # Use PADDED counts (row 1) because moe_permute with padding_alignment=128 |
| 318 | + # lays out each expert's tokens in 128-aligned blocks. Using actual counts |
| 319 | + # (row 0) would cause moe_expert_ffn to read wrong positions. |
| 320 | + # Use matmul with a cached lower-triangular matrix instead of |
| 321 | + # paddle.cumsum, because CUB inclusive_scan allocates temp memory |
| 322 | + # which is forbidden during CUDA graph capture. |
| 323 | + padded_counts = count_tokens_per_expert_func(topk_idx, layer.num_experts)[ |
| 324 | + 1 |
| 325 | + ] # [num_experts], int32, 128-aligned |
| 326 | + if not hasattr(self, "_cumsum_tril") or self._cumsum_tril.shape[0] != layer.num_experts: |
| 327 | + self._cumsum_tril = paddle.tril( |
| 328 | + paddle.ones([layer.num_experts, layer.num_experts], dtype="float32") |
| 329 | + ) |
| 330 | + token_nums_per_expert = paddle.mv(self._cumsum_tril, padded_counts.cast("float32")).cast(paddle.int64) |
| 331 | + |
| 332 | + if topk_ids_hookfunc is not None: |
| 333 | + topk_ids_hookfunc(topk_ids=topk_idx) |
| 334 | + |
| 335 | + ffn_out = self.compute_ffn( |
| 336 | + layer, |
| 337 | + permute_input, |
| 338 | + token_nums_per_expert, |
| 339 | + None, # expert_idx_per_token not needed for w16a16 without bias |
| 340 | + False, |
| 341 | + -1, |
| 342 | + None, # dequant_scale |
| 343 | + None, # max_tokens_per_expert |
| 344 | + ) |
| 345 | + |
| 346 | + fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute( |
| 347 | + hidden_states_unzipped=ffn_out, |
| 348 | + zipped_expertwise_rowmap=permute_indices_per_token, |
| 349 | + expert_routemap_topk=topk_idx_i32, |
| 350 | + token_prob_unzipped=dst_weights, |
| 351 | + total_zipped_tokens=x.shape[0], |
| 352 | + num_experts=layer.num_experts, |
| 353 | + using_weighted_combine=True, |
| 354 | + ) |
| 355 | + return fused_moe_out |
| 356 | + |
289 | 357 | ( |
290 | 358 | permute_input, |
291 | 359 | token_nums_per_expert, |
|
0 commit comments