@@ -93,6 +93,10 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
9393 new_args [weight_pos ] = self .weight_quantizer (args [weight_pos ])
9494 new_args [inp_pos ] = self .input_quantizer (args [inp_pos ])
9595 output = getattr (package , func_name )(* new_args , ** kwargs )
96+ # TE 2.15+ returns `(out, new_weight_workspace)`; TE <= 2.14 returns just `out`.
97+ # Only the activation tensor participates in output quantization.
98+ if isinstance (output , tuple ):
99+ return (self .output_quantizer (output [0 ]), * output [1 :])
96100 return self .output_quantizer (output )
97101
98102 # Override the quantized linear function
@@ -181,6 +185,10 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):
181185 for i in range (weights_start , weights_start + num_gemms ):
182186 new_args [i ] = self .weight_quantizer (args [i ])
183187 output = getattr (package , func_name )(* new_args )
188+ # TE 2.15+ returns `(out, new_workspaces)`; TE <= 2.14 returns just `out`.
189+ # Only the activation tensor participates in output quantization.
190+ if isinstance (output , tuple ):
191+ return (self .output_quantizer (output [0 ]), * output [1 :])
184192 return self .output_quantizer (output )
185193
186194 # Override the quantized linear function
0 commit comments