Commit 50e112e
committed
fix(te-plugin): make _Linear arg indexing robust to TE signature changes (#1473)
### What does this PR do?
Type of change: Bug fix
ModelOpt's `te_quantized_linear_fn` and `te_grouped_quantized_linear_fn`
read `weight` / `inp` from hard-coded positions in `args`. Two TE
signature changes broke this scheme:
- **TE 1.x → 2.0:** dropped the legacy `weight_fp8` slot between
`weight` and `inp`. ModelOpt handled this with an `if Version("2.0") <=
_TE_VERSION:` branch + a duplicate else branch.
- **TE 2.14 → 2.15:** inserted `weight_workspace` between `weight` and
`inp` at the `_Linear.forward` call site ([TE 2.15 linear.py
L1663](https://github.com/NVIDIA/TransformerEngine/blob/release_v2.15/transformer_engine/pytorch/module/linear.py#L1663)).
Unhandled by ModelOpt — `args[idx + 1]` resolved to `None` (workspace is
None outside FP8), which then crashed `TensorQuantizer.forward` on
`inputs.numel()` with `AttributeError: 'NoneType' object has no
attribute 'numel'`. Surfaced as a regression in Megatron-Bridge after
the TE 2.15 bump alongside ModelOpt 0.44.0rc3.
- **TE 2.10:** `_GroupedLinear.forward`'s second positional slot was
renamed `m_splits` → `non_tensor_args` (tuple wrapping). ModelOpt had a
separate `Version("2.10")` gate for this.
Replace all three version gates with **parameter-name introspection** of
the live `_Linear.forward` / `_GroupedLinear.forward` signature. The
parameter names (`weight`, `inp`, `m_splits`, `non_tensor_args`) have
been stable across TE 1.x, 2.x, and 2.15+; only their relative positions
shift. The new code reads the live signature via
`inspect.signature(...).parameters`, locates `weight`/`inp` by name, and
mutates only those positions in a list copy of `args` — everything
between (e.g. TE 2.15's `weight_workspace`) and after passes through
verbatim. The dual-branch code in `te_quantized_linear_fn` collapses to
a single path.
### Usage
No public API change. PTQ continues to work transparently across all
supported TE versions:
```python
import modelopt.torch.quantization as mtq
# Works on TE 1.x, 2.0-2.14, 2.15.x, and 2.16+ — no version flag needed.
mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop)
```
### Testing
<!-- Mention how have you tested your change if applicable. -->
Existing TE plugin tests
(`tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py`)
exercise both the `_forward` (no-grad calibration) and `_apply`
(grad-enabled training) paths of `te_quantized_linear_fn` for
`te.pytorch.Linear` — they would have caught the original TE 2.15
regression on a CI matrix entry pinned to TE 2.15. Verified trace
correctness across:
| TE version | `_Linear.forward` signature | `_te_linear` weight→inp gap
| `_GroupedLinear.forward` second slot |
|---|---|---|---|
| 1.x | `(ctx, weight, weight_fp8, inp, …)` | 1 | n/a |
| 2.0–2.14 | `(ctx, weight, inp, bias, …)` | 0 | `m_splits` |
| 2.15.x | `(ctx, weight, weight_workspace, inp, …)` | 1 |
`non_tensor_args` |
| 2.16+ (main) | `(ctx, weight, inp, bias, fwd_args)` | 0 |
`non_tensor_args` |
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).
- Is this change backward compatible?: ✅ <!--- Public API unchanged;
broadens the range of TE versions that work (TE 2.15.x now supported, TE
1.x still supported via the same introspection path). -->
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: N/A <!--- Only
adds a stdlib `inspect` import. -->
- Did you write any new necessary tests?: Existing tests sufficient
<!--- Bug fix is covered by existing `test_transformer_engine.py` for
whatever single TE version CI exercises. A multi-version TE matrix is
the right next step but is out of scope for this PR. -->
### Additional Information
<!-- E.g. related issue. -->
Triggered by Megatron-Bridge
NVIDIA-NeMo/Megatron-Bridge#3783 failing tests
after bumping ModelOpt 0.44.0rc2 → 0.44.0rc3 together with a Megatron-LM
bump that pulls TE 2.15. ModelOpt rc2 had the same latent bug — it just
wasn't exercised until TE 2.15 became the runtime version.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Refactor**
* Improved Transformer Engine quantization plugin robustness by using
runtime parameter inspection instead of version-based branching,
ensuring compatibility across TE versions without requiring manual
updates.
[](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1473)
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>1 parent 1b5d448 commit 50e112e
1 file changed
Lines changed: 43 additions & 52 deletions
Lines changed: 43 additions & 52 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
18 | 19 | | |
19 | 20 | | |
20 | 21 | | |
| |||
74 | 75 | | |
75 | 76 | | |
76 | 77 | | |
77 | | - | |
78 | | - | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
92 | | - | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | | - | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
101 | 96 | | |
102 | 97 | | |
103 | 98 | | |
| |||
161 | 156 | | |
162 | 157 | | |
163 | 158 | | |
164 | | - | |
165 | | - | |
166 | | - | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
171 | | - | |
172 | | - | |
173 | | - | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
181 | | - | |
182 | | - | |
183 | | - | |
184 | | - | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
191 | | - | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
192 | 169 | | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
193 | 184 | | |
194 | 185 | | |
195 | 186 | | |
| |||
0 commit comments