Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions jenkins/L0_Test.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -3247,9 +3247,10 @@ def launchTestJobs(pipeline, testFilter)
"DGX_H100-4_GPUs-PyTorch-Ray-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"DGX_H100-4_GPUs-AutoDeploy-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4],
"DGX_B200-PyTorch-1": ["auto:dgx-b200-flex", "l0_b200", 1, 3, 1, 1, true],
"DGX_B200-PyTorch-2": ["auto:dgx-b200-flex", "l0_b200", 2, 3, 1, 1, true],
"DGX_B200-PyTorch-3": ["auto:dgx-b200-flex", "l0_b200", 3, 3, 1, 1, true],
"DGX_B200-PyTorch-1": ["auto:dgx-b200-flex", "l0_b200", 1, 4, 1, 1, true],
"DGX_B200-PyTorch-2": ["auto:dgx-b200-flex", "l0_b200", 2, 4, 1, 1, true],
"DGX_B200-PyTorch-3": ["auto:dgx-b200-flex", "l0_b200", 3, 4, 1, 1, true],
"DGX_B200-PyTorch-4": ["auto:dgx-b200-flex", "l0_b200", 4, 4, 1, 1, true],
"DGX_B200-AutoDeploy-1": ["auto:dgx-b200-flex", "l0_b200", 1, 1, 1, 1, true],
"DGX_B200-Triton-Post-Merge-1": ["auto:dgx-b200-flex", "l0_b200", 1, 1, 1, 1, true],
"DGX_B200-PyTorch-Post-Merge-1": ["auto:dgx-b200-flex", "l0_b200", 1, 2, 1, 1, true],
Expand Down
38 changes: 26 additions & 12 deletions tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
4. Unified EPLB integration for backends that support it
"""

import copy
from typing import Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -164,21 +165,34 @@ def __init__(
self.apply_router_weight_on_input = apply_router_weight_on_input

# ========== Create MoE Backend (Default: Cutlass) ==========
from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend, get_moe_cls
from tensorrt_llm._torch.modules.fused_moe.create_moe import (
create_moe_backend,
resolve_moe_cls,
)

# Get MoE backend class based on override_quant_config, routing_method, and model_config
moe_cls = resolve_moe_cls(
model_config,
routing_method,
self.dtype,
override_quant_config=override_quant_config,
)

# Get MoE backend class based on override_quant_config or model_config
moe_cls = get_moe_cls(model_config, override_quant_config=override_quant_config)
backend_model_config = model_config
if override_quant_config is not None:
backend_model_config = copy.deepcopy(model_config)
backend_model_config.quant_config = override_quant_config

# Call create_moe_backend with all necessary parameters
# init_load_balancer=False: Prevents backend from registering itself with load balancer
# without_comm=True: Prevents backend from initializing communication (ConfigurableMoE handles it)
# skip_create_weights_in_init=True: Prevents backend from creating weights in __init__
# because backend uses layer_idx=None and may have different expert assignments
# We will create weights after syncing attributes from ConfigurableMoE
tmp_skip_create_weights_in_init = model_config.skip_create_weights_in_init
model_config._frozen = False
model_config.skip_create_weights_in_init = True
model_config._frozen = True
tmp_skip_create_weights_in_init = backend_model_config.skip_create_weights_in_init
backend_model_config._frozen = False
backend_model_config.skip_create_weights_in_init = True
backend_model_config._frozen = True

backend = create_moe_backend(
moe_cls=moe_cls,
Expand All @@ -188,7 +202,7 @@ def __init__(
intermediate_size=self.intermediate_size,
dtype=self.dtype,
reduce_results=self.reduce_results,
model_config=model_config,
model_config=backend_model_config,
aux_stream_dict=self.aux_stream_dict,
weight_loading_mode=self.weight_loading_mode,
bias=kwargs.get("bias", False),
Expand Down Expand Up @@ -223,10 +237,10 @@ def __init__(
self.backend.expert_size_per_partition = self.expert_size_per_partition

# Create weights here, because the backend needs the layer_load_balancer info to create weights
model_config._frozen = False
model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
model_config._frozen = True
if not model_config.skip_create_weights_in_init:
backend_model_config._frozen = False
backend_model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
backend_model_config._frozen = True
if not backend_model_config.skip_create_weights_in_init:
self.backend.create_weights()

# ========== Create Communication Strategy ==========
Expand Down
45 changes: 37 additions & 8 deletions tensorrt_llm/_torch/modules/fused_moe/create_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,23 @@ def get_moe_cls(
return CutlassFusedMoE
return DenseGEMMFusedMoE
elif moe_backend.upper() == "TRTLLM":
if quant_config is not None and (
quant_config.quant_mode.has_fp8_block_scales()
or quant_config.quant_mode.has_nvfp4()
or quant_config.quant_mode.has_w4a16_mxfp4()
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
has_quant = quant_config is not None and quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True)
if has_quant and (quant_config.quant_mode.has_fp8_block_scales()
or quant_config.quant_mode.has_nvfp4()
or quant_config.quant_mode.has_w4a16_mxfp4()
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
return TRTLLMGenFusedMoE
if not has_quant and model_config.pretrained_config is not None and getattr(
model_config.pretrained_config, "torch_dtype",
None) == torch.bfloat16:
if TRTLLMGenFusedMoE._is_flashinfer_fused_moe_available():
return TRTLLMGenFusedMoE
raise RuntimeError(
"TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE with "
"trtllm_bf16_moe support, but it is not available.")
else:
logger.warning(
"TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, and w4a8_mxfp4_mxfp8. "
Expand All @@ -85,6 +94,25 @@ def get_moe_cls(
raise ValueError(f"Unsupported moe backend: {moe_backend}")


def resolve_moe_cls(
model_config: ModelConfig,
routing_method: BaseMoeRoutingMethod,
dtype: Optional[torch.dtype],
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
moe_cls = get_moe_cls(model_config, override_quant_config)

effective_quant_config = override_quant_config or model_config.quant_config
has_quant = (effective_quant_config is not None
and effective_quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True))
if (moe_cls == TRTLLMGenFusedMoE and not has_quant
and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method(
routing_method)):
return CutlassFusedMoE

return moe_cls


def create_moe_backend(
moe_cls: Type[MoE],
routing_method: BaseMoeRoutingMethod,
Expand Down Expand Up @@ -379,7 +407,8 @@ def create_moe(
pretrained_config, 'torch_dtype'):
dtype = pretrained_config.torch_dtype

moe_cls = get_moe_cls(model_config, override_quant_config)
moe_cls = resolve_moe_cls(model_config, routing_method, dtype,
override_quant_config)

enable_configurable_moe = os.environ.get("ENABLE_CONFIGURABLE_MOE",
"1") == "1"
Expand Down
136 changes: 115 additions & 21 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@

# isort: off
from .quantization import (
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod,
NVFP4TRTLLMGenFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
W4A16MXFP4TRTLLMGenFusedMoEMethod)
BF16TRTLLMGenFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethod,
NVFP4TRTLLMGenFusedMoEBaseMethod, NVFP4TRTLLMGenFusedMoEMethod,
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod)
# isort: on
from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod,
DefaultMoeRoutingMethod)
Expand Down Expand Up @@ -115,7 +115,8 @@ def can_implement(
- W4A8_MXFP4_FP8
- W4A8_MXFP4_MXFP8

Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
Unquantized BF16 path is supported only with FlashInfer fused MoE backend.
Output dtype is hardcoded to bfloat16.

Args:
quant_algo: The quantization algorithm to check (None for unquantized)
Expand Down Expand Up @@ -143,10 +144,16 @@ def can_implement(
f"TRTLLMGenFusedMoE only supports bfloat16 activation, got {dtype_activation}"
)

# TRTLLMGenFusedMoE does NOT support unquantized mode
if quant_algo is None:
return _warn_and_return(
"TRTLLMGenFusedMoE does not support unquantized mode")
if swiglu_gptoss_style:
return _warn_and_return(
"TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters."
)
if not cls._is_flashinfer_fused_moe_available():
return _warn_and_return(
"TRTLLMGenFusedMoE unquantized BF16 path requires FlashInfer fused MoE "
"with trtllm_bf16_moe support.")
return True, None

# Check if quant_algo is supported
if quant_algo not in cls._SUPPORTED_QUANT_ALGOS:
Expand Down Expand Up @@ -210,7 +217,14 @@ def __init__(

assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."

self.use_flashinfer = self._check_op_backend_support()
self.use_flashinfer = self._check_flashinfer_backend_support()
if (self.quant_config is None
or not self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True)) and not self.use_flashinfer:
raise NotImplementedError(
"TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE. "
"Please install a FlashInfer build with trtllm_bf16_moe support."
)
backend_name = "flashinfer" if self.use_flashinfer else "trtllm"
self.op_backend: MoEOpBackend = get_op_backend(backend_name)

Expand Down Expand Up @@ -292,7 +306,43 @@ def _to_trtllm_gen_activation_type(self,
else:
raise ValueError(f"Unsupported activation type: {activation_type}")

def _check_op_backend_support(self) -> bool:
@staticmethod
def _is_flashinfer_fused_moe_available() -> bool:
try:
from flashinfer.fused_moe import core as _core
except (ImportError, ModuleNotFoundError):
return False
return (hasattr(_core, "trtllm_bf16_moe")
and hasattr(_core, "trtllm_bf16_routed_moe"))

def _is_unquantized_path(self) -> bool:
return self.quant_config is None or not self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True)

@staticmethod
def _supports_flashinfer_bf16_routing_method(
routing_method: BaseMoeRoutingMethod, ) -> bool:
# FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug
return not isinstance(routing_method, DeepSeekV3MoeRoutingMethod)

def _requires_separated_routing(self) -> bool:
"""Whether this backend instance expects precomputed top-k routing."""
# FIXME: ban FlashInfer BF16 MoE direct routing as it appears to have accuracy bug
return self.use_flashinfer and self._is_unquantized_path()

def _check_flashinfer_backend_support(self) -> bool:
# For BF16 (unquantized) path, we will use FlashInfer regardless whether
# env TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER=1 is set or not as it's the only way.
if self._is_unquantized_path():
if not self._is_flashinfer_fused_moe_available():
return False
if self.activation_type != ActivationType.Swiglu:
return False
if not self._supports_flashinfer_bf16_routing_method(
self.routing_method):
return False
return True

use_flashinfer = os.environ.get("TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER",
"0")
if use_flashinfer != "1":
Expand All @@ -311,8 +361,6 @@ def _check_op_backend_support(self) -> bool:
if type(quant_method) is NVFP4TRTLLMGenFusedMoEBaseMethod:
return True

if self.quant_config is None:
return False
mode = self.quant_config.layer_quant_mode

# These quant modes are never supported via op backend
Expand Down Expand Up @@ -365,7 +413,14 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
return AlltoallMethodType.NVLinkOneSided

def _supports_load_balancer(self) -> bool:
"""TRTLLMGenFusedMoE supports load balancer."""
"""Whether separated routing (top-k outside the kernel) is used.

