Skip to content

Commit bec302c

Browse files
yueshen2016kevalmorabia97
authored andcommitted
Fix TEGroupedLinear quantization for expert parallelism (EP > 1) (#833)
## What does this PR do? **Type of change:** Bug fix / Compatibility update **Overview:** Fix `te_grouped_quantized_linear_fn` argument parsing for TEGroupedLinear quantization when parallelism configuration results in fewer local experts per GPU. ### Problem TransformerEngine changed the _GroupedLinear.forward signature in PR #2377 (released in TE 2.10): Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, is_first_microbatch, ...) New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases) where non_tensor_args = (m_splits, use_bias, is_first_microbatch, ...) Without this fix, ModelOpt's quantization code fails with newer TE versions because it tries to access m_splits directly from args[idx + 1], but in TE >= 2.10, that position contains the non_tensor_args tuple instead. ### Root Cause The code assumed m_splits was always directly accessible at args[idx + 1], but TransformerEngine PR #2377 changed the signature to pack all non-tensor arguments into a tuple. Taking Qwen3-30B-A3B (with `num_gemms=21`, threshold=44) as an example: ### Solution Added version checking to handle both signatures: ```python if Version("2.10") <= _TE_VERSION: # New signature: non_tensor_args is a tuple, m_splits is the first element num_gemms = len(args[idx + 1][0]) else: # Old signature: m_splits is directly args[idx + 1] num_gemms = len(args[idx + 1]) ``` ## Usage <!-- You can potentially add a usage example below. --> Works seamlessly with any TransformerEngine version: ```python # High EP quantization - previously failed, now works torchrun --nproc_per_node 8 examples/quantization/quantize.py \ --hf-model-id /models/Qwen3-30B-A3B \ --export-quant-cfg fp8 \ --megatron-save-path /models/Qwen3-30B-A3B_fp8_mlm \ --tp 8 \ --ep 8 # High EP inference - previously failed, now works torchrun --nproc_per_node 8 examples/quantization/ptq_generate.py \ --megatron-load-path /models/Qwen3-30B-A3B_fp8_mlm \ --hf-model-id /models/Qwen3-30B-A3B \ --tp 8 \ --ep 8 ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ```python # High EP quantization - previously failed, now works torchrun --nproc_per_node 8 examples/quantization/quantize.py \ --hf-model-id /models/Qwen3-30B-A3B \ --export-quant-cfg fp8 \ --megatron-save-path /models/Qwen3-30B-A3B_fp8_mlm \ --tp 8 \ --ep 8 # High EP inference - previously failed, now works torchrun --nproc_per_node 8 examples/quantization/ptq_generate.py \ --megatron-load-path /models/Qwen3-30B-A3B_fp8_mlm \ --hf-model-id /models/Qwen3-30B-A3B \ --tp 8 \ --ep 8 ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Enhanced Mixture of Experts (MoE) calibration validation and synchronization to ensure consistency across distributed training setups. * Improved grouped linear quantization robustness to handle varying input patterns and tensor dimensions. * **Improvements** * Better error handling for incomplete MoE expert calibration detection. * More flexible argument parsing for quantization operations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: James Shen <yueshen@nvidia.com>
1 parent 452c5a0 commit bec302c

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,17 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):
156156
_assert_te_fp8_enabled()
157157
idx = 1 if func_name == "_forward" else 0
158158
inp = args[idx]
159-
num_gemms = len(args[idx + 1])
159+
160+
# Handle both old and new TE signatures (changed in PR #2377 in TE 2.10)
161+
# New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases)
162+
# Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...)
163+
if Version("2.10") <= _TE_VERSION:
164+
# New signature: non_tensor_args is a tuple, m_splits is the first element
165+
num_gemms = len(args[idx + 1][0])
166+
else:
167+
# Old signature: m_splits is directly args[idx + 1]
168+
num_gemms = len(args[idx + 1])
169+
160170
weights_and_biases = args[-2 * num_gemms :]
161171
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
162172
quantized_inputs = self.input_quantizer(inp)

0 commit comments

Comments
 (0)