Skip to content

Commit 35cbab6

Browse files
committed
Apply alpha scale factor in torch add/sub converters
`aten::add(self, other, alpha)` and `aten::sub(self, other, alpha)` compute `self ± alpha * other`, but the converters previously raised on positional `alpha != 1` and silently ignored kwarg `alpha` from the `torch.export` path. As a result, `torch.sub(x, y, alpha=5)` produced `x - y` instead of `x - 5*y` (issue #2573). Look up `alpha` from positional inputs (TorchScript) and kwinputs (torch.export), and apply `y = y * alpha` before the add/sub when `alpha != 1`, mirroring the existing `addmm` handler. The alpha=1 fast path is unchanged. Fixes #2573.
1 parent e95804f commit 35cbab6

2 files changed

Lines changed: 45 additions & 13 deletions

File tree

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,12 +1063,16 @@ def add(context, node):
10631063
add_inputs = _get_inputs(context, node)
10641064
assert len(node.outputs) == 1
10651065

1066-
# TODO (sberardi): 3rd param to aten::add is a scale factor, need to handle that.
1067-
# out=input+alpha x other
1068-
# rdar://60175736
1069-
if len(add_inputs) > 2 and add_inputs[2].val != 1:
1070-
raise ValueError("ADD does not support scale factor param")
10711066
x, y = add_inputs[:2]
1067+
# aten::add(self, other, alpha=1) computes self + alpha * other.
1068+
# alpha may be passed positionally (TorchScript) or as a kwarg (torch.export).
1069+
alpha = add_inputs[2] if len(add_inputs) > 2 else None
1070+
alpha = _get_kwinputs(context, node, "alpha", default=[alpha])[0]
1071+
if isinstance(alpha, Var):
1072+
alpha = alpha.val
1073+
if alpha is not None and alpha != 1:
1074+
y, alpha_var = promote_input_dtypes([y, alpha])
1075+
y = mb.mul(x=y, y=alpha_var)
10721076
if types.is_bool(x.dtype) and types.is_bool(y.dtype):
10731077
add_node = mb.logical_or(x=x, y=y, name=node.name)
10741078
elif types.is_complex(x.dtype) or types.is_complex(y.dtype):
@@ -2074,14 +2078,15 @@ def sub(context, node):
20742078
x = inputs[0]
20752079
y = inputs[1]
20762080

2077-
if len(inputs) > 2:
2078-
alpha = inputs[2].val
2079-
2080-
# TODO (sberardi): 3rd param to aten::sub is a scale factor, need to handle that.
2081-
# out=input-alpha x other
2082-
# rdar://60175736
2083-
if alpha != 1:
2084-
raise ValueError("SUB does not support scale factor param")
2081+
# aten::sub(self, other, alpha=1) computes self - alpha * other.
2082+
# alpha may be passed positionally (TorchScript) or as a kwarg (torch.export).
2083+
alpha = inputs[2] if len(inputs) > 2 else None
2084+
alpha = _get_kwinputs(context, node, "alpha", default=[alpha])[0]
2085+
if isinstance(alpha, Var):
2086+
alpha = alpha.val
2087+
if alpha is not None and alpha != 1:
2088+
y, alpha_var = promote_input_dtypes([y, alpha])
2089+
y = mb.mul(x=y, y=alpha_var)
20852090

20862091
x, y = promote_input_dtypes([x, y])
20872092
res = mb.sub(x=x, y=y, name=node.name)

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4289,6 +4289,33 @@ def forward(self, x, y):
42894289
)
42904290

42914291

4292+
class TestAddSubAlpha(TorchBaseTest):
4293+
@pytest.mark.parametrize(
4294+
"compute_unit, backend, frontend, op, alpha",
4295+
itertools.product(
4296+
compute_units,
4297+
backends,
4298+
frontends,
4299+
["add", "sub"],
4300+
[2, 0.5, -3.0],
4301+
),
4302+
)
4303+
def test_alpha(self, compute_unit, backend, frontend, op, alpha):
4304+
class TestModel(nn.Module):
4305+
def forward(self, x, y):
4306+
if op == "add":
4307+
return torch.add(x, y, alpha=alpha)
4308+
return torch.sub(x, y, alpha=alpha)
4309+
4310+
self.run_compare_torch(
4311+
[(2, 3), (2, 3)],
4312+
TestModel(),
4313+
frontend=frontend,
4314+
backend=backend,
4315+
compute_unit=compute_unit,
4316+
)
4317+
4318+
42924319
class TestFull(TorchBaseTest):
42934320
@pytest.mark.parametrize(
42944321
"compute_unit, backend, frontend, rank",

0 commit comments

Comments
 (0)