ConfigurableMoE uses this flag to decide whether routing is separated
(top-k ids/scales computed outside backend) or fused inside the kernel.
BF16 FlashInfer path always requires separated routing.
"""
if self._requires_separated_routing():
return True
return self.use_dp and self.parallel_size > 1

@cached_property
Expand All @@ -375,9 +430,17 @@ def enable_alltoall(self):
return self.alltoall_method_type != AlltoallMethodType.NotEnabled

def _check_configs(self):
assert self.has_deepseek_fp8_block_scales \
assert not self.has_any_quant \
or self.has_deepseek_fp8_block_scales \
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, \
"TRTLLMGenFusedMoE only supports bf16 (FlashInfer), fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."

if not self.has_any_quant:
assert self.activation_type == ActivationType.Swiglu, \
"TRTLLMGenFusedMoE BF16 path only supports Swiglu activation."
assert not self.bias and self.swiglu_alpha is None and self.swiglu_beta is None and self.swiglu_limit is None, \
"TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters."

if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants."
Expand Down Expand Up @@ -405,8 +468,7 @@ def _get_quant_method(self):
f"Unsupported quantization method by TRTLLMGenFusedMoE: {self.quant_config.quant_mode}"
)
else:
raise NotImplementedError(
"TRTLLMGenFusedMoE doesn't support fp16/bf16/fp32 MoE.")
return BF16TRTLLMGenFusedMoEMethod()

def create_weights(self):
if self._weights_created:
Expand Down Expand Up @@ -467,6 +529,8 @@ def quantize_input(self, x, post_quant_comm: bool = True):
- scaling_vector_size is typically the group size for block-wise quantization
"""
x_sf = None
if not self.has_any_quant:
return x, x_sf
if self.has_w4a8_mxfp4_fp8:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
Expand Down Expand Up @@ -526,7 +590,7 @@ def quantize_input(self, x, post_quant_comm: bool = True):
return x, x_sf

