Skip to content

Commit 94e1305

Browse files
committed
Cast int input to fp32 in torch reciprocal converter
`torch.reciprocal` returns a float for int inputs in PyTorch, but `mb.inverse` only accepts fp16/fp32. As a result, common patterns like `1 / x.shape[0]` (which TorchScript traces as `reciprocal(prim::NumToTensor(int))`) failed conversion with: Op (op_type: inverse) Input x expects tensor or scalar of dtype from type domain ['fp16', 'fp32'] but got tensor[1, int32] Insert a fp32 cast before `mb.inverse` when the input dtype is integer, mirroring the pattern already used by `log`, `sqrt`, and other unary ops that share the same MIL constraint. Verified end-to-end on the issue repro and a representative RoPE-style inverse-frequency expression. Fixes #2579.
1 parent e95804f commit 94e1305

2 files changed

Lines changed: 37 additions & 1 deletion

File tree

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7323,7 +7323,14 @@ def floor(context, node):
73237323
@register_torch_op
73247324
def reciprocal(context, node):
73257325
inputs = _get_inputs(context, node, expected=1)
7326-
context.add(mb.inverse(x=inputs[0], name=node.name))
7326+
x = inputs[0]
7327+
# PyTorch's reciprocal promotes int inputs to float; mb.inverse only
7328+
# accepts fp16/fp32. Without this cast, common patterns like
7329+
# `1 / x.shape[0]` (which TorchScript traces as
7330+
# reciprocal(prim::NumToTensor(int))) fail to convert.
7331+
if types.is_int(x.dtype):
7332+
x = mb.cast(x=x, dtype="fp32")
7333+
context.add(mb.inverse(x=x, name=node.name))
73277334

73287335

73297336
@register_torch_op

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6676,6 +6676,35 @@ def test_div(self, compute_unit, backend, frontend, rounding_mode, x2_type):
66766676
)
66776677

66786678

6679+
class TestReciprocal(TorchBaseTest):
6680+
@pytest.mark.parametrize(
6681+
"compute_unit, backend, frontend",
6682+
itertools.product(compute_units, backends, frontends),
6683+
)
6684+
def test_reciprocal_int_shape(self, compute_unit, backend, frontend):
6685+
# Regression test for #2579: TorchScript traces `16 / x.shape[0]`
6686+
# as reciprocal(int) -> mul(16), and reciprocal previously rejected
6687+
# the int input because mb.inverse only accepts fp16/fp32.
6688+
if frontend in TORCH_EXPORT_BASED_FRONTENDS:
6689+
pytest.skip("torch.export folds shape-derived constants")
6690+
6691+
class TestModel(nn.Module):
6692+
def forward(self, x):
6693+
return 16 / x.shape[0] * x
6694+
6695+
# mb.inverse uses hardware reciprocal with limited precision; loosen
6696+
# tolerance to accommodate fp16 backends.
6697+
self.run_compare_torch(
6698+
(2, 16, 11),
6699+
TestModel(),
6700+
frontend=frontend,
6701+
backend=backend,
6702+
compute_unit=compute_unit,
6703+
atol=1e-2,
6704+
rtol=1e-2,
6705+
)
6706+
6707+
66796708
class TestElementWiseUnary(TorchBaseTest):
66806709
@pytest.mark.parametrize(
66816710
"compute_unit, backend, frontend, shape, op_string",

0 commit comments

Comments
 (0)