Skip to content

Commit a65b400

Browse files
rosenrodtnv-guomingz
authored andcommitted
[None][feat] Add bf16 trtllm moe through flashinfer.
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
1 parent e16317e commit a65b400

File tree

11 files changed

+486
-65
lines changed

11 files changed

+486
-65
lines changed

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
4. Unified EPLB integration for backends that support it
2929
"""
3030

31+
import copy
3132
from typing import Dict, List, Optional, Tuple, Union
3233

3334
import torch
@@ -164,21 +165,34 @@ def __init__(
164165
self.apply_router_weight_on_input = apply_router_weight_on_input
165166

166167
# ========== Create MoE Backend (Default: Cutlass) ==========
167-
from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend, get_moe_cls
168+
from tensorrt_llm._torch.modules.fused_moe.create_moe import (
169+
create_moe_backend,
170+
resolve_moe_cls,
171+
)
172+
173+
# Get MoE backend class based on override_quant_config, routing_method, and model_config
174+
moe_cls = resolve_moe_cls(
175+
model_config,
176+
routing_method,
177+
self.dtype,
178+
override_quant_config=override_quant_config,
179+
)
168180

169-
# Get MoE backend class based on override_quant_config or model_config
170-
moe_cls = get_moe_cls(model_config, override_quant_config=override_quant_config)
181+
backend_model_config = model_config
182+
if override_quant_config is not None:
183+
backend_model_config = copy.deepcopy(model_config)
184+
backend_model_config.quant_config = override_quant_config
171185

172186
# Call create_moe_backend with all necessary parameters
173187
# init_load_balancer=False: Prevents backend from registering itself with load balancer
174188
# without_comm=True: Prevents backend from initializing communication (ConfigurableMoE handles it)
175189
# skip_create_weights_in_init=True: Prevents backend from creating weights in __init__
176190
# because backend uses layer_idx=None and may have different expert assignments
177191
# We will create weights after syncing attributes from ConfigurableMoE
178-
tmp_skip_create_weights_in_init = model_config.skip_create_weights_in_init
179-
model_config._frozen = False
180-
model_config.skip_create_weights_in_init = True
181-
model_config._frozen = True
192+
tmp_skip_create_weights_in_init = backend_model_config.skip_create_weights_in_init
193+
backend_model_config._frozen = False
194+
backend_model_config.skip_create_weights_in_init = True
195+
backend_model_config._frozen = True
182196

183197
backend = create_moe_backend(
184198
moe_cls=moe_cls,
@@ -188,7 +202,7 @@ def __init__(
188202
intermediate_size=self.intermediate_size,
189203
dtype=self.dtype,
190204
reduce_results=self.reduce_results,
191-
model_config=model_config,
205+
model_config=backend_model_config,
192206
aux_stream_dict=self.aux_stream_dict,
193207
weight_loading_mode=self.weight_loading_mode,
194208
bias=kwargs.get("bias", False),
@@ -223,10 +237,10 @@ def __init__(
223237
self.backend.expert_size_per_partition = self.expert_size_per_partition
224238

225239
# Create weights here, because the backend needs the layer_load_balancer info to create weights
226-
model_config._frozen = False
227-
model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
228-
model_config._frozen = True
229-
if not model_config.skip_create_weights_in_init:
240+
backend_model_config._frozen = False
241+
backend_model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
242+
backend_model_config._frozen = True
243+
if not backend_model_config.skip_create_weights_in_init:
230244
self.backend.create_weights()
231245

232246
# ========== Create Communication Strategy ==========

tensorrt_llm/_torch/modules/fused_moe/create_moe.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,23 @@ def get_moe_cls(
6363
return CutlassFusedMoE
6464
return DenseGEMMFusedMoE
6565
elif moe_backend.upper() == "TRTLLM":
66-
if quant_config is not None and (
67-
quant_config.quant_mode.has_fp8_block_scales()
68-
or quant_config.quant_mode.has_nvfp4()
69-
or quant_config.quant_mode.has_w4a16_mxfp4()
70-
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
71-
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
72-
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
66+
has_quant = quant_config is not None and quant_config.quant_mode.has_any_quant(
67+
exclude_kv_cache=True)
68+
if has_quant and (quant_config.quant_mode.has_fp8_block_scales()
69+
or quant_config.quant_mode.has_nvfp4()
70+
or quant_config.quant_mode.has_w4a16_mxfp4()
71+
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
72+
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
73+
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
7374
return TRTLLMGenFusedMoE
75+
if not has_quant and model_config.pretrained_config is not None and getattr(
76+
model_config.pretrained_config, "torch_dtype",
77+
None) == torch.bfloat16:
78+
if TRTLLMGenFusedMoE._is_flashinfer_fused_moe_available():
79+
return TRTLLMGenFusedMoE
80+
raise RuntimeError(
81+
"TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE with "
82+
"trtllm_bf16_moe support, but it is not available.")
7483
else:
7584
logger.warning(
7685
"TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, and w4a8_mxfp4_mxfp8. "
@@ -85,6 +94,25 @@ def get_moe_cls(
8594
raise ValueError(f"Unsupported moe backend: {moe_backend}")
8695

8796

97+
def resolve_moe_cls(
98+
model_config: ModelConfig,
99+
routing_method: BaseMoeRoutingMethod,
100+
dtype: Optional[torch.dtype],
101+
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
102+
moe_cls = get_moe_cls(model_config, override_quant_config)
103+
104+
effective_quant_config = override_quant_config or model_config.quant_config
105+
has_quant = (effective_quant_config is not None
106+
and effective_quant_config.layer_quant_mode.has_any_quant(
107+
exclude_kv_cache=True))
108+
if (moe_cls == TRTLLMGenFusedMoE and not has_quant
109+
and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method(
110+
routing_method)):
111+
return CutlassFusedMoE
112+
113+
return moe_cls
114+
115+
88116
def create_moe_backend(
89117
moe_cls: Type[MoE],
90118
routing_method: BaseMoeRoutingMethod,
@@ -379,7 +407,8 @@ def create_moe(
379407
pretrained_config, 'torch_dtype'):
380408
dtype = pretrained_config.torch_dtype
381409

382-
moe_cls = get_moe_cls(model_config, override_quant_config)
410+
moe_cls = resolve_moe_cls(model_config, routing_method, dtype,
411+
override_quant_config)
383412

384413
enable_configurable_moe = os.environ.get("ENABLE_CONFIGURABLE_MOE",
385414
"1") == "1"

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 115 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838

3939
# isort: off
4040
from .quantization import (
41-
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod,
42-
NVFP4TRTLLMGenFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
43-
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
44-
W4A16MXFP4TRTLLMGenFusedMoEMethod)
41+
BF16TRTLLMGenFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethod,
42+
NVFP4TRTLLMGenFusedMoEBaseMethod, NVFP4TRTLLMGenFusedMoEMethod,
43+
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
44+
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod)
4545
# isort: on
4646
from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod,
4747
DefaultMoeRoutingMethod)
@@ -115,7 +115,8 @@ def can_implement(
115115
- W4A8_MXFP4_FP8
116116
- W4A8_MXFP4_MXFP8
117117
118-
Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
118+
Unquantized BF16 path is supported only with FlashInfer fused MoE backend.
119+
Output dtype is hardcoded to bfloat16.
119120
120121
Args:
121122
quant_algo: The quantization algorithm to check (None for unquantized)
@@ -143,10 +144,16 @@ def can_implement(
143144
f"TRTLLMGenFusedMoE only supports bfloat16 activation, got {dtype_activation}"
144145
)
145146

146-
# TRTLLMGenFusedMoE does NOT support unquantized mode
147147
if quant_algo is None:
148-
return _warn_and_return(
149-
"TRTLLMGenFusedMoE does not support unquantized mode")
148+
if swiglu_gptoss_style:
149+
return _warn_and_return(
150+
"TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters."
151+
)
152+
if not cls._is_flashinfer_fused_moe_available():
153+
return _warn_and_return(
154+
"TRTLLMGenFusedMoE unquantized BF16 path requires FlashInfer fused MoE "
155+
"with trtllm_bf16_moe support.")
156+
return True, None
150157

151158
# Check if quant_algo is supported
152159
if quant_algo not in cls._SUPPORTED_QUANT_ALGOS:
@@ -210,7 +217,14 @@ def __init__(
210217

211218
assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."
212219

213-
self.use_flashinfer = self._check_op_backend_support()
220+
self.use_flashinfer = self._check_flashinfer_backend_support()
221+
if (self.quant_config is None
222+
or not self.quant_config.layer_quant_mode.has_any_quant(
223+
exclude_kv_cache=True)) and not self.use_flashinfer:
224+
raise NotImplementedError(
225+
"TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE. "
226+
"Please install a FlashInfer build with trtllm_bf16_moe support."
227+
)
214228
backend_name = "flashinfer" if self.use_flashinfer else "trtllm"
215229
self.op_backend: MoEOpBackend = get_op_backend(backend_name)
216230

@@ -292,7 +306,43 @@ def _to_trtllm_gen_activation_type(self,
292306
else:
293307
raise ValueError(f"Unsupported activation type: {activation_type}")
294308

295-
def _check_op_backend_support(self) -> bool:
309+
@staticmethod
310+
def _is_flashinfer_fused_moe_available() -> bool:
311+
try:
312+
from flashinfer.fused_moe import core as _core
313+
except (ImportError, ModuleNotFoundError):
314+
return False
315+
return (hasattr(_core, "trtllm_bf16_moe")
316+
and hasattr(_core, "trtllm_bf16_routed_moe"))
317+
318+
def _is_unquantized_path(self) -> bool:
319+
return self.quant_config is None or not self.quant_config.layer_quant_mode.has_any_quant(
320+
exclude_kv_cache=True)
321+
322+
@staticmethod
323+
def _supports_flashinfer_bf16_routing_method(
324+
routing_method: BaseMoeRoutingMethod, ) -> bool:
325+
# FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug
326+
return not isinstance(routing_method, DeepSeekV3MoeRoutingMethod)
327+
328+
def _requires_separated_routing(self) -> bool:
329+
"""Whether this backend instance expects precomputed top-k routing."""
330+
# FIXME: ban FlashInfer BF16 MoE direct routing as it appears to have accuracy bug
331+
return self.use_flashinfer and self._is_unquantized_path()
332+
333+
def _check_flashinfer_backend_support(self) -> bool:
334+
# For BF16 (unquantized) path, we will use FlashInfer regardless whether
335+
# env TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER=1 is set or not as it's the only way.
336+
if self._is_unquantized_path():
337+
if not self._is_flashinfer_fused_moe_available():
338+
return False
339+
if self.activation_type != ActivationType.Swiglu:
340+
return False
341+
if not self._supports_flashinfer_bf16_routing_method(
342+
self.routing_method):
343+
return False
344+
return True
345+
296346
use_flashinfer = os.environ.get("TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER",
297347
"0")
298348
if use_flashinfer != "1":
@@ -311,8 +361,6 @@ def _check_op_backend_support(self) -> bool:
311361
if type(quant_method) is NVFP4TRTLLMGenFusedMoEBaseMethod:
312362
return True
313363

314-
if self.quant_config is None:
315-
return False
316364
mode = self.quant_config.layer_quant_mode
317365

318366
# These quant modes are never supported via op backend
@@ -365,7 +413,14 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
365413
return AlltoallMethodType.NVLinkOneSided
366414

367415
def _supports_load_balancer(self) -> bool:
368-
"""TRTLLMGenFusedMoE supports load balancer."""
416+
"""Whether separated routing (top-k outside the kernel) is used.
417+
418+
ConfigurableMoE uses this flag to decide whether routing is separated
419+
(top-k ids/scales computed outside backend) or fused inside the kernel.
420+
BF16 FlashInfer path always requires separated routing.
421+
"""
422+
if self._requires_separated_routing():
423+
return True
369424
return self.use_dp and self.parallel_size > 1
370425

371426
@cached_property
@@ -375,9 +430,17 @@ def enable_alltoall(self):
375430
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
376431

377432
def _check_configs(self):
378-
assert self.has_deepseek_fp8_block_scales \
433+
assert not self.has_any_quant \
434+
or self.has_deepseek_fp8_block_scales \
379435
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
380-
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
436+
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, \
437+
"TRTLLMGenFusedMoE only supports bf16 (FlashInfer), fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
438+
439+
if not self.has_any_quant:
440+
assert self.activation_type == ActivationType.Swiglu, \
441+
"TRTLLMGenFusedMoE BF16 path only supports Swiglu activation."
442+
assert not self.bias and self.swiglu_alpha is None and self.swiglu_beta is None and self.swiglu_limit is None, \
443+
"TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters."
381444

382445
if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
383446
assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants."
@@ -405,8 +468,7 @@ def _get_quant_method(self):
405468
f"Unsupported quantization method by TRTLLMGenFusedMoE: {self.quant_config.quant_mode}"
406469
)
407470
else:
408-
raise NotImplementedError(
409-
"TRTLLMGenFusedMoE doesn't support fp16/bf16/fp32 MoE.")
471+
return BF16TRTLLMGenFusedMoEMethod()
410472

411473
def create_weights(self):
412474
if self._weights_created:
@@ -467,6 +529,8 @@ def quantize_input(self, x, post_quant_comm: bool = True):
467529
- scaling_vector_size is typically the group size for block-wise quantization
468530
"""
469531
x_sf = None
532+
if not self.has_any_quant:
533+
return x, x_sf
470534
if self.has_w4a8_mxfp4_fp8:
471535
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
472536
x = torch.nn.functional.pad(x, (0, pad_size))
@@ -526,7 +590,7 @@ def quantize_input(self, x, post_quant_comm: bool = True):
526590
return x, x_sf
527591

