Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,15 +606,20 @@ class FusedMoeRunner : public torch::CustomClassHolder
TORCH_CHECK(!(lora_per_request && lora_slot_indexed),
"MoE LoRA: the per-request (fc1_lora_ranks, ...) and slot-indexed (fc1_slot_lora_ranks, ..., "
"token_to_slot) input schemas are mutually exclusive. Provide exactly one, not both.");
// Conservative rejections (min-latency, alltoall, quant, graph capture).
// Conservative rejections (min-latency, alltoall, unsupported quant, graph capture).
TORCH_CHECK(!min_latency_mode, "MoE LoRA is not supported in min-latency mode.");
TORCH_CHECK(!enable_alltoall,
"MoE LoRA is not supported with alltoall: the per-token adapter pointer arrays do not survive "
"cross-rank token reshuffling.");
TORCH_CHECK(mActivationDtype == c10::ScalarType::Half || mActivationDtype == c10::ScalarType::BFloat16,
"MoE LoRA only supports fp16 and bf16 activation dtypes.");
TORCH_CHECK(mWeightDtype == c10::ScalarType::Half || mWeightDtype == c10::ScalarType::BFloat16,
"MoE LoRA only supports unquantized fp16/bf16 expert weights.");
bool const is_per_tensor_fp8 = isFp8Quant();
TORCH_CHECK(mActivationDtype == c10::ScalarType::Half || mActivationDtype == c10::ScalarType::BFloat16
|| is_per_tensor_fp8,
"MoE LoRA only supports fp16, bf16, or per-tensor FP8 (qdq) base weights. FP8 block-scale, NVFP4, "
"MXFP8, and integer quant are not supported.");
TORCH_CHECK(
mWeightDtype == c10::ScalarType::Half || mWeightDtype == c10::ScalarType::BFloat16 || is_per_tensor_fp8,
"MoE LoRA supports unquantized fp16/bf16 or per-tensor FP8 (qdq) base expert weights only "
"(LoRA adapters are always fp16/bf16).");
// CUDA-graph capture is only safe on the device LoRA path. The
// legacy host path performs a host-side cudaEventSynchronize and
// per-token pointer expansion in setupLoraWorkspace, plus host-side
Expand Down Expand Up @@ -1274,14 +1279,20 @@ class FusedMoeRunner : public torch::CustomClassHolder
}

// Map a torch dtype to the TRT-LLM nvinfer1::DataType expected by LoraImpl.
static nvinfer1::DataType loraTypeFromActDtype(c10::ScalarType dtype)
nvinfer1::DataType loraTypeFromActDtype(c10::ScalarType dtype) const
{
switch (dtype)
{
case c10::ScalarType::Half: return nvinfer1::DataType::kHALF;
case c10::ScalarType::Float: return nvinfer1::DataType::kFLOAT;
#ifdef ENABLE_BF16
case c10::ScalarType::BFloat16: return nvinfer1::DataType::kBF16;
#endif
#ifdef ENABLE_FP8
case c10::ScalarType::Float8_e4m3fn:
TORCH_CHECK(mOutputDtype != c10::ScalarType::Float8_e4m3fn,
"MoE LoRA with FP8 base activations requires an fp16/bf16 output (LoRA compute) dtype.");
return loraTypeFromActDtype(mOutputDtype);
#endif
default: C10_THROW_ERROR_FORMATTED(Error, "MoE LoRA only supports fp16/bf16/fp32 activation dtype.");
}
Expand Down
14 changes: 11 additions & 3 deletions tensorrt_llm/_torch/models/modeling_mixtral.py
Comment thread
brb-nv marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
lora_params: Optional[dict] = None,
) -> torch.Tensor:
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)
use_dp_padding=False,
lora_params=lora_params)
return final_hidden_states


Expand Down Expand Up @@ -141,6 +143,7 @@ def forward(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
Expand All @@ -155,13 +158,16 @@ def forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
lora_params=lora_params,
**kwargs,
)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states, attn_metadata)
hidden_states = self.block_sparse_moe(hidden_states,
attn_metadata,
lora_params=lora_params)
return hidden_states, residual


