Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,19 @@ class FusedMoeRunner : public torch::CustomClassHolder
{
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
// [num_experts, inter_size, hidden_size]
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
// Mirror the non-woq else-branch below: gated activations (Swiglu/Geglu) require fc1's
// intermediate dim to be 2x fc2's (one half each for gate and up), while non-gated
// activations (Relu2/Identity/ReLU/SiLU/Gelu, e.g. Nemotron-H) require them to be equal.
if (isGatedActivation(base_activation_type))
{
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
}
else
{
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier,
"fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.");
}
}
else
{
Expand Down Expand Up @@ -771,8 +782,6 @@ class FusedMoeRunner : public torch::CustomClassHolder
}
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");

TORCH_CHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant(),
"Block-scaling factors provided for non block-scaling quantization");
Expand Down Expand Up @@ -815,6 +824,38 @@ class FusedMoeRunner : public torch::CustomClassHolder
reinterpret_cast<float const*>(swiglu_alpha.has_value() ? swiglu_alpha.value().const_data_ptr() : nullptr),
reinterpret_cast<float const*>(swiglu_beta.has_value() ? swiglu_beta.value().const_data_ptr() : nullptr),
reinterpret_cast<float const*>(swiglu_limit.has_value() ? swiglu_limit.value().const_data_ptr() : nullptr));

// Validate the fc1/fc2 inter-size relationship now that the activation type (gated vs
// non-gated) is finalized. INT8-woq uses a transposed weight layout, so its fc1/fc2 dim
// ordering differs from the non-woq path; both mirror the gated/non-gated split used in
// runMoe(). Gated activations (Swiglu/Geglu) require fc1's intermediate dim to be 2x fc2's;
// non-gated (Relu2/Identity/ReLU/SiLU/Gelu, e.g. Nemotron-H) require them to be equal.
if (mUseINT8WoqPerChannel)
{
if (isGatedActivation(base_activation_type))
{
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
}
else
{
TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier,
"fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.");
}
}
else
{
if (isGatedActivation(base_activation_type))
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
}
else
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier,
"fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.");
}
}

setRunnerProfiles(profile_ids);

Expand Down
51 changes: 34 additions & 17 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,14 +1312,14 @@ def create_weights(self, module: torch.nn.Module):
# since the quantized weights have their own layout
w3_w1_weight_shape = (module.expert_size_per_partition,
module.hidden_size,
module.intermediate_size_per_partition * 2)
module.expand_intermediate_size_per_partition)
w2_weight_shape = (module.expert_size_per_partition,
module.intermediate_size_per_partition,
module.hidden_size)

fc31_weight_scale = nn.Parameter(torch.empty(
module.expert_size_per_partition,
module.intermediate_size_per_partition * 2,
module.expand_intermediate_size_per_partition,
dtype=module.dtype),
requires_grad=False)
module.register_parameter("fc31_weight_scale", fc31_weight_scale)
Expand Down Expand Up @@ -1354,10 +1354,19 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight_shard = load_weight_shard(w1_weight, module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN)
w3_weight_shard = load_weight_shard(w3_weight, module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN)
w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0)

# w3_weight (gate_proj) is empty for non-gated MoE (e.g. Nemotron-H squared-ReLU).
# Only concatenate the gate projection when present; otherwise the single
# up-projection fills the (non-doubled) intermediate buffer. Mirrors the
# non-gated handling in UnquantizedFusedMoEMethod.load_expert_w3_w1_weight.
if w3_weight is not None and w3_weight.numel() > 0:
w3_weight_shard = load_weight_shard(w3_weight, module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN)
w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard],
dim=0)
else:
w31_weight_shard = w1_weight_shard

weight_dtype = torch.int8

Expand Down Expand Up @@ -1398,25 +1407,33 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
non_blocking=True)

def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
# fc31 scales
all_w3_scales = [
load_weight_shard(weights[f"{expert_id}.w3.weight_scale"],
module.tp_size, module.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in module.initial_local_expert_ids
]
# fc31 scales. w1 (up_proj) is always present; w3 (gate_proj) is absent
# for non-gated MoE (e.g. Nemotron-H squared-ReLU). Only concatenate the
# gate-projection scales when the gate weights are present; otherwise the
# up-projection scales alone fill the (non-doubled) fc31 scale buffer.
all_w1_scales = [
load_weight_shard(weights[f"{expert_id}.w1.weight_scale"],
module.tp_size, module.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in module.initial_local_expert_ids
]
w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
torch.stack(all_w1_scales)], dim=-1)
has_w3_scales = all(
f"{expert_id}.w3.weight_scale" in weights
for expert_id in module.initial_local_expert_ids)
if module.is_gated_activation and has_w3_scales:
all_w3_scales = [
load_weight_shard(weights[f"{expert_id}.w3.weight_scale"],
module.tp_size, module.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in module.initial_local_expert_ids
]
w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
torch.stack(all_w1_scales)], dim=-1)
else:
w3_w1_scales = torch.stack(all_w1_scales)
w3_w1_scales = w3_w1_scales.to(module.dtype)
module.fc31_weight_scale.data.copy_(w3_w1_scales.contiguous())

# fc2 scales
all_w2_scales = [
load_weight_shard(weights[f"{expert_id}.w2.weight_scale"],
Expand Down