3838
3939# isort: off
4040from .quantization import (
41- DeepSeekFP8BlockScalesFusedMoEMethod , NVFP4TRTLLMGenFusedMoEBaseMethod ,
42- NVFP4TRTLLMGenFusedMoEMethod , W4A8MXFP4FP8TRTLLMGenFusedMoEMethod ,
43- W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod , W4A8NVFP4FP8TRTLLMGenFusedMoEMethod ,
44- W4A16MXFP4TRTLLMGenFusedMoEMethod )
41+ BF16TRTLLMGenFusedMoEMethod , DeepSeekFP8BlockScalesFusedMoEMethod ,
42+ NVFP4TRTLLMGenFusedMoEBaseMethod , NVFP4TRTLLMGenFusedMoEMethod ,
43+ W4A8MXFP4FP8TRTLLMGenFusedMoEMethod , W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod ,
44+ W4A8NVFP4FP8TRTLLMGenFusedMoEMethod , W4A16MXFP4TRTLLMGenFusedMoEMethod )
4545# isort: on
4646from .routing import (BaseMoeRoutingMethod , DeepSeekV3MoeRoutingMethod ,
4747 DefaultMoeRoutingMethod )
@@ -115,7 +115,8 @@ def can_implement(
115115 - W4A8_MXFP4_FP8
116116 - W4A8_MXFP4_MXFP8
117117
118- Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
118+ Unquantized BF16 path is supported only with FlashInfer fused MoE backend.
119+ Output dtype is hardcoded to bfloat16.
119120
120121 Args:
121122 quant_algo: The quantization algorithm to check (None for unquantized)
@@ -143,10 +144,16 @@ def can_implement(
143144 f"TRTLLMGenFusedMoE only supports bfloat16 activation, got { dtype_activation } "
144145 )
145146
146- # TRTLLMGenFusedMoE does NOT support unquantized mode
147147 if quant_algo is None :
148- return _warn_and_return (
149- "TRTLLMGenFusedMoE does not support unquantized mode" )
148+ if swiglu_gptoss_style :
149+ return _warn_and_return (
150+ "TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters."
151+ )
152+ if not cls ._is_flashinfer_fused_moe_available ():
153+ return _warn_and_return (
154+ "TRTLLMGenFusedMoE unquantized BF16 path requires FlashInfer fused MoE "
155+ "with trtllm_bf16_moe support." )
156+ return True , None
150157
151158 # Check if quant_algo is supported
152159 if quant_algo not in cls ._SUPPORTED_QUANT_ALGOS :
@@ -210,7 +217,14 @@ def __init__(
210217
211218 assert not self .smart_router , "Smart router is not supported in TRTLLMGenFusedMoE."
212219
213- self .use_flashinfer = self ._check_op_backend_support ()
220+ self .use_flashinfer = self ._check_flashinfer_backend_support ()
221+ if (self .quant_config is None
222+ or not self .quant_config .layer_quant_mode .has_any_quant (
223+ exclude_kv_cache = True )) and not self .use_flashinfer :
224+ raise NotImplementedError (
225+ "TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE. "
226+ "Please install a FlashInfer build with trtllm_bf16_moe support."
227+ )
214228 backend_name = "flashinfer" if self .use_flashinfer else "trtllm"
215229 self .op_backend : MoEOpBackend = get_op_backend (backend_name )
216230
@@ -292,7 +306,43 @@ def _to_trtllm_gen_activation_type(self,
292306 else :
293307 raise ValueError (f"Unsupported activation type: { activation_type } " )
294308
295- def _check_op_backend_support (self ) -> bool :
309+ @staticmethod
310+ def _is_flashinfer_fused_moe_available () -> bool :
311+ try :
312+ from flashinfer .fused_moe import core as _core
313+ except (ImportError , ModuleNotFoundError ):
314+ return False
315+ return (hasattr (_core , "trtllm_bf16_moe" )
316+ and hasattr (_core , "trtllm_bf16_routed_moe" ))
317+
318+ def _is_unquantized_path (self ) -> bool :
319+ return self .quant_config is None or not self .quant_config .layer_quant_mode .has_any_quant (
320+ exclude_kv_cache = True )
321+
322+ @staticmethod
323+ def _supports_flashinfer_bf16_routing_method (
324+ routing_method : BaseMoeRoutingMethod , ) -> bool :
325+ # FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug
326+ return not isinstance (routing_method , DeepSeekV3MoeRoutingMethod )
327+
328+ def _requires_separated_routing (self ) -> bool :
329+ """Whether this backend instance expects precomputed top-k routing."""
330+ # FIXME: ban FlashInfer BF16 MoE direct routing as it appears to have accuracy bug
331+ return self .use_flashinfer and self ._is_unquantized_path ()
332+
333+ def _check_flashinfer_backend_support (self ) -> bool :
334+ # For BF16 (unquantized) path, we will use FlashInfer regardless whether
335+ # env TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER=1 is set or not as it's the only way.
336+ if self ._is_unquantized_path ():
337+ if not self ._is_flashinfer_fused_moe_available ():
338+ return False
339+ if self .activation_type != ActivationType .Swiglu :
340+ return False
341+ if not self ._supports_flashinfer_bf16_routing_method (
342+ self .routing_method ):
343+ return False
344+ return True
345+
296346 use_flashinfer = os .environ .get ("TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER" ,
297347 "0" )
298348 if use_flashinfer != "1" :
@@ -311,8 +361,6 @@ def _check_op_backend_support(self) -> bool:
311361 if type (quant_method ) is NVFP4TRTLLMGenFusedMoEBaseMethod :
312362 return True
313363
314- if self .quant_config is None :
315- return False
316364 mode = self .quant_config .layer_quant_mode
317365
318366 # These quant modes are never supported via op backend
@@ -365,7 +413,14 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
365413 return AlltoallMethodType .NVLinkOneSided
366414
367415 def _supports_load_balancer (self ) -> bool :
368- """TRTLLMGenFusedMoE supports load balancer."""
416+ """Whether separated routing (top-k outside the kernel) is used.
417+
418+ ConfigurableMoE uses this flag to decide whether routing is separated
419+ (top-k ids/scales computed outside backend) or fused inside the kernel.
420+ BF16 FlashInfer path always requires separated routing.
421+ """
422+ if self ._requires_separated_routing ():
423+ return True
369424 return self .use_dp and self .parallel_size > 1
370425
371426 @cached_property
@@ -375,9 +430,17 @@ def enable_alltoall(self):
375430 return self .alltoall_method_type != AlltoallMethodType .NotEnabled
376431
377432 def _check_configs (self ):
378- assert self .has_deepseek_fp8_block_scales \
433+ assert not self .has_any_quant \
434+ or self .has_deepseek_fp8_block_scales \
379435 or self .has_nvfp4 or self .has_w4a16_mxfp4 or self .has_w4a8_nvfp4_fp8 \
380- or self .has_w4a8_mxfp4_fp8 or self .has_w4a8_mxfp4_mxfp8 , "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
436+ or self .has_w4a8_mxfp4_fp8 or self .has_w4a8_mxfp4_mxfp8 , \
437+ "TRTLLMGenFusedMoE only supports bf16 (FlashInfer), fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
438+
439+ if not self .has_any_quant :
440+ assert self .activation_type == ActivationType .Swiglu , \
441+ "TRTLLMGenFusedMoE BF16 path only supports Swiglu activation."
442+ assert not self .bias and self .swiglu_alpha is None and self .swiglu_beta is None and self .swiglu_limit is None , \
443+ "TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters."
381444
382445 if self .bias or self .swiglu_alpha is not None or self .swiglu_beta is not None or self .swiglu_limit is not None :
383446 assert self .has_nvfp4 or self .has_w4a16_mxfp4 or self .has_w4a8_mxfp4_fp8 or self .has_w4a8_mxfp4_mxfp8 , "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants."
@@ -405,8 +468,7 @@ def _get_quant_method(self):
405468 f"Unsupported quantization method by TRTLLMGenFusedMoE: { self .quant_config .quant_mode } "
406469 )
407470 else :
408- raise NotImplementedError (
409- "TRTLLMGenFusedMoE doesn't support fp16/bf16/fp32 MoE." )
471+ return BF16TRTLLMGenFusedMoEMethod ()
410472
411473 def create_weights (self ):
412474 if self ._weights_created :
@@ -467,6 +529,8 @@ def quantize_input(self, x, post_quant_comm: bool = True):
467529 - scaling_vector_size is typically the group size for block-wise quantization
468530 """
469531 x_sf = None
532+ if not self .has_any_quant :
533+ return x , x_sf
470534 if self .has_w4a8_mxfp4_fp8 :
471535 pad_size = self .w3_w1_weight .shape [- 1 ] * 2 - x .shape [- 1 ]
472536 x = torch .nn .functional .pad (x , (0 , pad_size ))
@@ -526,7 +590,7 @@ def quantize_input(self, x, post_quant_comm: bool = True):
526590 return x , x_sf
527591
528592 def supports_moe_output_in_alltoall_workspace (self ):
529- return True
593+ return self . has_any_quant and not self . use_flashinfer
530594
531595 def run_moe (
532596 self ,
@@ -542,8 +606,8 @@ def run_moe(
542606 Run MoE computation with TRTLLMGen backend.
543607
544608 This method encapsulates the core MoE computation logic, handling different
545- quantization schemes (fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8 ,
546- w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8).
609+ quantization schemes (bf16, fp8_block_scales, nvfp4, w4a16_mxfp4,
610+ w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8).
547611
548612 Args:
549613 # Standard MoE interface parameters:
@@ -592,7 +656,37 @@ def run_moe(
592656 ) == 2 , f"x_sf should be 2D tensor, got shape { x_sf .shape } "
593657 x_sf = x_sf .flatten ()
594658
595- if self .has_deepseek_fp8_block_scales :
659+ if not self .has_any_quant :
660+ result = self .op_backend .run_bf16_moe (
661+ router_logits = router_logits ,
662+ routing_bias = routing_bias ,
663+ hidden_states = x ,
664+ gemm1_weights = self .w3_w1_weight ,
665+ gemm2_weights = self .w2_weight ,
666+ num_experts = self .num_slots ,
667+ top_k = top_k ,
668+ n_group = n_group ,
669+ topk_group = topk_group ,
670+ intermediate_size = self .intermediate_size_per_partition ,
671+ local_expert_offset = self .slot_start ,
672+ local_num_experts = self .expert_size_per_partition ,
673+ routed_scaling_factor = routed_scaling_factor ,
674+ routing_method_type = self .routing_method .routing_method_type ,
675+ topk_weights = token_final_scales ,
676+ topk_ids = token_selected_experts ,
677+ gated_act_type = self ._to_trtllm_gen_activation_type (
678+ self .activation_type ),
679+ output = moe_output ,
680+ use_shuffled_weight = getattr (self .quant_method ,
681+ "use_shuffled_weight" , False ),
682+ weight_layout = getattr (self .quant_method , "weight_layout" , 0 ),
683+ do_finalize = do_finalize ,
684+ )
685+ if not do_finalize :
686+ assert not self .reduce_results , "reduce_results must be False when do_finalize is False"
687+ return result
688+ final_hidden_states = result
689+ elif self .has_deepseek_fp8_block_scales :
596690 assert do_finalize , "fp8_block_scale_moe_runner does not support do_finalize=False"
597691 # fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+
598692 if x_sf is None :
0 commit comments