Skip to content

Commit febe313

Browse files
committed
Fix TEGroupedLinear quantization for expert parallelism (EP > 1)
Signed-off-by: James Shen <yueshen@nvidia.com>
1 parent 452c5a0 commit febe313

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,17 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):
156156
_assert_te_fp8_enabled()
157157
idx = 1 if func_name == "_forward" else 0
158158
inp = args[idx]
159-
num_gemms = len(args[idx + 1])
159+
160+
# Handle both old and new TE signatures (changed in PR #2377 in TE 2.10)
161+
# New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases)
162+
# Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...)
163+
if Version("2.10") <= _TE_VERSION:
164+
# New signature: non_tensor_args is a tuple, m_splits is the first element
165+
num_gemms = len(args[idx + 1][0])
166+
else:
167+
# Old signature: m_splits is directly args[idx + 1]
168+
num_gemms = len(args[idx + 1])
169+
160170
weights_and_biases = args[-2 * num_gemms :]
161171
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
162172
quantized_inputs = self.input_quantizer(inp)

0 commit comments

Comments
 (0)