diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index feede19bf71a..7634faaf8a33 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -480,8 +480,19 @@ class FusedMoeRunner : public torch::CustomClassHolder { // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: // [num_experts, inter_size, hidden_size] - TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2, - "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + // Mirror the non-woq else-branch below: gated activations (Swiglu/Geglu) require fc1's + // intermediate dim to be 2x fc2's (one half each for gate and up), while non-gated + // activations (Relu2/Identity/ReLU/SiLU/Gelu, e.g. Nemotron-H) require them to be equal. + if (isGatedActivation(base_activation_type)) + { + TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + } + else + { + TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier, + "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."); + } } else { @@ -771,8 +782,6 @@ class FusedMoeRunner : public torch::CustomClassHolder } TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0], "fc1_expert_weights and fc2_expert_weights must have the same number of experts."); - TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, - "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); TORCH_CHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant(), "Block-scaling factors provided for non block-scaling quantization"); @@ -815,6 +824,38 @@ class FusedMoeRunner : public torch::CustomClassHolder reinterpret_cast(swiglu_alpha.has_value() ? swiglu_alpha.value().const_data_ptr() : nullptr), reinterpret_cast(swiglu_beta.has_value() ? swiglu_beta.value().const_data_ptr() : nullptr), reinterpret_cast(swiglu_limit.has_value() ? swiglu_limit.value().const_data_ptr() : nullptr)); + + // Validate the fc1/fc2 inter-size relationship now that the activation type (gated vs + // non-gated) is finalized. INT8-woq uses a transposed weight layout, so its fc1/fc2 dim + // ordering differs from the non-woq path; both mirror the gated/non-gated split used in + // runMoe(). Gated activations (Swiglu/Geglu) require fc1's intermediate dim to be 2x fc2's; + // non-gated (Relu2/Identity/ReLU/SiLU/Gelu, e.g. Nemotron-H) require them to be equal. + if (mUseINT8WoqPerChannel) + { + if (isGatedActivation(base_activation_type)) + { + TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + } + else + { + TORCH_CHECK(fc1_expert_weights.sizes()[2] == fc2_expert_weights.sizes()[1] * mInnerDimMultiplier, + "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."); + } + } + else + { + if (isGatedActivation(base_activation_type)) + { + TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + } + else + { + TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier, + "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."); + } + } setRunnerProfiles(profile_ids); diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 42c9c84e2408..5ceba9437e6a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1312,14 +1312,14 @@ def create_weights(self, module: torch.nn.Module): # since the quantized weights have their own layout w3_w1_weight_shape = (module.expert_size_per_partition, module.hidden_size, - module.intermediate_size_per_partition * 2) + module.expand_intermediate_size_per_partition) w2_weight_shape = (module.expert_size_per_partition, module.intermediate_size_per_partition, module.hidden_size) fc31_weight_scale = nn.Parameter(torch.empty( module.expert_size_per_partition, - module.intermediate_size_per_partition * 2, + module.expand_intermediate_size_per_partition, dtype=module.dtype), requires_grad=False) module.register_parameter("fc31_weight_scale", fc31_weight_scale) @@ -1354,10 +1354,19 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, w1_weight_shard = load_weight_shard(w1_weight, module.tp_size, module.tp_rank, TensorParallelMode.COLUMN) - w3_weight_shard = load_weight_shard(w3_weight, module.tp_size, - module.tp_rank, - TensorParallelMode.COLUMN) - w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + + # w3_weight (gate_proj) is empty for non-gated MoE (e.g. Nemotron-H squared-ReLU). + # Only concatenate the gate projection when present; otherwise the single + # up-projection fills the (non-doubled) intermediate buffer. Mirrors the + # non-gated handling in UnquantizedFusedMoEMethod.load_expert_w3_w1_weight. + if w3_weight is not None and w3_weight.numel() > 0: + w3_weight_shard = load_weight_shard(w3_weight, module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN) + w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], + dim=0) + else: + w31_weight_shard = w1_weight_shard weight_dtype = torch.int8 @@ -1398,25 +1407,33 @@ def load_expert_w2_weight(self, module: torch.nn.Module, non_blocking=True) def load_quant_scales(self, module: torch.nn.Module, weights: Dict): - # fc31 scales - all_w3_scales = [ - load_weight_shard(weights[f"{expert_id}.w3.weight_scale"], - module.tp_size, module.tp_rank, - TensorParallelMode.COLUMN) - for expert_id in module.initial_local_expert_ids - ] + # fc31 scales. w1 (up_proj) is always present; w3 (gate_proj) is absent + # for non-gated MoE (e.g. Nemotron-H squared-ReLU). Only concatenate the + # gate-projection scales when the gate weights are present; otherwise the + # up-projection scales alone fill the (non-doubled) fc31 scale buffer. all_w1_scales = [ load_weight_shard(weights[f"{expert_id}.w1.weight_scale"], module.tp_size, module.tp_rank, TensorParallelMode.COLUMN) for expert_id in module.initial_local_expert_ids ] - w3_w1_scales = torch.cat( - [torch.stack(all_w3_scales), - torch.stack(all_w1_scales)], dim=-1) + has_w3_scales = all( + f"{expert_id}.w3.weight_scale" in weights + for expert_id in module.initial_local_expert_ids) + if module.is_gated_activation and has_w3_scales: + all_w3_scales = [ + load_weight_shard(weights[f"{expert_id}.w3.weight_scale"], + module.tp_size, module.tp_rank, + TensorParallelMode.COLUMN) + for expert_id in module.initial_local_expert_ids + ] + w3_w1_scales = torch.cat( + [torch.stack(all_w3_scales), + torch.stack(all_w1_scales)], dim=-1) + else: + w3_w1_scales = torch.stack(all_w1_scales) w3_w1_scales = w3_w1_scales.to(module.dtype) module.fc31_weight_scale.data.copy_(w3_w1_scales.contiguous()) - # fc2 scales all_w2_scales = [ load_weight_shard(weights[f"{expert_id}.w2.weight_scale"],