From 765aee6ddbc08cb128c6261714ca20dcb28dd144 Mon Sep 17 00:00:00 2001 From: john-rocky Date: Fri, 1 May 2026 13:28:09 +0900 Subject: [PATCH] Apply alpha scale factor in torch add/sub converters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `aten::add(self, other, alpha)` and `aten::sub(self, other, alpha)` compute `self ± alpha * other`, but the converters previously raised on positional `alpha != 1` and silently ignored kwarg `alpha` from the `torch.export` path. As a result, `torch.sub(x, y, alpha=5)` produced `x - y` instead of `x - 5*y` (issue #2573). Look up `alpha` from positional inputs (TorchScript) and kwinputs (torch.export), and apply `y = y * alpha` before the add/sub when `alpha != 1`, mirroring the existing `addmm` handler. The alpha=1 fast path is unchanged. Fixes #2573. --- .../converters/mil/frontend/torch/ops.py | 31 +++++++++++-------- .../mil/frontend/torch/test/test_torch_ops.py | 27 ++++++++++++++++ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index f75051467..c093c7d3c 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -1063,12 +1063,16 @@ def add(context, node): add_inputs = _get_inputs(context, node) assert len(node.outputs) == 1 - # TODO (sberardi): 3rd param to aten::add is a scale factor, need to handle that. - # out=input+alpha x other - # rdar://60175736 - if len(add_inputs) > 2 and add_inputs[2].val != 1: - raise ValueError("ADD does not support scale factor param") x, y = add_inputs[:2] + # aten::add(self, other, alpha=1) computes self + alpha * other. + # alpha may be passed positionally (TorchScript) or as a kwarg (torch.export). + alpha = add_inputs[2] if len(add_inputs) > 2 else None + alpha = _get_kwinputs(context, node, "alpha", default=[alpha])[0] + if isinstance(alpha, Var): + alpha = alpha.val + if alpha is not None and alpha != 1: + y, alpha_var = promote_input_dtypes([y, alpha]) + y = mb.mul(x=y, y=alpha_var) if types.is_bool(x.dtype) and types.is_bool(y.dtype): add_node = mb.logical_or(x=x, y=y, name=node.name) elif types.is_complex(x.dtype) or types.is_complex(y.dtype): @@ -2074,14 +2078,15 @@ def sub(context, node): x = inputs[0] y = inputs[1] - if len(inputs) > 2: - alpha = inputs[2].val - - # TODO (sberardi): 3rd param to aten::sub is a scale factor, need to handle that. - # out=input-alpha x other - # rdar://60175736 - if alpha != 1: - raise ValueError("SUB does not support scale factor param") + # aten::sub(self, other, alpha=1) computes self - alpha * other. + # alpha may be passed positionally (TorchScript) or as a kwarg (torch.export). + alpha = inputs[2] if len(inputs) > 2 else None + alpha = _get_kwinputs(context, node, "alpha", default=[alpha])[0] + if isinstance(alpha, Var): + alpha = alpha.val + if alpha is not None and alpha != 1: + y, alpha_var = promote_input_dtypes([y, alpha]) + y = mb.mul(x=y, y=alpha_var) x, y = promote_input_dtypes([x, y]) res = mb.sub(x=x, y=y, name=node.name) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index dc17406d6..4db09989a 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -4289,6 +4289,33 @@ def forward(self, x, y): ) +class TestAddSubAlpha(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, frontend, op, alpha", + itertools.product( + compute_units, + backends, + frontends, + ["add", "sub"], + [2, 0.5, -3.0], + ), + ) + def test_alpha(self, compute_unit, backend, frontend, op, alpha): + class TestModel(nn.Module): + def forward(self, x, y): + if op == "add": + return torch.add(x, y, alpha=alpha) + return torch.sub(x, y, alpha=alpha) + + self.run_compare_torch( + [(2, 3), (2, 3)], + TestModel(), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, + ) + + class TestFull(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, frontend, rank",