Skip to content

Commit fff917e

Browse files
authored
Merge branch 'main' into dmoodie/bugfix/trtexec_safe
2 parents 17c662c + 229ba61 commit fff917e

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

modelopt/torch/quantization/algorithms.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,10 +1094,8 @@ def run_search_with_stats(self, max_weight_size, verbose=False):
10941094
return best_recipes, is_satisfied
10951095

10961096

1097-
# TODO: Enable torch compile for this function
1098-
# Currently modelopt.onnx is breaking this
1097+
@torch.compile(dynamic=True)
10991098
def _get_log_softmax_dist(logits: torch.Tensor, tp_group) -> torch.Tensor:
1100-
# TODO: test this
11011099
dtype = logits.dtype
11021100
max_logits = torch.amax(logits, dim=-1, keepdim=True)
11031101
torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=tp_group)

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
9393
new_args[weight_pos] = self.weight_quantizer(args[weight_pos])
9494
new_args[inp_pos] = self.input_quantizer(args[inp_pos])
9595
output = getattr(package, func_name)(*new_args, **kwargs)
96+
# TE 2.15+ returns `(out, new_weight_workspace)`; TE <= 2.14 returns just `out`.
97+
# Only the activation tensor participates in output quantization.
98+
if isinstance(output, tuple):
99+
return (self.output_quantizer(output[0]), *output[1:])
96100
return self.output_quantizer(output)
97101

98102
# Override the quantized linear function
@@ -181,6 +185,10 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):
181185
for i in range(weights_start, weights_start + num_gemms):
182186
new_args[i] = self.weight_quantizer(args[i])
183187
output = getattr(package, func_name)(*new_args)
188+
# TE 2.15+ returns `(out, new_workspaces)`; TE <= 2.14 returns just `out`.
189+
# Only the activation tensor participates in output quantization.
190+
if isinstance(output, tuple):
191+
return (self.output_quantizer(output[0]), *output[1:])
184192
return self.output_quantizer(output)
185193

186194
# Override the quantized linear function

0 commit comments

Comments
 (0)