Expand Down Expand Up @@ -195,6 +201,7 @@ def forward(
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand All @@ -212,7 +219,8 @@ def forward(
hidden_states, residual = decoder_layer(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual)
residual=residual,
lora_params=lora_params)

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
Expand Down
73 changes: 59 additions & 14 deletions tensorrt_llm/_torch/peft/lora/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,40 @@
# SPDX-License-Identifier: Apache-2.0
"""Validation helpers for routed-expert (MoE) LoRA.

MoE LoRA is supported only on the Cutlass backend with unquantized fp16/bf16
base weights. This module provides a single helper, `check_moe_lora_supported`,
that callers (typically the MoE factory in `create_moe.py`) can invoke at
construction time so that unsupported combinations fail loudly instead of
silently dropping the LoRA contribution at runtime.

Runtime-only rejections (min-latency mode, alltoall, FP4 base, CUDA-graph
without slot pointers) are enforced in the C++ thop / runtime call paths and
are NOT re-checked here.
MoE LoRA is supported only on the Cutlass backend with unquantized fp16/bf16 or
per-tensor FP8 (qdq) base weights. This module provides a single helper,
`check_moe_lora_supported`, that callers (typically the MoE factory in
`create_moe.py`) can invoke at construction time so that unsupported
combinations fail loudly instead of silently dropping the LoRA contribution at
runtime.

Runtime-only rejections (min-latency mode, alltoall, CUDA-graph without slot
pointers) are enforced in the C++ thop / runtime call paths and are NOT
re-checked here.
"""

from typing import Iterable, Optional, Set

from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.quantization.mode import QuantMode

# Base-weight quantization bits that MoE LoRA does not support. Only per-tensor
# FP8 (qdq) composes with MoE LoRA; any of the bits below makes the combination
# unsupported.
_UNSUPPORTED_QUANT = (
QuantMode.INT4_WEIGHTS
| QuantMode.INT8_WEIGHTS
| QuantMode.ACTIVATIONS
| QuantMode.FP8_ROWWISE
| QuantMode.FP8_1x128_128x128
| QuantMode.W4A8_QSERVE
| QuantMode.NVFP4
| QuantMode.W4A8_NVFP4_FP8
| QuantMode.W4A8_MXFP4_FP8
| QuantMode.W4A16_MXFP4
| QuantMode.W4A8_MXFP4_MXFP8
| QuantMode.MXFP8
)

# TRTLLM module names that map to routed-expert MoE projections.
MOE_LORA_MODULE_NAMES: Set[str] = {"moe_h_to_4h", "moe_4h_to_h", "moe_gate"}
Expand All @@ -35,6 +55,30 @@ def has_moe_lora_targets(lora_config: Optional[LoraConfig]) -> bool:
)


def _is_supported_quant(quant_mode) -> bool:
"""Return True iff the only base-weight quantization is per-tensor FP8 (qdq).

The CUTLASS MoE LoRA kernel runs the LoRA GEMM on the bf16/fp16 activations,
dequantizing the per-tensor FP8 (qdq) activations to the backbone type before
the LoRA GEMM. FP8 block-scale, NVFP4, and the integer / MXFP4 / W4A8 formats
in `_UNSUPPORTED_QUANT` have no such path and stay rejected.
"""
if quant_mode is None:
return False
# quant_mode may be a QuantMode, or a QuantModeWrapper that holds a per-layer
# list of QuantModes and forwards has_* queries. Normalize to the underlying
# QuantMode(s) so the bitwise check works in either case.
objs = getattr(quant_mode, "objs", None)
modes = objs if objs is not None else [quant_mode]
has_supported = False
for mode in modes:
if mode.has_fp8_qdq():
has_supported = True
if bool(mode & _UNSUPPORTED_QUANT):
return False
return has_supported


def check_moe_lora_supported(
*,
moe_backend_name: str,
Expand All @@ -55,7 +99,8 @@ def check_moe_lora_supported(

Constraints:
- MoE backend MUST be CUTLASS.
- Base weight quantization MUST be off (no FP8 / FP4 / INT8 / INT4 / W4A8 ...).
- Base weight quantization MUST be off (fp16/bf16) or per-tensor FP8
(qdq). FP8 block-scale / FP4 / INT8 / INT4 / W4A8 ... are rejected.

Other constraints (alltoall, min-latency, FP4, CUDA-graph) are enforced at
runtime; we do not pre-check them here because they depend on per-call
Expand Down Expand Up @@ -83,10 +128,10 @@ def check_moe_lora_supported(
except TypeError:
# Older signatures may not accept the kwarg; fall back.
is_quantized = bool(quant_mode.has_any_quant())
if is_quantized:
if is_quantized and not _is_supported_quant(quant_mode):
raise ValueError(
f"{prefix}Routed-expert MoE LoRA only supports unquantized "
f"fp16/bf16 base weights; got quant_mode={quant_mode}. "
"FP8/FP4/INT4/INT8 base weights combined with MoE LoRA are not "
"supported."
f"fp16/bf16 or per-tensor FP8 (qdq) base weights; got "
f"quant_mode={quant_mode}. FP8 block-scale / FP4 / INT4 / INT8 / "
"W4A8 base weights combined with MoE LoRA are not supported."
)
Loading
Loading