Skip to content

Commit ca2bb08

Browse files
committed
[None][chore] Allow fp8 per-tensor base weights for MoE LoRA
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
1 parent c9b6518 commit ca2bb08

11 files changed

Lines changed: 824 additions & 239 deletions

File tree

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -606,15 +606,20 @@ class FusedMoeRunner : public torch::CustomClassHolder
606606
TORCH_CHECK(!(lora_per_request && lora_slot_indexed),
607607
"MoE LoRA: the per-request (fc1_lora_ranks, ...) and slot-indexed (fc1_slot_lora_ranks, ..., "
608608
"token_to_slot) input schemas are mutually exclusive. Provide exactly one, not both.");
609-
// Conservative rejections (min-latency, alltoall, quant, graph capture).
609+
// Conservative rejections (min-latency, alltoall, unsupported quant, graph capture).
610610
TORCH_CHECK(!min_latency_mode, "MoE LoRA is not supported in min-latency mode.");
611611
TORCH_CHECK(!enable_alltoall,
612612
"MoE LoRA is not supported with alltoall: the per-token adapter pointer arrays do not survive "
613613
"cross-rank token reshuffling.");
614-
TORCH_CHECK(mActivationDtype == c10::ScalarType::Half || mActivationDtype == c10::ScalarType::BFloat16,
615-
"MoE LoRA only supports fp16 and bf16 activation dtypes.");
616-
TORCH_CHECK(mWeightDtype == c10::ScalarType::Half || mWeightDtype == c10::ScalarType::BFloat16,
617-
"MoE LoRA only supports unquantized fp16/bf16 expert weights.");
614+
bool const is_per_tensor_fp8 = isFp8Quant();
615+
TORCH_CHECK(mActivationDtype == c10::ScalarType::Half || mActivationDtype == c10::ScalarType::BFloat16
616+
|| is_per_tensor_fp8,
617+
"MoE LoRA only supports fp16, bf16, or per-tensor FP8 (qdq) base weights. FP8 block-scale, NVFP4, "
618+
"MXFP8, and integer quant are not supported.");
619+
TORCH_CHECK(
620+
mWeightDtype == c10::ScalarType::Half || mWeightDtype == c10::ScalarType::BFloat16 || is_per_tensor_fp8,
621+
"MoE LoRA supports unquantized fp16/bf16 or per-tensor FP8 (qdq) base expert weights only "
622+
"(LoRA adapters are always fp16/bf16).");
618623
// CUDA-graph capture is only safe on the device LoRA path. The
619624
// legacy host path performs a host-side cudaEventSynchronize and
620625
// per-token pointer expansion in setupLoraWorkspace, plus host-side
@@ -1274,14 +1279,20 @@ class FusedMoeRunner : public torch::CustomClassHolder
12741279
}
12751280

12761281
// Map a torch dtype to the TRT-LLM nvinfer1::DataType expected by LoraImpl.
1277-
static nvinfer1::DataType loraTypeFromActDtype(c10::ScalarType dtype)
1282+
nvinfer1::DataType loraTypeFromActDtype(c10::ScalarType dtype) const
12781283
{
12791284
switch (dtype)
12801285
{
12811286
case c10::ScalarType::Half: return nvinfer1::DataType::kHALF;
12821287
case c10::ScalarType::Float: return nvinfer1::DataType::kFLOAT;
12831288
#ifdef ENABLE_BF16
12841289
case c10::ScalarType::BFloat16: return nvinfer1::DataType::kBF16;
1290+
#endif
1291+
#ifdef ENABLE_FP8
1292+
case c10::ScalarType::Float8_e4m3fn:
1293+
TORCH_CHECK(mOutputDtype != c10::ScalarType::Float8_e4m3fn,
1294+
"MoE LoRA with FP8 base activations requires an fp16/bf16 output (LoRA compute) dtype.");
1295+
return loraTypeFromActDtype(mOutputDtype);
12851296
#endif
12861297
default: C10_THROW_ERROR_FORMATTED(Error, "MoE LoRA only supports fp16/bf16/fp32 activation dtype.");
12871298
}

tensorrt_llm/_torch/models/modeling_mixtral.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,16 @@ def forward(
6161
self,
6262
hidden_states: torch.Tensor,
6363
attn_metadata: AttentionMetadata,
64+
lora_params: Optional[dict] = None,
6465
) -> torch.Tensor:
6566
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
6667
router_logits = self.gate(hidden_states)
6768
final_hidden_states = self.experts(
6869
hidden_states,
6970
router_logits,
7071
all_rank_num_tokens=all_rank_num_tokens,
71-
use_dp_padding=False)
72+
use_dp_padding=False,
73+
lora_params=lora_params)
7274
return final_hidden_states
7375

