Skip to content

Commit c3b1f5a

Browse files
kevalmorabia97jenchen13
authored andcommitted
fix(te-plugin): make _Linear arg indexing robust to TE signature changes (#1473)
Type of change: Bug fix ModelOpt's `te_quantized_linear_fn` and `te_grouped_quantized_linear_fn` read `weight` / `inp` from hard-coded positions in `args`. Two TE signature changes broke this scheme: - **TE 1.x → 2.0:** dropped the legacy `weight_fp8` slot between `weight` and `inp`. ModelOpt handled this with an `if Version("2.0") <= _TE_VERSION:` branch + a duplicate else branch. - **TE 2.14 → 2.15:** inserted `weight_workspace` between `weight` and `inp` at the `_Linear.forward` call site ([TE 2.15 linear.py L1663](https://github.com/NVIDIA/TransformerEngine/blob/release_v2.15/transformer_engine/pytorch/module/linear.py#L1663)). Unhandled by ModelOpt — `args[idx + 1]` resolved to `None` (workspace is None outside FP8), which then crashed `TensorQuantizer.forward` on `inputs.numel()` with `AttributeError: 'NoneType' object has no attribute 'numel'`. Surfaced as a regression in Megatron-Bridge after the TE 2.15 bump alongside ModelOpt 0.44.0rc3. - **TE 2.10:** `_GroupedLinear.forward`'s second positional slot was renamed `m_splits` → `non_tensor_args` (tuple wrapping). ModelOpt had a separate `Version("2.10")` gate for this. Replace all three version gates with **parameter-name introspection** of the live `_Linear.forward` / `_GroupedLinear.forward` signature. The parameter names (`weight`, `inp`, `m_splits`, `non_tensor_args`) have been stable across TE 1.x, 2.x, and 2.15+; only their relative positions shift. The new code reads the live signature via `inspect.signature(...).parameters`, locates `weight`/`inp` by name, and mutates only those positions in a list copy of `args` — everything between (e.g. TE 2.15's `weight_workspace`) and after passes through verbatim. The dual-branch code in `te_quantized_linear_fn` collapses to a single path. No public API change. PTQ continues to work transparently across all supported TE versions: ```python import modelopt.torch.quantization as mtq mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop) ``` <!-- Mention how have you tested your change if applicable. --> Existing TE plugin tests (`tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py`) exercise both the `_forward` (no-grad calibration) and `_apply` (grad-enabled training) paths of `te_quantized_linear_fn` for `te.pytorch.Linear` — they would have caught the original TE 2.15 regression on a CI matrix entry pinned to TE 2.15. Verified trace correctness across: | TE version | `_Linear.forward` signature | `_te_linear` weight→inp gap | `_GroupedLinear.forward` second slot | |---|---|---|---| | 1.x | `(ctx, weight, weight_fp8, inp, …)` | 1 | n/a | | 2.0–2.14 | `(ctx, weight, inp, bias, …)` | 0 | `m_splits` | | 2.15.x | `(ctx, weight, weight_workspace, inp, …)` | 1 | `non_tensor_args` | | 2.16+ (main) | `(ctx, weight, inp, bias, fwd_args)` | 0 | `non_tensor_args` | Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ <!--- Public API unchanged; broadens the range of TE versions that work (TE 2.15.x now supported, TE 1.x still supported via the same introspection path). --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A <!--- Only adds a stdlib `inspect` import. --> - Did you write any new necessary tests?: Existing tests sufficient <!--- Bug fix is covered by existing `test_transformer_engine.py` for whatever single TE version CI exercises. A multi-version TE matrix is the right next step but is out of scope for this PR. --> <!-- E.g. related issue. --> Triggered by Megatron-Bridge NVIDIA-NeMo/Megatron-Bridge#3783 failing tests after bumping ModelOpt 0.44.0rc2 → 0.44.0rc3 together with a Megatron-LM bump that pulls TE 2.15. ModelOpt rc2 had the same latent bug — it just wasn't exercised until TE 2.15 became the runtime version. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **Refactor** * Improved Transformer Engine quantization plugin robustness by using runtime parameter inspection instead of version-based branching, ensuring compatibility across TE versions without requiring manual updates. [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1473) <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 406c4b6 commit c3b1f5a

1 file changed

Lines changed: 43 additions & 54 deletions

File tree

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import copy
1919
import os
20+
import inspect
2021
import warnings
2122

2223
import torch
@@ -81,30 +82,24 @@ def _setup(self):
8182
def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
8283
"""Quantized version specifically for TE with weight first, then input."""
8384
_assert_te_fp8_enabled()
84-
if Version("2.0") <= _TE_VERSION:
85-
idx = 1 if func_name == "_forward" else 0
86-
weight, inputs = args[idx], args[idx + 1]
87-
remaining_args = args[idx + 2 :]
88-
weight = self.weight_quantizer(weight)
89-
inputs = self.input_quantizer(inputs)
90-
new_args = (weight, inputs, *remaining_args)
91-
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
92-
output = getattr(package, func_name)(
93-
*new_args,
94-
**kwargs,
95-
)
96-
else:
97-
idx = 1 if func_name == "_forward" else 0
98-
weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
99-
remaining_args = args[idx + 3 :]
100-
weight = self.weight_quantizer(weight)
101-
inputs = self.input_quantizer(inputs)
102-
new_args = (weight, weight_fp8, inputs, *remaining_args)
103-
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
104-
output = getattr(package, func_name)(
105-
*new_args,
106-
**kwargs,
107-
)
85+
# Locate `weight` and `inp` by parameter name in the un-patched `_Linear.forward`
86+
# signature — robust to TE versions that insert positional args between them
87+
# (e.g. `weight_fp8` in TE 1.x, `weight_workspace` in TE 2.15).
88+
# NOTE: we're called from inside `replace_function`'s context, so
89+
# `_Linear.forward` may currently point at the `functools.partial` wrapper
90+
# (whose signature collapses to `*args, **kwargs`). The original is cached at
91+
# `_Linear._forward` while the patch is active (when `_apply` is patched
92+
# instead, `_forward` is absent and `forward` is itself the original).
93+
# `_forward` path receives a leading None (placeholder ctx); `_apply` does not.
94+
orig_forward = getattr(te_linear._Linear, "_forward", te_linear._Linear.forward)
95+
names = list(inspect.signature(orig_forward).parameters)
96+
ctx_offset = 0 if func_name == "_forward" else 1
97+
weight_pos = names.index("weight") - ctx_offset
98+
inp_pos = names.index("inp") - ctx_offset
99+
new_args = list(args)
100+
new_args[weight_pos] = self.weight_quantizer(args[weight_pos])
101+
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
102+
output = getattr(package, func_name)(*new_args, **kwargs)
108103
return self.output_quantizer(output)
109104

110105
# Override the quantized linear function
@@ -175,37 +170,31 @@ def iter_weights_for_calibration(self):
175170
@staticmethod
176171
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
177172
_assert_te_fp8_enabled()
178-
idx = 1 if func_name == "_forward" else 0
179-
inp = args[idx]
180-
181-
# Handle both old and new TE signatures (changed in PR #2377 in TE 2.10)
182-
# New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases)
183-
# Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, ...)
184-
if Version("2.10") <= _TE_VERSION:
185-
# New signature: non_tensor_args is a tuple, m_splits is the first element
186-
num_gemms = len(args[idx + 1][0])
187-
else:
188-
# Old signature: m_splits is directly args[idx + 1]
189-
num_gemms = len(args[idx + 1])
190-
191-
weights_and_biases = args[-2 * num_gemms :]
192-
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
193-
quantized_inputs = self.input_quantizer(inp)
194-
quantized_weights = [
195-
self._get_weight_quantizer(i)(weight) for i, weight in enumerate(weights)
196-
]
197-
198-
output = getattr(package, func_name)(
199-
*(
200-
args[0],
201-
quantized_inputs,
202-
)
203-
if func_name == "_forward"
204-
else (quantized_inputs,),
205-
*args[idx + 1 : -2 * num_gemms],
206-
*quantized_weights,
207-
*biases,
173+
# Locate `inp` and the m_splits-bearing arg by parameter name. The second
174+
# slot was renamed from `m_splits` (TE < 2.10) to `non_tensor_args` (TE
175+
# 2.10+, where m_splits is now at non_tensor_args[0]). `*weights_and_biases`
176+
# is always the trailing variadic — 2 * num_gemms tensors (weights, then biases).
177+
# See `te_quantized_linear_fn` for why we look up `_forward` here.
178+
# `_forward` path receives a leading None (placeholder ctx); `_apply` does not.
179+
orig_forward = getattr(
180+
te_grouped_linear._GroupedLinear,
181+
"_forward",
182+
te_grouped_linear._GroupedLinear.forward,
208183
)
184+
sig_params = list(inspect.signature(orig_forward).parameters)
185+
ctx_offset = 0 if func_name == "_forward" else 1
186+
inp_pos = sig_params.index("inp") - ctx_offset
187+
if "non_tensor_args" in sig_params:
188+
num_gemms = len(args[sig_params.index("non_tensor_args") - ctx_offset][0])
189+
else:
190+
num_gemms = len(args[sig_params.index("m_splits") - ctx_offset])
191+
weights_start = len(args) - 2 * num_gemms
192+
193+
new_args = list(args)
194+
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
195+
for i in range(weights_start, weights_start + num_gemms):
196+
new_args[i] = self.weight_quantizer(args[i])
197+
output = getattr(package, func_name)(*new_args)
209198
return self.output_quantizer(output)
210199

211200
# Override the quantized linear function

0 commit comments

Comments
 (0)