diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index cec7ff9568..afc08211f0 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -156,7 +156,17 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args): _assert_te_fp8_enabled() idx = 1 if func_name == "_forward" else 0 inp = args[idx] - num_gemms = len(args[idx + 1]) + + # Handle both old and new TE signatures (changed in PR #2377 in TE 2.10) + # New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases) + # Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...) + if Version("2.10") <= _TE_VERSION: + # New signature: non_tensor_args is a tuple, m_splits is the first element + num_gemms = len(args[idx + 1][0]) + else: + # Old signature: m_splits is directly args[idx + 1] + num_gemms = len(args[idx + 1]) + weights_and_biases = args[-2 * num_gemms :] weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] quantized_inputs = self.input_quantizer(inp)