7476

@@ -141,6 +143,7 @@ def forward(
141143
hidden_states: torch.Tensor,
142144
attn_metadata: AttentionMetadata,
143145
residual: Optional[torch.Tensor],
146+
lora_params: Optional[dict] = None,
144147
**kwargs,
145148
) -> torch.Tensor:
146149
if residual is None:
@@ -155,13 +158,16 @@ def forward(
155158
position_ids=position_ids,
156159
hidden_states=hidden_states,
157160
attn_metadata=attn_metadata,
161+
lora_params=lora_params,
158162
**kwargs,
159163
)
160164

161165
# Fully Connected
162166
hidden_states, residual = self.post_attention_layernorm(
163167
hidden_states, residual)
164-
hidden_states = self.block_sparse_moe(hidden_states, attn_metadata)
168+
hidden_states = self.block_sparse_moe(hidden_states,
169+
attn_metadata,
170+
lora_params=lora_params)
165171
return hidden_states, residual
166172

167173

@@ -195,6 +201,7 @@ def forward(
195201
input_ids: Optional[torch.IntTensor] = None,
196202
position_ids: Optional[torch.IntTensor] = None,
197203
inputs_embeds: Optional[torch.FloatTensor] = None,
204+
lora_params: Optional[dict] = None,
198205
**kwargs,
199206
) -> torch.Tensor:
200207
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -212,7 +219,8 @@ def forward(
212219
hidden_states, residual = decoder_layer(position_ids=position_ids,
213220
hidden_states=hidden_states,
214221
attn_metadata=attn_metadata,
215-
residual=residual)
222+
residual=residual,
223+
lora_params=lora_params)
216224

217225
hidden_states, _ = self.norm(hidden_states, residual)
218226
return hidden_states

tensorrt_llm/_torch/peft/lora/validation.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,40 @@
22
# SPDX-License-Identifier: Apache-2.0
33
"""Validation helpers for routed-expert (MoE) LoRA.
44
5-
MoE LoRA is supported only on the Cutlass backend with unquantized fp16/bf16
6-
base weights. This module provides a single helper, `check_moe_lora_supported`,
7-
that callers (typically the MoE factory in `create_moe.py`) can invoke at
8-
construction time so that unsupported combinations fail loudly instead of
9-
silently dropping the LoRA contribution at runtime.
10-
11-
Runtime-only rejections (min-latency mode, alltoall, FP4 base, CUDA-graph
12-
without slot pointers) are enforced in the C++ thop / runtime call paths and
13-
are NOT re-checked here.
5+
MoE LoRA is supported only on the Cutlass backend with unquantized fp16/bf16 or
6+
per-tensor FP8 (qdq) base weights. This module provides a single helper,
7+
`check_moe_lora_supported`, that callers (typically the MoE factory in
8+
`create_moe.py`) can invoke at construction time so that unsupported
9+
combinations fail loudly instead of silently dropping the LoRA contribution at
10+
runtime.
11+
12+
Runtime-only rejections (min-latency mode, alltoall, CUDA-graph without slot
13+
pointers) are enforced in the C++ thop / runtime call paths and are NOT
14+
re-checked here.
1415
"""
1516

1617
from typing import Iterable, Optional, Set
1718

