Commit bec302c
Fix TEGroupedLinear quantization for expert parallelism (EP > 1) (#833)
## What does this PR do?
**Type of change:** Bug fix / Compatibility update
**Overview:**
Fix `te_grouped_quantized_linear_fn` argument parsing for
TEGroupedLinear quantization when parallelism configuration results in
fewer local experts per GPU.
### Problem
TransformerEngine changed the _GroupedLinear.forward signature in PR
#2377 (released in TE 2.10):
Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int],
use_bias, is_first_microbatch, ...)
New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple,
*weights_and_biases) where non_tensor_args = (m_splits, use_bias,
is_first_microbatch, ...)
Without this fix, ModelOpt's quantization code fails with newer TE
versions because it tries to access m_splits directly from args[idx +
1], but in TE >= 2.10, that position contains the non_tensor_args tuple
instead.
### Root Cause
The code assumed m_splits was always directly accessible at args[idx +
1], but TransformerEngine PR #2377 changed the signature to pack all
non-tensor arguments into a tuple.
Taking Qwen3-30B-A3B (with `num_gemms=21`, threshold=44) as an example:
### Solution
Added version checking to handle both signatures:
```python
if Version("2.10") <= _TE_VERSION:
# New signature: non_tensor_args is a tuple, m_splits is the first element
num_gemms = len(args[idx + 1][0])
else:
# Old signature: m_splits is directly args[idx + 1]
num_gemms = len(args[idx + 1])
```
## Usage
<!-- You can potentially add a usage example below. -->
Works seamlessly with any TransformerEngine version:
```python
# High EP quantization - previously failed, now works
torchrun --nproc_per_node 8 examples/quantization/quantize.py \
--hf-model-id /models/Qwen3-30B-A3B \
--export-quant-cfg fp8 \
--megatron-save-path /models/Qwen3-30B-A3B_fp8_mlm \
--tp 8 \
--ep 8
# High EP inference - previously failed, now works
torchrun --nproc_per_node 8 examples/quantization/ptq_generate.py \
--megatron-load-path /models/Qwen3-30B-A3B_fp8_mlm \
--hf-model-id /models/Qwen3-30B-A3B \
--tp 8 \
--ep 8
```
## Testing
<!-- Mention how have you tested your change if applicable. -->
```python
# High EP quantization - previously failed, now works
torchrun --nproc_per_node 8 examples/quantization/quantize.py \
--hf-model-id /models/Qwen3-30B-A3B \
--export-quant-cfg fp8 \
--megatron-save-path /models/Qwen3-30B-A3B_fp8_mlm \
--tp 8 \
--ep 8
# High EP inference - previously failed, now works
torchrun --nproc_per_node 8 examples/quantization/ptq_generate.py \
--megatron-load-path /models/Qwen3-30B-A3B_fp8_mlm \
--hf-model-id /models/Qwen3-30B-A3B \
--tp 8 \
--ep 8
```
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->
## Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Enhanced Mixture of Experts (MoE) calibration validation and
synchronization to ensure consistency across distributed training
setups.
* Improved grouped linear quantization robustness to handle varying
input patterns and tensor dimensions.
* **Improvements**
* Better error handling for incomplete MoE expert calibration detection.
* More flexible argument parsing for quantization operations.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: James Shen <yueshen@nvidia.com>1 parent 452c5a0 commit bec302c
1 file changed
+11
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
156 | 156 | | |
157 | 157 | | |
158 | 158 | | |
159 | | - | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
160 | 170 | | |
161 | 171 | | |
162 | 172 | | |
| |||
0 commit comments