Skip to content

Commit 5b0a3fb

Browse files
committed
ban DeepSeek routing w/ BF16 TRTLLMGenFusedMoE; bug inside Flashinfer
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
1 parent 6b67c8e commit 5b0a3fb

3 files changed

Lines changed: 56 additions & 13 deletions

File tree

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
4. Unified EPLB integration for backends that support it
2929
"""
3030

31+
import copy
3132
from typing import Dict, List, Optional, Tuple, Union
3233

3334
import torch
@@ -162,21 +163,34 @@ def __init__(
162163
self.apply_router_weight_on_input = apply_router_weight_on_input
163164

164165
# ========== Create MoE Backend (Default: Cutlass) ==========
165-
from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend, get_moe_cls
166+
from tensorrt_llm._torch.modules.fused_moe.create_moe import (
167+
create_moe_backend,
168+
resolve_moe_cls,
169+
)
170+
171+
# Get MoE backend class based on override_quant_config, routing_method, and model_config
172+
moe_cls = resolve_moe_cls(
173+
model_config,
174+
routing_method,
175+
self.dtype,
176+
override_quant_config=override_quant_config,
177+
)
166178

167-
# Get MoE backend class based on override_quant_config or model_config
168-
moe_cls = get_moe_cls(model_config, override_quant_config=override_quant_config)
179+
backend_model_config = model_config
180+
if override_quant_config is not None:
181+
backend_model_config = copy.deepcopy(model_config)
182+
backend_model_config.quant_config = override_quant_config
169183

170184
# Call create_moe_backend with all necessary parameters
171185
# init_load_balancer=False: Prevents backend from registering itself with load balancer
172186
# without_comm=True: Prevents backend from initializing communication (ConfigurableMoE handles it)
173187
# skip_create_weights_in_init=True: Prevents backend from creating weights in __init__
174188
# because backend uses layer_idx=None and may have different expert assignments
175189
# We will create weights after syncing attributes from ConfigurableMoE
176-
tmp_skip_create_weights_in_init = model_config.skip_create_weights_in_init
177-
model_config._frozen = False
178-
model_config.skip_create_weights_in_init = True
179-
model_config._frozen = True
190+
tmp_skip_create_weights_in_init = backend_model_config.skip_create_weights_in_init
191+
backend_model_config._frozen = False
192+
backend_model_config.skip_create_weights_in_init = True
193+
backend_model_config._frozen = True
180194

181195
backend = create_moe_backend(
182196
moe_cls=moe_cls,
@@ -186,7 +200,7 @@ def __init__(
186200
intermediate_size=self.intermediate_size,
187201
dtype=self.dtype,
188202
reduce_results=self.reduce_results,
189-
model_config=model_config,
203+
model_config=backend_model_config,
190204
aux_stream_dict=self.aux_stream_dict,
191205
weight_loading_mode=self.weight_loading_mode,
192206
bias=kwargs.get("bias", False),
@@ -221,10 +235,10 @@ def __init__(
221235
self.backend.expert_size_per_partition = self.expert_size_per_partition
222236

223237
# Create weights here, because the backend needs the layer_load_balancer info to create weights
224-
model_config._frozen = False
225-
model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
226-
model_config._frozen = True
227-
if not model_config.skip_create_weights_in_init:
238+
backend_model_config._frozen = False
239+
backend_model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
240+
backend_model_config._frozen = True
241+
if not backend_model_config.skip_create_weights_in_init:
228242
self.backend.create_weights()
229243

230244
# ========== Create Communication Strategy ==========

tensorrt_llm/_torch/modules/fused_moe/create_moe.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,25 @@ def get_moe_cls(
7777
raise ValueError(f"Unsupported moe backend: {moe_backend}")
7878

7979

80+
def resolve_moe_cls(
81+
model_config: ModelConfig,
82+
routing_method: BaseMoeRoutingMethod,
83+
dtype: Optional[torch.dtype],
84+
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
85+
moe_cls = get_moe_cls(model_config, override_quant_config)
86+
87+
effective_quant_config = override_quant_config or model_config.quant_config
88+
has_quant = (effective_quant_config is not None
89+
and effective_quant_config.layer_quant_mode.has_any_quant(
90+
exclude_kv_cache=True))
91+
if (moe_cls == TRTLLMGenFusedMoE and not has_quant
92+
and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method(
93+
routing_method)):
94+
return CutlassFusedMoE
95+
96+
return moe_cls
97+
98+
8099
def create_moe_backend(
81100
moe_cls: Type[MoE],
82101
routing_method: BaseMoeRoutingMethod,
@@ -353,7 +372,8 @@ def create_moe(
353372
pretrained_config, 'torch_dtype'):
354373
dtype = pretrained_config.torch_dtype
355374

356-
moe_cls = get_moe_cls(model_config, override_quant_config)
375+
moe_cls = resolve_moe_cls(model_config, routing_method, dtype,
376+
override_quant_config)
357377

358378
enable_configurable_moe = os.environ.get("ENABLE_CONFIGURABLE_MOE",
359379
"1") == "1"

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,12 @@ def _is_unquantized_path(self) -> bool:
318318
return self.quant_config is None or not self.quant_config.layer_quant_mode.has_any_quant(
319319
exclude_kv_cache=True)
320320

321+
@staticmethod
322+
def _supports_flashinfer_bf16_routing_method(
323+
routing_method: BaseMoeRoutingMethod, ) -> bool:
324+
# FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug
325+
return not isinstance(routing_method, DeepSeekV3MoeRoutingMethod)
326+
321327
def _requires_separated_routing(self) -> bool:
322328
"""Whether this backend instance expects precomputed top-k routing."""
323329
# FIXME: ban FlashInfer BF16 MoE direct routing as it appears to have accuracy bug
@@ -331,6 +337,9 @@ def _check_flashinfer_backend_support(self) -> bool:
331337
return False
332338
if self.activation_type != ActivationType.Swiglu:
333339
return False
340+
if not self._supports_flashinfer_bf16_routing_method(
341+
self.routing_method):
342+
return False
334343
return True
335344

336345
use_flashinfer = os.environ.get("TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER",

0 commit comments

Comments
 (0)