def supports_moe_output_in_alltoall_workspace(self):
return True
return self.has_any_quant and not self.use_flashinfer

def run_moe(
self,
Expand All @@ -542,8 +606,8 @@ def run_moe(
Run MoE computation with TRTLLMGen backend.

This method encapsulates the core MoE computation logic, handling different
quantization schemes (fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8,
w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8).
quantization schemes (bf16, fp8_block_scales, nvfp4, w4a16_mxfp4,
w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8).

Args:
# Standard MoE interface parameters:
Expand Down Expand Up @@ -592,7 +656,37 @@ def run_moe(
) == 2, f"x_sf should be 2D tensor, got shape {x_sf.shape}"
x_sf = x_sf.flatten()

if self.has_deepseek_fp8_block_scales:
if not self.has_any_quant:
result = self.op_backend.run_bf16_moe(
router_logits=router_logits,
routing_bias=routing_bias,
hidden_states=x,
gemm1_weights=self.w3_w1_weight,
gemm2_weights=self.w2_weight,
num_experts=self.num_slots,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.slot_start,
local_num_experts=self.expert_size_per_partition,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method.routing_method_type,
topk_weights=token_final_scales,
topk_ids=token_selected_experts,
gated_act_type=self._to_trtllm_gen_activation_type(
self.activation_type),
output=moe_output,
use_shuffled_weight=getattr(self.quant_method,
"use_shuffled_weight", False),
weight_layout=getattr(self.quant_method, "weight_layout", 0),
do_finalize=do_finalize,
)
if not do_finalize:
assert not self.reduce_results, "reduce_results must be False when do_finalize is False"
return result
final_hidden_states = result
elif self.has_deepseek_fp8_block_scales:
assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False"
# fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+
if x_sf is None:
Expand Down
Loading
Loading