Skip to content

Commit 50e112e

Browse files
fix(te-plugin): make _Linear arg indexing robust to TE signature changes (#1473)
### What does this PR do? Type of change: Bug fix ModelOpt's `te_quantized_linear_fn` and `te_grouped_quantized_linear_fn` read `weight` / `inp` from hard-coded positions in `args`. Two TE signature changes broke this scheme: - **TE 1.x → 2.0:** dropped the legacy `weight_fp8` slot between `weight` and `inp`. ModelOpt handled this with an `if Version("2.0") <= _TE_VERSION:` branch + a duplicate else branch. - **TE 2.14 → 2.15:** inserted `weight_workspace` between `weight` and `inp` at the `_Linear.forward` call site ([TE 2.15 linear.py L1663](https://github.com/NVIDIA/TransformerEngine/blob/release_v2.15/transformer_engine/pytorch/module/linear.py#L1663)). Unhandled by ModelOpt — `args[idx + 1]` resolved to `None` (workspace is None outside FP8), which then crashed `TensorQuantizer.forward` on `inputs.numel()` with `AttributeError: 'NoneType' object has no attribute 'numel'`. Surfaced as a regression in Megatron-Bridge after the TE 2.15 bump alongside ModelOpt 0.44.0rc3. - **TE 2.10:** `_GroupedLinear.forward`'s second positional slot was renamed `m_splits` → `non_tensor_args` (tuple wrapping). ModelOpt had a separate `Version("2.10")` gate for this. Replace all three version gates with **parameter-name introspection** of the live `_Linear.forward` / `_GroupedLinear.forward` signature. The parameter names (`weight`, `inp`, `m_splits`, `non_tensor_args`) have been stable across TE 1.x, 2.x, and 2.15+; only their relative positions shift. The new code reads the live signature via `inspect.signature(...).parameters`, locates `weight`/`inp` by name, and mutates only those positions in a list copy of `args` — everything between (e.g. TE 2.15's `weight_workspace`) and after passes through verbatim. The dual-branch code in `te_quantized_linear_fn` collapses to a single path. ### Usage No public API change. PTQ continues to work transparently across all supported TE versions: ```python import modelopt.torch.quantization as mtq # Works on TE 1.x, 2.0-2.14, 2.15.x, and 2.16+ — no version flag needed. mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop) ``` ### Testing <!-- Mention how have you tested your change if applicable. --> Existing TE plugin tests (`tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py`) exercise both the `_forward` (no-grad calibration) and `_apply` (grad-enabled training) paths of `te_quantized_linear_fn` for `te.pytorch.Linear` — they would have caught the original TE 2.15 regression on a CI matrix entry pinned to TE 2.15. Verified trace correctness across: | TE version | `_Linear.forward` signature | `_te_linear` weight→inp gap | `_GroupedLinear.forward` second slot | |---|---|---|---| | 1.x | `(ctx, weight, weight_fp8, inp, …)` | 1 | n/a | | 2.0–2.14 | `(ctx, weight, inp, bias, …)` | 0 | `m_splits` | | 2.15.x | `(ctx, weight, weight_workspace, inp, …)` | 1 | `non_tensor_args` | | 2.16+ (main) | `(ctx, weight, inp, bias, fwd_args)` | 0 | `non_tensor_args` | ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ <!--- Public API unchanged; broadens the range of TE versions that work (TE 2.15.x now supported, TE 1.x still supported via the same introspection path). --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A <!--- Only adds a stdlib `inspect` import. --> - Did you write any new necessary tests?: Existing tests sufficient <!--- Bug fix is covered by existing `test_transformer_engine.py` for whatever single TE version CI exercises. A multi-version TE matrix is the right next step but is out of scope for this PR. --> ### Additional Information <!-- E.g. related issue. --> Triggered by Megatron-Bridge NVIDIA-NeMo/Megatron-Bridge#3783 failing tests after bumping ModelOpt 0.44.0rc2 → 0.44.0rc3 together with a Megatron-LM bump that pulls TE 2.15. ModelOpt rc2 had the same latent bug — it just wasn't exercised until TE 2.15 became the runtime version. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved Transformer Engine quantization plugin robustness by using runtime parameter inspection instead of version-based branching, ensuring compatibility across TE versions without requiring manual updates. [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1473) <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 1b5d448 commit 50e112e

1 file changed

Lines changed: 43 additions & 52 deletions

File tree

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 43 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Support quantization for Transformer Engine layers."""
1717

18+
import inspect
1819
import warnings
1920

2021
import torch
@@ -74,30 +75,24 @@ def _setup(self):
7475
def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
7576
"""Quantized version specifically for TE with weight first, then input."""
7677
_assert_te_fp8_enabled()
77-
if Version("2.0") <= _TE_VERSION:
78-
idx = 1 if func_name == "_forward" else 0
79-
weight, inputs = args[idx], args[idx + 1]
80-
remaining_args = args[idx + 2 :]
81-
weight = self.weight_quantizer(weight)
82-
inputs = self.input_quantizer(inputs)
83-
new_args = (weight, inputs, *remaining_args)
84-
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
85-
output = getattr(package, func_name)(
86-
*new_args,
87-
**kwargs,
88-
)
89-
else:
90-
idx = 1 if func_name == "_forward" else 0
91-
weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
92-
remaining_args = args[idx + 3 :]
93-
weight = self.weight_quantizer(weight)
94-
inputs = self.input_quantizer(inputs)
95-
new_args = (weight, weight_fp8, inputs, *remaining_args)
96-
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
97-
output = getattr(package, func_name)(
98-
*new_args,
99-
**kwargs,
100-
)
78+
# Locate `weight` and `inp` by parameter name in the un-patched `_Linear.forward`
79+
# signature — robust to TE versions that insert positional args between them
80+
# (e.g. `weight_fp8` in TE 1.x, `weight_workspace` in TE 2.15).
81+
# NOTE: we're called from inside `replace_function`'s context, so
82+
# `_Linear.forward` may currently point at the `functools.partial` wrapper
83+
# (whose signature collapses to `*args, **kwargs`). The original is cached at
84+
# `_Linear._forward` while the patch is active (when `_apply` is patched
85+
# instead, `_forward` is absent and `forward` is itself the original).
86+
# `_forward` path receives a leading None (placeholder ctx); `_apply` does not.
87+
orig_forward = getattr(te_linear._Linear, "_forward", te_linear._Linear.forward)
88+
names = list(inspect.signature(orig_forward).parameters)
89+
ctx_offset = 0 if func_name == "_forward" else 1
90+
weight_pos = names.index("weight") - ctx_offset
91+
inp_pos = names.index("inp") - ctx_offset
92+
new_args = list(args)
93+
new_args[weight_pos] = self.weight_quantizer(args[weight_pos])
94+
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
95+
output = getattr(package, func_name)(*new_args, **kwargs)
10196
return self.output_quantizer(output)
10297

10398
# Override the quantized linear function
@@ -161,35 +156,31 @@ def iter_weights_for_calibration(self):
161156
@staticmethod
162157
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
163158
_assert_te_fp8_enabled()
164-
idx = 1 if func_name == "_forward" else 0
165-
inp = args[idx]
166-
167-
# Handle both old and new TE signatures (changed in PR #2377 in TE 2.10)
168-
# New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases)
169-
# Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...)
170-
if Version("2.10") <= _TE_VERSION:
171-
# New signature: non_tensor_args is a tuple, m_splits is the first element
172-
num_gemms = len(args[idx + 1][0])
173-
else:
174-
# Old signature: m_splits is directly args[idx + 1]
175-
num_gemms = len(args[idx + 1])
176-
177-
weights_and_biases = args[-2 * num_gemms :]
178-
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
179-
quantized_inputs = self.input_quantizer(inp)
180-
quantized_weights = [self.weight_quantizer(weight) for weight in weights]
181-
182-
output = getattr(package, func_name)(
183-
*(
184-
args[0],
185-
quantized_inputs,
186-
)
187-
if func_name == "_forward"
188-
else (quantized_inputs,),
189-
*args[idx + 1 : -2 * num_gemms],
190-
*quantized_weights,
191-
*biases,
159+
# Locate `inp` and the m_splits-bearing arg by parameter name. The second
160+
# slot was renamed from `m_splits` (TE < 2.10) to `non_tensor_args` (TE
161+
# 2.10+, where m_splits is now at non_tensor_args[0]). `*weights_and_biases`
162+
# is always the trailing variadic — 2 * num_gemms tensors (weights, then biases).
163+
# See `te_quantized_linear_fn` for why we look up `_forward` here.
164+
# `_forward` path receives a leading None (placeholder ctx); `_apply` does not.
165+
orig_forward = getattr(
166+
te_grouped_linear._GroupedLinear,
167+
"_forward",
168+
te_grouped_linear._GroupedLinear.forward,
192169
)
170+
sig_params = list(inspect.signature(orig_forward).parameters)
171+
ctx_offset = 0 if func_name == "_forward" else 1
172+
inp_pos = sig_params.index("inp") - ctx_offset
173+
if "non_tensor_args" in sig_params:
174+
num_gemms = len(args[sig_params.index("non_tensor_args") - ctx_offset][0])
175+
else:
176+
num_gemms = len(args[sig_params.index("m_splits") - ctx_offset])
177+
weights_start = len(args) - 2 * num_gemms
178+
179+
new_args = list(args)
180+
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
181+
for i in range(weights_start, weights_start + num_gemms):
182+
new_args[i] = self.weight_quantizer(args[i])
183+
output = getattr(package, func_name)(*new_args)
193184
return self.output_quantizer(output)
194185

195186
# Override the quantized linear function

0 commit comments

Comments
 (0)