Skip to content

Commit dd445ba

Browse files
kevalmorabia97jenchen13
authored andcommitted
fix(te-plugin): handle TE 2.15+ tuple return from _Linear / _GroupedLinear (#1481)
### What does this PR do? Type of change: Bug fix TE 2.15+ changed `_Linear.forward` and `_GroupedLinear.forward` to return `(out, new_workspace)` tuples instead of a single tensor. ModelOpt's patched `te_quantized_linear_fn` / `te_grouped_quantized_linear_fn` still piped the whole tuple into `self.output_quantizer`, crashing inside `TensorQuantizer.forward` on `tuple.numel()`: ``` File ".../modelopt/torch/quantization/plugins/transformer_engine.py", line 184, in te_grouped_quantized_linear_fn return self.output_quantizer(output) File ".../tensor_quantizer.py", line 1037, in forward if inputs.numel() == 0: AttributeError: 'tuple' object has no attribute 'numel' ``` Mirror the existing pattern from `_QuantTELayerNormLinear.forward`: when the underlying TE call returns a tuple, quantize only `output[0]` (the activation tensor) and pass auxiliary workspace metadata through unchanged. TE <= 2.14 returns a single tensor and falls through the `isinstance` branch identically to before this change. Already landed on `release/0.44.0` as commit `c897fbeaaf`; this brings `main` in sync. Follow-up to [#1473](#1473) (signature introspection + `_forward` cache lookup), which fixed an earlier symptom of the same TE 2.15 signature change but not this tuple-return path. ### Usage No public API change. PTQ continues to work transparently across TE 2.x: ```python import modelopt.torch.quantization as mtq mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop) # now works on TE 2.15.x ``` ### Testing <!-- Mention how have you tested your change if applicable. --> Verified locally against **both TE 2.12** and **TE 2.15.0** using: ```bash pytest tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py ``` Without this fix on TE 2.15, the same test fails immediately with `AttributeError: 'tuple' object has no attribute 'numel'`. With this fix, both versions exercise the same code paths and pass — TE <= 2.14 skips the `isinstance(output, tuple)` branch and behaves identically to before. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ <!--- Public API unchanged; TE <= 2.14 path is identical (isinstance branch is false). --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A <!--- Existing `test_transformer_engine.py` already exercises both paths; it would have caught this on TE 2.15 had CI been running against that version. A TE-version matrix is the right follow-up but is out of scope here. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information <!-- E.g. related issue. --> Triggered by Megatron-Bridge failing tests after their TE 2.15 bump. The `release/0.44.0` cherry-pick was pushed directly (commit `c897fbeaaf`) so Bridge could unblock; this PR carries the same fix forward to main. Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent d6e1973 commit dd445ba

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
100100
new_args[weight_pos] = self.weight_quantizer(args[weight_pos])
101101
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
102102
output = getattr(package, func_name)(*new_args, **kwargs)
103+
# TE 2.15+ returns `(out, new_weight_workspace)`; TE <= 2.14 returns just `out`.
104+
# Only the activation tensor participates in output quantization.
105+
if isinstance(output, tuple):
106+
return (self.output_quantizer(output[0]), *output[1:])
103107
return self.output_quantizer(output)
104108

105109
# Override the quantized linear function
@@ -195,6 +199,10 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):
195199
for i in range(weights_start, weights_start + num_gemms):
196200
new_args[i] = self.weight_quantizer(args[i])
197201
output = getattr(package, func_name)(*new_args)
202+
# TE 2.15+ returns `(out, new_workspaces)`; TE <= 2.14 returns just `out`.
203+
# Only the activation tensor participates in output quantization.
204+
if isinstance(output, tuple):
205+
return (self.output_quantizer(output[0]), *output[1:])
198206
return self.output_quantizer(output)
199207

200208
# Override the quantized linear function

0 commit comments

Comments
 (0)