Skip to content

Commit c897fbe

Browse files
fix(te-plugin): handle TE 2.15+ tuple return from _Linear / _GroupedLinear
TE 2.15+ changed `_Linear.forward` and `_GroupedLinear.forward` to return `(out, new_workspace)` tuples instead of a single tensor. ModelOpt's patched `te_quantized_linear_fn` / `te_grouped_quantized_linear_fn` still passed the whole tuple into `self.output_quantizer`, crashing inside `TensorQuantizer.forward` on `tuple.numel()`: AttributeError: 'tuple' object has no attribute 'numel' Mirror the existing pattern from `_QuantTELayerNormLinear.forward`: quantize only `output[0]` (activation) and pass auxiliary workspace metadata through verbatim. TE <= 2.14 returns a single tensor and falls through the isinstance branch unchanged. This unblocks Megatron-Bridge's TE 2.15 path; the local `patch_modelopt_te_linear_tuple_output` shim can be removed once this ships in a tagged release. Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 50e112e commit c897fbe

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
9393
new_args[weight_pos] = self.weight_quantizer(args[weight_pos])
9494
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
9595
output = getattr(package, func_name)(*new_args, **kwargs)
96+
# TE 2.15+ returns `(out, new_weight_workspace)`; TE <= 2.14 returns just `out`.
97+
# Only the activation tensor participates in output quantization.
98+
if isinstance(output, tuple):
99+
return (self.output_quantizer(output[0]), *output[1:])
96100
return self.output_quantizer(output)
97101

98102
# Override the quantized linear function
@@ -181,6 +185,10 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):
181185
for i in range(weights_start, weights_start + num_gemms):
182186
new_args[i] = self.weight_quantizer(args[i])
183187
output = getattr(package, func_name)(*new_args)
188+
# TE 2.15+ returns `(out, new_workspaces)`; TE <= 2.14 returns just `out`.
189+
# Only the activation tensor participates in output quantization.
190+
if isinstance(output, tuple):
191+
return (self.output_quantizer(output[0]), *output[1:])
184192
return self.output_quantizer(output)
185193

186194
# Override the quantized linear function

0 commit comments

Comments
 (0)