528592
def supports_moe_output_in_alltoall_workspace(self):
529-
return True
593+
return self.has_any_quant and not self.use_flashinfer
530594

531595
def run_moe(
532596
self,
@@ -542,8 +606,8 @@ def run_moe(
542606
Run MoE computation with TRTLLMGen backend.
543607
544608
This method encapsulates the core MoE computation logic, handling different
545-
quantization schemes (fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8,
546-
w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8).
609+
quantization schemes (bf16, fp8_block_scales, nvfp4, w4a16_mxfp4,
610+
w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8).
547611
548612
Args:
549613
# Standard MoE interface parameters:
@@ -592,7 +656,37 @@ def run_moe(
592656
) == 2, f"x_sf should be 2D tensor, got shape {x_sf.shape}"
593657
x_sf = x_sf.flatten()
594658

595-
if self.has_deepseek_fp8_block_scales:
659+
if not self.has_any_quant:
660+
result = self.op_backend.run_bf16_moe(
661+
router_logits=router_logits,
662+
routing_bias=routing_bias,
663+
hidden_states=x,
664+
gemm1_weights=self.w3_w1_weight,
665+
gemm2_weights=self.w2_weight,
666+
num_experts=self.num_slots,
667+
top_k=top_k,
668+
n_group=n_group,
669+
topk_group=topk_group,
670+
intermediate_size=self.intermediate_size_per_partition,
671+
local_expert_offset=self.slot_start,
672+
local_num_experts=self.expert_size_per_partition,
673+
routed_scaling_factor=routed_scaling_factor,
674+
routing_method_type=self.routing_method.routing_method_type,
675+
topk_weights=token_final_scales,
676+
topk_ids=token_selected_experts,
677+
gated_act_type=self._to_trtllm_gen_activation_type(
678+
self.activation_type),
679+
output=moe_output,
680+
use_shuffled_weight=getattr(self.quant_method,
681+
"use_shuffled_weight", False),
682+
weight_layout=getattr(self.quant_method, "weight_layout", 0),
683+
do_finalize=do_finalize,
684+
)
685+
if not do_finalize:
686+
assert not self.reduce_results, "reduce_results must be False when do_finalize is False"
687+
return result
688+
final_hidden_states = result
689+
elif self.has_deepseek_fp8_block_scales:
596690
assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False"
597691
# fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+
598692
if x_sf is None:

0 commit comments

Comments
 (0)