diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 665fda0f6c6a..42e28ccb551a 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -3247,9 +3247,10 @@ def launchTestJobs(pipeline, testFilter) "DGX_H100-4_GPUs-PyTorch-Ray-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4], "DGX_H100-4_GPUs-AutoDeploy-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4], "DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4], - "DGX_B200-PyTorch-1": ["auto:dgx-b200-flex", "l0_b200", 1, 3, 1, 1, true], - "DGX_B200-PyTorch-2": ["auto:dgx-b200-flex", "l0_b200", 2, 3, 1, 1, true], - "DGX_B200-PyTorch-3": ["auto:dgx-b200-flex", "l0_b200", 3, 3, 1, 1, true], + "DGX_B200-PyTorch-1": ["auto:dgx-b200-flex", "l0_b200", 1, 4, 1, 1, true], + "DGX_B200-PyTorch-2": ["auto:dgx-b200-flex", "l0_b200", 2, 4, 1, 1, true], + "DGX_B200-PyTorch-3": ["auto:dgx-b200-flex", "l0_b200", 3, 4, 1, 1, true], + "DGX_B200-PyTorch-4": ["auto:dgx-b200-flex", "l0_b200", 4, 4, 1, 1, true], "DGX_B200-AutoDeploy-1": ["auto:dgx-b200-flex", "l0_b200", 1, 1, 1, 1, true], "DGX_B200-Triton-Post-Merge-1": ["auto:dgx-b200-flex", "l0_b200", 1, 1, 1, 1, true], "DGX_B200-PyTorch-Post-Merge-1": ["auto:dgx-b200-flex", "l0_b200", 1, 2, 1, 1, true], diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 986e2bf7b467..3b41bbaf2b77 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -28,6 +28,7 @@ 4. Unified EPLB integration for backends that support it """ +import copy from typing import Dict, List, Optional, Tuple, Union import torch @@ -164,10 +165,23 @@ def __init__( self.apply_router_weight_on_input = apply_router_weight_on_input # ========== Create MoE Backend (Default: Cutlass) ========== - from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend, get_moe_cls + from tensorrt_llm._torch.modules.fused_moe.create_moe import ( + create_moe_backend, + resolve_moe_cls, + ) + + # Get MoE backend class based on override_quant_config, routing_method, and model_config + moe_cls = resolve_moe_cls( + model_config, + routing_method, + self.dtype, + override_quant_config=override_quant_config, + ) - # Get MoE backend class based on override_quant_config or model_config - moe_cls = get_moe_cls(model_config, override_quant_config=override_quant_config) + backend_model_config = model_config + if override_quant_config is not None: + backend_model_config = copy.deepcopy(model_config) + backend_model_config.quant_config = override_quant_config # Call create_moe_backend with all necessary parameters # init_load_balancer=False: Prevents backend from registering itself with load balancer @@ -175,10 +189,10 @@ def __init__( # skip_create_weights_in_init=True: Prevents backend from creating weights in __init__ # because backend uses layer_idx=None and may have different expert assignments # We will create weights after syncing attributes from ConfigurableMoE - tmp_skip_create_weights_in_init = model_config.skip_create_weights_in_init - model_config._frozen = False - model_config.skip_create_weights_in_init = True - model_config._frozen = True + tmp_skip_create_weights_in_init = backend_model_config.skip_create_weights_in_init + backend_model_config._frozen = False + backend_model_config.skip_create_weights_in_init = True + backend_model_config._frozen = True backend = create_moe_backend( moe_cls=moe_cls, @@ -188,7 +202,7 @@ def __init__( intermediate_size=self.intermediate_size, dtype=self.dtype, reduce_results=self.reduce_results, - model_config=model_config, + model_config=backend_model_config, aux_stream_dict=self.aux_stream_dict, weight_loading_mode=self.weight_loading_mode, bias=kwargs.get("bias", False), @@ -223,10 +237,10 @@ def __init__( self.backend.expert_size_per_partition = self.expert_size_per_partition # Create weights here, because the backend needs the layer_load_balancer info to create weights - model_config._frozen = False - model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init - model_config._frozen = True - if not model_config.skip_create_weights_in_init: + backend_model_config._frozen = False + backend_model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init + backend_model_config._frozen = True + if not backend_model_config.skip_create_weights_in_init: self.backend.create_weights() # ========== Create Communication Strategy ========== diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 0638904ab8e7..0ee2741743cc 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -63,14 +63,23 @@ def get_moe_cls( return CutlassFusedMoE return DenseGEMMFusedMoE elif moe_backend.upper() == "TRTLLM": - if quant_config is not None and ( - quant_config.quant_mode.has_fp8_block_scales() - or quant_config.quant_mode.has_nvfp4() - or quant_config.quant_mode.has_w4a16_mxfp4() - or quant_config.quant_mode.has_w4a8_nvfp4_fp8() - or quant_config.quant_mode.has_w4a8_mxfp4_fp8() - or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()): + has_quant = quant_config is not None and quant_config.quant_mode.has_any_quant( + exclude_kv_cache=True) + if has_quant and (quant_config.quant_mode.has_fp8_block_scales() + or quant_config.quant_mode.has_nvfp4() + or quant_config.quant_mode.has_w4a16_mxfp4() + or quant_config.quant_mode.has_w4a8_nvfp4_fp8() + or quant_config.quant_mode.has_w4a8_mxfp4_fp8() + or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()): return TRTLLMGenFusedMoE + if not has_quant and model_config.pretrained_config is not None and getattr( + model_config.pretrained_config, "torch_dtype", + None) == torch.bfloat16: + if TRTLLMGenFusedMoE._is_flashinfer_fused_moe_available(): + return TRTLLMGenFusedMoE + raise RuntimeError( + "TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE with " + "trtllm_bf16_moe support, but it is not available.") else: logger.warning( "TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, and w4a8_mxfp4_mxfp8. " @@ -85,6 +94,25 @@ def get_moe_cls( raise ValueError(f"Unsupported moe backend: {moe_backend}") +def resolve_moe_cls( + model_config: ModelConfig, + routing_method: BaseMoeRoutingMethod, + dtype: Optional[torch.dtype], + override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]: + moe_cls = get_moe_cls(model_config, override_quant_config) + + effective_quant_config = override_quant_config or model_config.quant_config + has_quant = (effective_quant_config is not None + and effective_quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True)) + if (moe_cls == TRTLLMGenFusedMoE and not has_quant + and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method( + routing_method)): + return CutlassFusedMoE + + return moe_cls + + def create_moe_backend( moe_cls: Type[MoE], routing_method: BaseMoeRoutingMethod, @@ -379,7 +407,8 @@ def create_moe( pretrained_config, 'torch_dtype'): dtype = pretrained_config.torch_dtype - moe_cls = get_moe_cls(model_config, override_quant_config) + moe_cls = resolve_moe_cls(model_config, routing_method, dtype, + override_quant_config) enable_configurable_moe = os.environ.get("ENABLE_CONFIGURABLE_MOE", "1") == "1" diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 23354f5a5b34..65b82710871c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -38,10 +38,10 @@ # isort: off from .quantization import ( - DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod, - NVFP4TRTLLMGenFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, - W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, - W4A16MXFP4TRTLLMGenFusedMoEMethod) + BF16TRTLLMGenFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethod, + NVFP4TRTLLMGenFusedMoEBaseMethod, NVFP4TRTLLMGenFusedMoEMethod, + W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, + W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod) # isort: on from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod, DefaultMoeRoutingMethod) @@ -115,7 +115,8 @@ def can_implement( - W4A8_MXFP4_FP8 - W4A8_MXFP4_MXFP8 - Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16. + Unquantized BF16 path is supported only with FlashInfer fused MoE backend. + Output dtype is hardcoded to bfloat16. Args: quant_algo: The quantization algorithm to check (None for unquantized) @@ -143,10 +144,16 @@ def can_implement( f"TRTLLMGenFusedMoE only supports bfloat16 activation, got {dtype_activation}" ) - # TRTLLMGenFusedMoE does NOT support unquantized mode if quant_algo is None: - return _warn_and_return( - "TRTLLMGenFusedMoE does not support unquantized mode") + if swiglu_gptoss_style: + return _warn_and_return( + "TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters." + ) + if not cls._is_flashinfer_fused_moe_available(): + return _warn_and_return( + "TRTLLMGenFusedMoE unquantized BF16 path requires FlashInfer fused MoE " + "with trtllm_bf16_moe support.") + return True, None # Check if quant_algo is supported if quant_algo not in cls._SUPPORTED_QUANT_ALGOS: @@ -210,7 +217,14 @@ def __init__( assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE." - self.use_flashinfer = self._check_op_backend_support() + self.use_flashinfer = self._check_flashinfer_backend_support() + if (self.quant_config is None + or not self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True)) and not self.use_flashinfer: + raise NotImplementedError( + "TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE. " + "Please install a FlashInfer build with trtllm_bf16_moe support." + ) backend_name = "flashinfer" if self.use_flashinfer else "trtllm" self.op_backend: MoEOpBackend = get_op_backend(backend_name) @@ -292,7 +306,43 @@ def _to_trtllm_gen_activation_type(self, else: raise ValueError(f"Unsupported activation type: {activation_type}") - def _check_op_backend_support(self) -> bool: + @staticmethod + def _is_flashinfer_fused_moe_available() -> bool: + try: + from flashinfer.fused_moe import core as _core + except (ImportError, ModuleNotFoundError): + return False + return (hasattr(_core, "trtllm_bf16_moe") + and hasattr(_core, "trtllm_bf16_routed_moe")) + + def _is_unquantized_path(self) -> bool: + return self.quant_config is None or not self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True) + + @staticmethod + def _supports_flashinfer_bf16_routing_method( + routing_method: BaseMoeRoutingMethod, ) -> bool: + # FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug + return not isinstance(routing_method, DeepSeekV3MoeRoutingMethod) + + def _requires_separated_routing(self) -> bool: + """Whether this backend instance expects precomputed top-k routing.""" + # FIXME: ban FlashInfer BF16 MoE direct routing as it appears to have accuracy bug + return self.use_flashinfer and self._is_unquantized_path() + + def _check_flashinfer_backend_support(self) -> bool: + # For BF16 (unquantized) path, we will use FlashInfer regardless whether + # env TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER=1 is set or not as it's the only way. + if self._is_unquantized_path(): + if not self._is_flashinfer_fused_moe_available(): + return False + if self.activation_type != ActivationType.Swiglu: + return False + if not self._supports_flashinfer_bf16_routing_method( + self.routing_method): + return False + return True + use_flashinfer = os.environ.get("TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER", "0") if use_flashinfer != "1": @@ -311,8 +361,6 @@ def _check_op_backend_support(self) -> bool: if type(quant_method) is NVFP4TRTLLMGenFusedMoEBaseMethod: return True - if self.quant_config is None: - return False mode = self.quant_config.layer_quant_mode # These quant modes are never supported via op backend @@ -365,7 +413,14 @@ def select_alltoall_method_type(self) -> AlltoallMethodType: return AlltoallMethodType.NVLinkOneSided def _supports_load_balancer(self) -> bool: - """TRTLLMGenFusedMoE supports load balancer.""" + """Whether separated routing (top-k outside the kernel) is used. + + ConfigurableMoE uses this flag to decide whether routing is separated + (top-k ids/scales computed outside backend) or fused inside the kernel. + BF16 FlashInfer path always requires separated routing. + """ + if self._requires_separated_routing(): + return True return self.use_dp and self.parallel_size > 1 @cached_property @@ -375,9 +430,17 @@ def enable_alltoall(self): return self.alltoall_method_type != AlltoallMethodType.NotEnabled def _check_configs(self): - assert self.has_deepseek_fp8_block_scales \ + assert not self.has_any_quant \ + or self.has_deepseek_fp8_block_scales \ or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \ - or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes." + or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, \ + "TRTLLMGenFusedMoE only supports bf16 (FlashInfer), fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes." + + if not self.has_any_quant: + assert self.activation_type == ActivationType.Swiglu, \ + "TRTLLMGenFusedMoE BF16 path only supports Swiglu activation." + assert not self.bias and self.swiglu_alpha is None and self.swiglu_beta is None and self.swiglu_limit is None, \ + "TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters." if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None: assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants." @@ -405,8 +468,7 @@ def _get_quant_method(self): f"Unsupported quantization method by TRTLLMGenFusedMoE: {self.quant_config.quant_mode}" ) else: - raise NotImplementedError( - "TRTLLMGenFusedMoE doesn't support fp16/bf16/fp32 MoE.") + return BF16TRTLLMGenFusedMoEMethod() def create_weights(self): if self._weights_created: @@ -467,6 +529,8 @@ def quantize_input(self, x, post_quant_comm: bool = True): - scaling_vector_size is typically the group size for block-wise quantization """ x_sf = None + if not self.has_any_quant: + return x, x_sf if self.has_w4a8_mxfp4_fp8: pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] x = torch.nn.functional.pad(x, (0, pad_size)) @@ -526,7 +590,7 @@ def quantize_input(self, x, post_quant_comm: bool = True): return x, x_sf def supports_moe_output_in_alltoall_workspace(self): - return True + return self.has_any_quant and not self.use_flashinfer def run_moe( self, @@ -542,8 +606,8 @@ def run_moe( Run MoE computation with TRTLLMGen backend. This method encapsulates the core MoE computation logic, handling different - quantization schemes (fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, - w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8). + quantization schemes (bf16, fp8_block_scales, nvfp4, w4a16_mxfp4, + w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8). Args: # Standard MoE interface parameters: @@ -592,7 +656,37 @@ def run_moe( ) == 2, f"x_sf should be 2D tensor, got shape {x_sf.shape}" x_sf = x_sf.flatten() - if self.has_deepseek_fp8_block_scales: + if not self.has_any_quant: + result = self.op_backend.run_bf16_moe( + router_logits=router_logits, + routing_bias=routing_bias, + hidden_states=x, + gemm1_weights=self.w3_w1_weight, + gemm2_weights=self.w2_weight, + num_experts=self.num_slots, + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.slot_start, + local_num_experts=self.expert_size_per_partition, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=self.routing_method.routing_method_type, + topk_weights=token_final_scales, + topk_ids=token_selected_experts, + gated_act_type=self._to_trtllm_gen_activation_type( + self.activation_type), + output=moe_output, + use_shuffled_weight=getattr(self.quant_method, + "use_shuffled_weight", False), + weight_layout=getattr(self.quant_method, "weight_layout", 0), + do_finalize=do_finalize, + ) + if not do_finalize: + assert not self.reduce_results, "reduce_results must be False when do_finalize is False" + return result + final_hidden_states = result + elif self.has_deepseek_fp8_block_scales: assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False" # fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+ if x_sf is None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py b/tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py index 0655410cd8ca..771484d642ec 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py @@ -25,6 +25,8 @@ import torch +from ...utils import ActType_TrtllmGen + # Global registry for MoE backends _MOE_OP_BACKEND_REGISTRY: Dict[str, Type["MoEOpBackend"]] = {} @@ -164,6 +166,35 @@ def run_fp4_block_scale_moe( """Run FP4 block scale MoE computation.""" raise NotImplementedError + def run_bf16_moe( + self, + router_logits: Optional[torch.Tensor], + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int, + topk_weights: Optional[torch.Tensor] = None, + topk_ids: Optional[torch.Tensor] = None, + gated_act_type: int = 0, + output: Optional[torch.Tensor] = None, + use_shuffled_weight: bool = False, + weight_layout: int = 0, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, + ) -> torch.Tensor: + """Run BF16 MoE computation.""" + raise NotImplementedError + # ==================== TRTLLM Backend ==================== @register_op_backend("trtllm") @@ -405,6 +436,37 @@ def run_fp4_block_scale_moe( output=output, ) + def run_bf16_moe( + self, + router_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + routing_method_type, + topk_weights=None, + topk_ids=None, + gated_act_type=0, + output=None, + use_shuffled_weight=False, + weight_layout=0, + do_finalize=True, + enable_pdl=None, + tune_max_num_tokens=8192, + ): + raise NotImplementedError( + "TRTLLM native op backend does not support unquantized BF16 TRTLLM-Gen fused MoE. " + "Enable FlashInfer fused MoE for TRTLLM backend." + ) + # ==================== Flashinfer Backend ==================== @register_op_backend("flashinfer") @@ -689,3 +751,87 @@ def run_fp4_block_scale_moe( else: final_hidden_states = outputs[0] return final_hidden_states + + def run_bf16_moe( + self, + router_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + routing_method_type, + topk_weights=None, + topk_ids=None, + gated_act_type=0, + output=None, + use_shuffled_weight=False, + weight_layout=0, + do_finalize=True, + enable_pdl=None, + tune_max_num_tokens=8192, + ): + # FlashInfer BF16 MoE does not expose an activation_type argument. + # TRTLLMGen constrains the BF16 path to Swiglu, so reject anything + # else here instead of silently calling a mismatched kernel. + if gated_act_type != ActType_TrtllmGen.SwiGlu: + raise ValueError("FlashInfer BF16 fused MoE only supports Swiglu activation.") + + if router_logits is not None: + result = self._fused_moe.trtllm_bf16_moe( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=self.cvt_routing_method_type(routing_method_type), + use_shuffled_weight=use_shuffled_weight, + weight_layout=weight_layout, + do_finalize=do_finalize, + enable_pdl=enable_pdl, + tune_max_num_tokens=tune_max_num_tokens, + ) + else: + packed_topk_ids = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16 + ).contiguous().view(torch.int16).to(torch.int32) + result = self._fused_moe.trtllm_bf16_routed_moe( + packed_topk_ids, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + self.cvt_routing_method_type(routing_method_type), + use_shuffled_weight=use_shuffled_weight, + weight_layout=weight_layout, + do_finalize=do_finalize, + enable_pdl=enable_pdl, + tune_max_num_tokens=tune_max_num_tokens, + ) + + if output is not None and do_finalize: + output.copy_(result) + return output + return result diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index c30b7d771aa9..93a6d3cacee6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -49,6 +49,10 @@ # pack weight block scales into int32, e.g. 4 x fp8 weight values FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE = torch.int32 FUSED_MOE_MXFP4_WEIGHT_BLOCK_SCALE_DTYPE = torch.int32 +# TRTLLM-Gen MatrixLayout enum values. +TRTLLM_GEN_WEIGHT_LAYOUT_MAJOR_K = 0 +TRTLLM_GEN_WEIGHT_LAYOUT_MAJOR_MN = 1 +TRTLLM_GEN_WEIGHT_LAYOUT_BLOCK_MAJOR_K = 2 class FusedMoEQuantScalesFP8(NamedTuple): @@ -164,6 +168,28 @@ def trtllmgen_maybe_get_cached_w2_permute_indices( return permute_indices +def _convert_to_block_major_k_layout(input_tensor: torch.Tensor, + block_k: int) -> torch.Tensor: + if input_tensor.dim() != 2: + raise ValueError( + f"input_tensor must be 2D for BlockMajorK conversion, got shape={tuple(input_tensor.shape)}" + ) + m, k = input_tensor.shape + if k % block_k != 0: + raise ValueError( + f"K dimension ({k}) must be divisible by block_k ({block_k}) for BlockMajorK layout." + ) + return input_tensor.view(m, k // block_k, block_k).permute(1, 0, + 2).contiguous() + + +def _prepare_bf16_weight_for_trtllm_gen(weight: torch.Tensor, + permute_indices: torch.Tensor, + block_k: int) -> torch.Tensor: + shuffled_weight = weight[permute_indices.to(weight.device)].contiguous() + return _convert_to_block_major_k_layout(shuffled_weight, block_k) + + def maybe_pad_for_mxfp4(weight: torch.Tensor, col_alignment: int, row_alignment: Optional[int] = None) -> torch.Tensor: @@ -225,6 +251,8 @@ class FusedMoEMethodBase(ABC): to work with online EPLB should override this to SUPPORTED; those that have not yet been tested may set it to NOT_VERIFIED. """ + needs_post_load_processing_for_dummy: bool = False + """Whether LoadFormat.DUMMY must finish weight processing in post_load_weights().""" @classmethod def supports_online_eplb(cls) -> bool: @@ -300,6 +328,8 @@ def create_weights( module.w2_bias = None module.rebuild_tensor_metadata = {} + module._needs_post_load_weight_processing = True + module._weights_loaded_via_load_weights = False def load_expert_weights_to_dst( self, @@ -411,6 +441,7 @@ def load_weights(self, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode, allow_partial_loading: bool = False): + module._weights_loaded_via_load_weights = True if allow_partial_loading: if not isinstance(self, (UnquantizedFusedMoEMethod, FP8QDQFusedMoEMethod, @@ -493,8 +524,23 @@ def load_weights(self, if not allow_partial_loading: self.process_weights_after_loading(module) + module._needs_post_load_weight_processing = False def post_load_weights(self, module: torch.nn.Module): + # LoadFormat.DUMMY initializes parameters in-place without calling + # load_weights(), so only methods that explicitly opt in should finish + # their processing here unless load_weights() left work unfinished. + needs_post_load_processing = getattr( + module, "_needs_post_load_weight_processing", True) + loaded_via_load_weights = getattr(module, + "_weights_loaded_via_load_weights", + False) + if needs_post_load_processing and ( + loaded_via_load_weights + or self.needs_post_load_processing_for_dummy): + self.process_weights_after_loading(module) + module._needs_post_load_weight_processing = False + if self.need_load_shared_weights(module): weight_fns = { 'w3_w1_weight': getattr(module, 'local_shared_w3_w1_tensors'), @@ -645,6 +691,59 @@ def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = tuple() +class BF16TRTLLMGenFusedMoEMethod(UnquantizedFusedMoEMethod): + # BlockMajorK uses 128-byte K blocks. BF16 has 2 bytes per element. + block_k = 64 + use_shuffled_weight = True + weight_layout = TRTLLM_GEN_WEIGHT_LAYOUT_BLOCK_MAJOR_K + needs_post_load_processing_for_dummy = True + _cache_permute_indices: Dict[tuple[tuple[int, ...], str, int], + torch.Tensor] = {} + + def _get_w3_w1_permute_indices( + self, + w3_w1_weight: torch.Tensor, + is_gated_act_gemm: bool = True) -> torch.Tensor: + return trtllmgen_maybe_get_cached_w3_w1_permute_indices( + w3_w1_weight.view(torch.uint8), + self._cache_permute_indices, + epilogue_tile_m=128, + is_gated_act_gemm=is_gated_act_gemm) + + def _get_w2_permute_indices(self, w2_weight: torch.Tensor) -> torch.Tensor: + return trtllmgen_maybe_get_cached_w2_permute_indices( + w2_weight.view(torch.uint8), + self._cache_permute_indices, + epilogue_tile_m=128) + + def process_weights_after_loading(self, module: torch.nn.Module): + if module.w3_w1_weight.numel() == 0 or module.w2_weight.numel() == 0: + return + + w3_w1_permute_indices = self._get_w3_w1_permute_indices( + module.w3_w1_weight.data[0], + is_gated_act_gemm=getattr(module, "is_gated_activation", True)) + w2_permute_indices = self._get_w2_permute_indices( + module.w2_weight.data[0]) + + processed_w3_w1 = torch.stack([ + _prepare_bf16_weight_for_trtllm_gen(expert, w3_w1_permute_indices, + self.block_k) + for expert in module.w3_w1_weight.data + ]) + processed_w2 = torch.stack([ + _prepare_bf16_weight_for_trtllm_gen(expert, w2_permute_indices, + self.block_k) + for expert in module.w2_weight.data + ]) + + replace_parameter_and_save_metadata(module, "w3_w1_weight", + processed_w3_w1, + module.rebuild_tensor_metadata) + replace_parameter_and_save_metadata(module, "w2_weight", processed_w2, + module.rebuild_tensor_metadata) + + def load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale, dst_fc31_input_scale: torch.Tensor): if w1_input_scale is not None and w1_input_scale.numel() != 0: diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index 08b28bbc939e..8ef2b23c2525 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -293,6 +293,11 @@ "accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-TRTLLM]": 3240.4374056719825603, "accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS]": 3240.4651110970880836, "accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM]": 3240.4263977239606902, + "accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5-fp8kv=False]": 1202.0, + "accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9-fp8kv=False]": 1127.0, + "accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-CUTLASS]": 1091.0, + "accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-TRTLLM]": 1091.0, + "accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp1]": 313.0, "accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency-torch_compile=False]": 149.19146074401215, "accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]": 104.32479889906244, "accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]": 240.30756398336961865, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index a334d46a8d01..f0fefde0857a 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -5720,41 +5720,55 @@ class TestQwen3_5_35B_A3B(LlmapiAccuracyTestHarness): chat_template_kwargs=dict(enable_thinking=False), ) - def test_bf16(self): - world_size = 1 + @pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRTLLM"]) + @pytest.mark.parametrize( + "tp_size", + [1, pytest.param(2, marks=pytest.mark.skip_less_device(2))], + ids=["tp1", "tp2"], + ) + def test_bf16(self, moe_backend, tp_size): + if moe_backend == "TRTLLM" and get_sm_version() not in (100, 103): + pytest.skip(f"{moe_backend} backend supports SM 100 and 103 only") + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8, enable_block_reuse=False) cuda_graph_config = CudaGraphConfig( enable_padding=True, batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128]) + moe_config = MoeConfig(backend=moe_backend) with LLM(self.MODEL_PATH, - tensor_parallel_size=world_size, - moe_expert_parallel_size=world_size, + tensor_parallel_size=tp_size, + moe_expert_parallel_size=1, max_seq_len=4096, - max_num_tokens=4096, - max_batch_size=128, + max_batch_size=32, enable_chunked_prefill=True, kv_cache_config=kv_cache_config, - cuda_graph_config=cuda_graph_config) as llm: + cuda_graph_config=cuda_graph_config, + moe_config=moe_config) as llm: task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) - def test_fp8(self): + @pytest.mark.parametrize( + "tp_size", + [1, pytest.param(2, marks=pytest.mark.skip_less_device(2))], + ids=["tp1", "tp2"], + ) + def test_fp8(self, tp_size): model_dir = f"{self.MODEL_PATH}-FP8" # Model is being added to CI. Skip at the moment. if not os.path.exists(model_dir): pytest.skip(f"Model directory {model_dir} does not exist") - world_size = 1 kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8, enable_block_reuse=False) moe_config = MoeConfig(backend='DEEPGEMM') with LLM(model_dir, - tensor_parallel_size=world_size, - moe_expert_parallel_size=world_size, + tensor_parallel_size=tp_size, + moe_expert_parallel_size=1, max_seq_len=4096, + max_batch_size=32, enable_chunked_prefill=True, kv_cache_config=kv_cache_config, moe_config=moe_config) as llm: diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index a49644bbb479..1853c8debf09 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -191,8 +191,9 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cu accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16 -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp1] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-triton-auto] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index 1378ce4efa40..27d02c9310a7 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -175,8 +175,10 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_4B::test_eagle3 accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[mxfp8-latency] -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16 -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-CUTLASS] +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-TRTLLM] +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp1] +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp2] # disaggregated serving accuracy test accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 17a17fb673ad..62ca5f7a2cdd 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -70,9 +70,8 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM] - - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9-fp8kv=True] - - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16 - - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9-fp8kv=False] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-TRTLLM] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551 - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] @@ -269,6 +268,9 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTEDSL-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5-fp8kv=False] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp1] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-CUTLASS] - accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype - unittest/_torch/visual_gen/test_wan.py::TestWanTwoStageTransformer diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index c398f69ec883..e6aff4acb0a5 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -27,6 +27,9 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-CUTLASS] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-TRTLLM] + - accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp2] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 6ad20ff814fe..e35ba2deaf59 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -292,7 +292,9 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_piecewi accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5992113) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[use_temperature=False-attn_backend=TRTLLM] SKIP (https://nvbugs/5997547) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[ep4-mtp_nextn=0] SKIP (https://nvbugs/5997051) -accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8 SKIP (https://nvbugs/6004530) +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp1] SKIP (https://nvbugs/6004530) +accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp2] SKIP (https://nvbugs/6004530) +perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_v32_fp4_blackwell-v32_fp4_tep8_mtp3_8k1k] SKIP (https://nvbugs/5997092) unittest/_torch/modules/moe/test_moe_module.py::test_configurable_moe_multi_gpu[parallel=DEP-comm=DEEPEP-e60_k4_h2048_i1408-seq=8-dtype=torch.bfloat16-backend=TRTLLM-quant=NVFP4-routing=Renormalize] SKIP (https://nvbugs/6007285) accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_bf16_4gpu_mtp_ar SKIP (https://nvbugs/5959992) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_vswa_reuse_4gpus[two_model] SKIP (https://nvbugs/6013562) diff --git a/tests/unittest/_torch/modules/moe/moe_test_utils.py b/tests/unittest/_torch/modules/moe/moe_test_utils.py index a26d0466ee87..a61c64387205 100644 --- a/tests/unittest/_torch/modules/moe/moe_test_utils.py +++ b/tests/unittest/_torch/modules/moe/moe_test_utils.py @@ -235,12 +235,13 @@ def should_skip_trtllm( QuantAlgo.W4A16_MXFP4, QuantAlgo.W4A8_MXFP4_MXFP8, } - - if quant_algo not in trtllm_gen_quant_algos: + # Quant_algo==None (BF16 path) also falls through and must meet the should_skip_trtllm criteria + if quant_algo is not None and quant_algo not in trtllm_gen_quant_algos: return None num_experts = model_config.num_experts top_k = model_config.top_k + hidden_size = model_config.hidden_size intermediate_size = model_config.intermediate_size # Check: num_experts must be divisible by 4 @@ -258,11 +259,22 @@ def should_skip_trtllm( f"TRTLLMGenFusedMoE requires num_experts > top_k " f"(got num_experts={num_experts}, top_k={top_k})" ) + + if quant_algo is None: + if swiglu_gptoss_style: + return "TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters." + + if hidden_size % 128 != 0 or intermediate_size % 128 != 0: + return ( + "TRTLLMGenFusedMoE BF16 path requires hidden_size and intermediate_size " + f"to be multiples of 128 (got h={hidden_size}, i={intermediate_size})." + ) + return None + # W4A8_MXFP4_MXFP8 with non-128-aligned hidden_size or intermediate_size # causes block_scale_interleave_reverse to fail with # "rows of Interleaved block scales should be multiple of 128". if quant_algo == QuantAlgo.W4A8_MXFP4_MXFP8: - hidden_size = model_config.hidden_size if hidden_size % 128 != 0 or intermediate_size % 128 != 0: return ( f"TRTLLMGenFusedMoE W4A8_MXFP4_MXFP8 with non-128-aligned " @@ -959,7 +971,8 @@ def should_skip_to_accelerate_ci( all combinations run (local exhaustive testing). Rules applied (in order): - 0. Skip unquantized (quant=None) — quantized paths are the focus of CI + 0. Skip unquantized (quant=None) for most paths, but keep TRTLLM BF16 + unquantized coverage enabled. 1. e256 model: only DeepSeekV3 routing, bfloat16, seq=1, non-gptoss 2. Multi-GPU: only DEP and TTP parallel modes 3. Routing: full 6 routing methods only on (CUTLASS or TRTLLM) with NVFP4; @@ -985,8 +998,13 @@ def should_skip_to_accelerate_ci( if model_config is None: return None - # --- Rule 0: Skip gated and unquantized (quant=None) --- - if quant_algo is None and is_gated_activation(activation_type): + # --- Rule 0: Skip gated and unquantized (quant=None) for most backends --- + # Keep TRTLLM BF16 unquantized enabled to cover FlashInfer BF16 TRTLLM MoE. + if ( + quant_algo is None + and is_gated_activation(activation_type) + and not (backend_type == MoeBackendType.TRTLLM and dtype == torch.bfloat16) + ): return "[CI accel] Skip unquantized (quant=None) in CI" is_large_model = model_config.num_experts >= 256 and model_config.hidden_size >= 7168