Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,12 +1063,16 @@ def add(context, node):
add_inputs = _get_inputs(context, node)
assert len(node.outputs) == 1

# TODO (sberardi): 3rd param to aten::add is a scale factor, need to handle that.
# out=input+alpha x other
# rdar://60175736
if len(add_inputs) > 2 and add_inputs[2].val != 1:
raise ValueError("ADD does not support scale factor param")
x, y = add_inputs[:2]
# aten::add(self, other, alpha=1) computes self + alpha * other.
# alpha may be passed positionally (TorchScript) or as a kwarg (torch.export).
alpha = add_inputs[2] if len(add_inputs) > 2 else None
alpha = _get_kwinputs(context, node, "alpha", default=[alpha])[0]
if isinstance(alpha, Var):
alpha = alpha.val
if alpha is not None and alpha != 1:
y, alpha_var = promote_input_dtypes([y, alpha])
y = mb.mul(x=y, y=alpha_var)
if types.is_bool(x.dtype) and types.is_bool(y.dtype):
add_node = mb.logical_or(x=x, y=y, name=node.name)
elif types.is_complex(x.dtype) or types.is_complex(y.dtype):
Expand Down Expand Up @@ -2074,14 +2078,15 @@ def sub(context, node):
x = inputs[0]
y = inputs[1]

if len(inputs) > 2:
alpha = inputs[2].val

# TODO (sberardi): 3rd param to aten::sub is a scale factor, need to handle that.
# out=input-alpha x other
# rdar://60175736
if alpha != 1:
raise ValueError("SUB does not support scale factor param")
# aten::sub(self, other, alpha=1) computes self - alpha * other.
# alpha may be passed positionally (TorchScript) or as a kwarg (torch.export).
alpha = inputs[2] if len(inputs) > 2 else None
alpha = _get_kwinputs(context, node, "alpha", default=[alpha])[0]
if isinstance(alpha, Var):
alpha = alpha.val
if alpha is not None and alpha != 1:
y, alpha_var = promote_input_dtypes([y, alpha])
y = mb.mul(x=y, y=alpha_var)

x, y = promote_input_dtypes([x, y])
res = mb.sub(x=x, y=y, name=node.name)
Expand Down
27 changes: 27 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4289,6 +4289,33 @@ def forward(self, x, y):
)


class TestAddSubAlpha(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, op, alpha",
itertools.product(
compute_units,
backends,
frontends,
["add", "sub"],
[2, 0.5, -3.0],
),
)
def test_alpha(self, compute_unit, backend, frontend, op, alpha):
class TestModel(nn.Module):
def forward(self, x, y):
if op == "add":
return torch.add(x, y, alpha=alpha)
return torch.sub(x, y, alpha=alpha)

self.run_compare_torch(
[(2, 3), (2, 3)],
TestModel(),
frontend=frontend,
backend=backend,
compute_unit=compute_unit,
)


class TestFull(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, frontend, rank",
Expand Down