22# SPDX-License-Identifier: Apache-2.0
33"""Validation helpers for routed-expert (MoE) LoRA.
44
5- MoE LoRA is supported only on the Cutlass backend with unquantized fp16/bf16
6- base weights. This module provides a single helper, `check_moe_lora_supported`,
7- that callers (typically the MoE factory in `create_moe.py`) can invoke at
8- construction time so that unsupported combinations fail loudly instead of
9- silently dropping the LoRA contribution at runtime.
10-
11- Runtime-only rejections (min-latency mode, alltoall, FP4 base, CUDA-graph
12- without slot pointers) are enforced in the C++ thop / runtime call paths and
13- are NOT re-checked here.
5+ MoE LoRA is supported only on the Cutlass backend with unquantized fp16/bf16 or
6+ per-tensor FP8 (qdq) base weights. This module provides a single helper,
7+ `check_moe_lora_supported`, that callers (typically the MoE factory in
8+ `create_moe.py`) can invoke at construction time so that unsupported
9+ combinations fail loudly instead of silently dropping the LoRA contribution at
10+ runtime.
11+
12+ Runtime-only rejections (min-latency mode, alltoall, CUDA-graph without slot
13+ pointers) are enforced in the C++ thop / runtime call paths and are NOT
14+ re-checked here.
1415"""
1516
1617from typing import Iterable , Optional , Set
1718
1819from tensorrt_llm .lora_helper import LoraConfig
20+ from tensorrt_llm .quantization .mode import QuantMode
21+
22+ # Base-weight quantization bits that MoE LoRA does not support. Only per-tensor
23+ # FP8 (qdq) composes with MoE LoRA; any of the bits below makes the combination
24+ # unsupported.
25+ _UNSUPPORTED_QUANT = (
26+ QuantMode .INT4_WEIGHTS
27+ | QuantMode .INT8_WEIGHTS
28+ | QuantMode .ACTIVATIONS
29+ | QuantMode .FP8_ROWWISE
30+ | QuantMode .FP8_1x128_128x128
31+ | QuantMode .W4A8_QSERVE
32+ | QuantMode .NVFP4
33+ | QuantMode .W4A8_NVFP4_FP8
34+ | QuantMode .W4A8_MXFP4_FP8
35+ | QuantMode .W4A16_MXFP4
36+ | QuantMode .W4A8_MXFP4_MXFP8
37+ | QuantMode .MXFP8
38+ )
1939
2040# TRTLLM module names that map to routed-expert MoE projections.
2141MOE_LORA_MODULE_NAMES : Set [str ] = {"moe_h_to_4h" , "moe_4h_to_h" , "moe_gate" }
@@ -35,6 +55,30 @@ def has_moe_lora_targets(lora_config: Optional[LoraConfig]) -> bool:
3555 )
3656
3757
58+ def _is_supported_quant (quant_mode ) -> bool :
59+ """Return True iff the only base-weight quantization is per-tensor FP8 (qdq).
60+
61+ The CUTLASS MoE LoRA kernel runs the LoRA GEMM on the bf16/fp16 activations,
62+ dequantizing the per-tensor FP8 (qdq) activations to the backbone type before
63+ the LoRA GEMM. FP8 block-scale, NVFP4, and the integer / MXFP4 / W4A8 formats
64+ in `_UNSUPPORTED_QUANT` have no such path and stay rejected.
65+ """
66+ if quant_mode is None :
67+ return False
68+ # quant_mode may be a QuantMode, or a QuantModeWrapper that holds a per-layer
69+ # list of QuantModes and forwards has_* queries. Normalize to the underlying
70+ # QuantMode(s) so the bitwise check works in either case.
71+ objs = getattr (quant_mode , "objs" , None )
72+ modes = objs if objs is not None else [quant_mode ]
73+ has_supported = False
74+ for mode in modes :
75+ if mode .has_fp8_qdq ():
76+ has_supported = True
77+ if bool (mode & _UNSUPPORTED_QUANT ):
78+ return False
79+ return has_supported
80+
81+
3882def check_moe_lora_supported (
3983 * ,
4084 moe_backend_name : str ,
@@ -55,7 +99,8 @@ def check_moe_lora_supported(
5599
56100 Constraints:
57101 - MoE backend MUST be CUTLASS.
58- - Base weight quantization MUST be off (no FP8 / FP4 / INT8 / INT4 / W4A8 ...).
102+ - Base weight quantization MUST be off (fp16/bf16) or per-tensor FP8
103+ (qdq). FP8 block-scale / FP4 / INT8 / INT4 / W4A8 ... are rejected.
59104
60105 Other constraints (alltoall, min-latency, FP4, CUDA-graph) are enforced at
61106 runtime; we do not pre-check them here because they depend on per-call
@@ -83,10 +128,10 @@ def check_moe_lora_supported(
83128 except TypeError :
84129 # Older signatures may not accept the kwarg; fall back.
85130 is_quantized = bool (quant_mode .has_any_quant ())
86- if is_quantized :
131+ if is_quantized and not _is_supported_quant ( quant_mode ) :
87132 raise ValueError (
88133 f"{ prefix } Routed-expert MoE LoRA only supports unquantized "
89- f"fp16/bf16 base weights; got quant_mode= { quant_mode } . "
90- " FP8/FP4/INT4/INT8 base weights combined with MoE LoRA are not "
91- "supported."
134+ f"fp16/bf16 or per-tensor FP8 (qdq) base weights; got "
135+ f"quant_mode= { quant_mode } . FP8 block-scale / FP4 / INT4 / INT8 / "
136+ "W4A8 base weights combined with MoE LoRA are not supported."
92137 )
0 commit comments