1819
from tensorrt_llm.lora_helper import LoraConfig
20+
from tensorrt_llm.quantization.mode import QuantMode
21+
22+
# Base-weight quantization bits that MoE LoRA does not support. Only per-tensor
23+
# FP8 (qdq) composes with MoE LoRA; any of the bits below makes the combination
24+
# unsupported.
25+
_UNSUPPORTED_QUANT = (
26+
QuantMode.INT4_WEIGHTS
27+
| QuantMode.INT8_WEIGHTS
28+
| QuantMode.ACTIVATIONS
29+
| QuantMode.FP8_ROWWISE
30+
| QuantMode.FP8_1x128_128x128
31+
| QuantMode.W4A8_QSERVE
32+
| QuantMode.NVFP4
33+
| QuantMode.W4A8_NVFP4_FP8
34+
| QuantMode.W4A8_MXFP4_FP8
35+
| QuantMode.W4A16_MXFP4
36+
| QuantMode.W4A8_MXFP4_MXFP8
37+
| QuantMode.MXFP8
38+
)
1939

2040
# TRTLLM module names that map to routed-expert MoE projections.
2141
MOE_LORA_MODULE_NAMES: Set[str] = {"moe_h_to_4h", "moe_4h_to_h", "moe_gate"}
@@ -35,6 +55,30 @@ def has_moe_lora_targets(lora_config: Optional[LoraConfig]) -> bool:
3555
)
3656

3757

58+
def _is_supported_quant(quant_mode) -> bool:
59+
"""Return True iff the only base-weight quantization is per-tensor FP8 (qdq).
60+
61+
The CUTLASS MoE LoRA kernel runs the LoRA GEMM on the bf16/fp16 activations,
62+
dequantizing the per-tensor FP8 (qdq) activations to the backbone type before
63+
the LoRA GEMM. FP8 block-scale, NVFP4, and the integer / MXFP4 / W4A8 formats
64+
in `_UNSUPPORTED_QUANT` have no such path and stay rejected.
65+
"""
66+
if quant_mode is None:
67+
return False
68+
# quant_mode may be a QuantMode, or a QuantModeWrapper that holds a per-layer
69+
# list of QuantModes and forwards has_* queries. Normalize to the underlying
70+
# QuantMode(s) so the bitwise check works in either case.
71+
objs = getattr(quant_mode, "objs", None)
72+
modes = objs if objs is not None else [quant_mode]
73+
has_supported = False
74+
for mode in modes:
75+
if mode.has_fp8_qdq():
76+
has_supported = True
77+
if bool(mode & _UNSUPPORTED_QUANT):
78+
return False
79+
return has_supported
80+
81+
3882
def check_moe_lora_supported(
3983
*,
4084
moe_backend_name: str,
@@ -55,7 +99,8 @@ def check_moe_lora_supported(
5599
56100
Constraints:
57101
- MoE backend MUST be CUTLASS.
58-
- Base weight quantization MUST be off (no FP8 / FP4 / INT8 / INT4 / W4A8 ...).
102+
- Base weight quantization MUST be off (fp16/bf16) or per-tensor FP8
103+
(qdq). FP8 block-scale / FP4 / INT8 / INT4 / W4A8 ... are rejected.
59104
60105
Other constraints (alltoall, min-latency, FP4, CUDA-graph) are enforced at
61106
runtime; we do not pre-check them here because they depend on per-call
@@ -83,10 +128,10 @@ def check_moe_lora_supported(
83128
except TypeError:
84129
# Older signatures may not accept the kwarg; fall back.
85130
is_quantized = bool(quant_mode.has_any_quant())
86-
if is_quantized:
131+
if is_quantized and not _is_supported_quant(quant_mode):
87132
raise ValueError(
88133
f"{prefix}Routed-expert MoE LoRA only supports unquantized "
89-
f"fp16/bf16 base weights; got quant_mode={quant_mode}. "
90-
"FP8/FP4/INT4/INT8 base weights combined with MoE LoRA are not "
91-
"supported."
134+
f"fp16/bf16 or per-tensor FP8 (qdq) base weights; got "
135+
f"quant_mode={quant_mode}. FP8 block-scale / FP4 / INT4 / INT8 / "
136+
"W4A8 base weights combined with MoE LoRA are not supported."
92137
)

0 commit comments

Comments
 (0)