From 197ecda1851b819e6f2e99c28f2cf0aa258720b4 Mon Sep 17 00:00:00 2001 From: James Shen Date: Fri, 30 Jan 2026 14:05:12 -0800 Subject: [PATCH] Fix TEGroupedLinear quantization for expert parallelism (EP > 1) Signed-off-by: James Shen --- .../torch/quantization/plugins/transformer_engine.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index cec7ff9568..afc08211f0 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -156,7 +156,17 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args): _assert_te_fp8_enabled() idx = 1 if func_name == "_forward" else 0 inp = args[idx] - num_gemms = len(args[idx + 1]) + + # Handle both old and new TE signatures (changed in PR #2377 in TE 2.10) + # New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases) + # Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...) + if Version("2.10") <= _TE_VERSION: + # New signature: non_tensor_args is a tuple, m_splits is the first element + num_gemms = len(args[idx + 1][0]) + else: + # Old signature: m_splits is directly args[idx + 1] + num_gemms = len(args[idx + 1]) + weights_and_biases = args[-2 * num_gemms :] weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] quantized_inputs = self.input_quantizer(inp)