Skip to content

Commit aa06dbc

Browse files
authored
Skip torch random ops in CoreML partitioner (pytorch#19246)
### Summary coremltools' MIL converter fails the input-count check for the torch random ops (`rand` / `randn` / `rand_like` / `randn_like` / `randint` / `randint_like`) and aborts during compilation with an internal error (see issue for the full traceback). Add these ops to `should_override_support` so the partitioner refuses to delegate them and they fall back to the portable backend, the same way `acosh` / `asinh` are handled today. Fixes pytorch#11722. ### Test plan Added `test_random_ops_are_skipped` which lowers a model that adds `torch.randn` + `torch.rand` outputs and asserts the random ops remain in the top-level graph (not delegated). Also reproduced the original `randn_like` repro from the issue and confirmed it now lowers without crashing: ``` $ python -m unittest -v executorch.backends.apple.coreml.test.test_coreml_partitioner.TestCoreMLPartitioner.test_random_ops_are_skipped Ran 1 test in 0.475s OK ``` Authored with Claude. cc @metascroy
1 parent 4c474af commit aa06dbc

2 files changed

Lines changed: 86 additions & 20 deletions

File tree

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,36 @@
3333
logger.setLevel(get_coreml_log_level(default_level=logging.INFO))
3434

3535

36+
# Ops that the CoreML partitioner must always reject regardless of their
37+
# arguments. Each entry is annotated with the upstream issue that motivates
38+
# it so future readers can tell when an entry is safe to drop.
39+
_UNSUPPORTED_OP_TARGETS = frozenset(
40+
[
41+
# https://github.com/apple/coremltools/issues/2565 — diagonal has a
42+
# CoreML correctness bug.
43+
torch.ops.aten.diagonal.default,
44+
torch.ops.aten.diagonal_copy.default,
45+
exir_ops.edge.aten.diagonal.default,
46+
exir_ops.edge.aten.diagonal_copy.default,
47+
# https://github.com/apple/coremltools/issues/2569 — acosh / asinh
48+
# are not implemented in coremltools.
49+
torch.ops.aten.acosh.default,
50+
exir_ops.edge.aten.acosh.default,
51+
torch.ops.aten.asinh.default,
52+
exir_ops.edge.aten.asinh.default,
53+
# https://github.com/pytorch/executorch/issues/11722 — only
54+
# ``aten.rand.default`` actually reaches an unimplemented branch in
55+
# coremltools 9.0 ("not enough values to unpack (expected 5, got 1)"
56+
# raised from the rand handler). ``randn`` / ``rand_like`` /
57+
# ``randn_like`` / ``randint`` lower cleanly today, so we leave them
58+
# delegated. Verified locally against coremltools 9.0 / Python 3.10
59+
# by lowering each op in isolation.
60+
torch.ops.aten.rand.default,
61+
exir_ops.edge.aten.rand.default,
62+
]
63+
)
64+
65+
3666
def _is_view_op(op: torch._ops.OpOverload) -> bool:
3767
schema = op._schema
3868
if len(schema.arguments) == 0:
@@ -92,27 +122,13 @@ def should_override_support(self, node) -> bool:
92122
)
93123
return True
94124

95-
# https://github.com/apple/coremltools/issues/2565
96-
if node.target in [
97-
torch.ops.aten.diagonal.default,
98-
torch.ops.aten.diagonal_copy.default,
99-
exir_ops.edge.aten.diagonal.default,
100-
exir_ops.edge.aten.diagonal_copy.default,
101-
]:
102-
self.log_once(
103-
"torch.ops.aten.diagonal.default has a bug in CoreML. Overriding op support."
104-
)
105-
return True
106-
107-
# https://github.com/apple/coremltools/issues/2569
108-
if node.target in [
109-
torch.ops.aten.acosh.default,
110-
exir_ops.edge.aten.acosh.default,
111-
torch.ops.aten.asinh.default,
112-
exir_ops.edge.aten.asinh.default,
113-
]:
125+
# Ops that are unsupported by CoreML purely on the basis of their
126+
# target — no per-arg conditions to check. Grouped by upstream issue
127+
# so the comment trail still points at the underlying coremltools /
128+
# executorch bug for each entry.
129+
if node.target in _UNSUPPORTED_OP_TARGETS:
114130
self.log_once(
115-
"torch.ops.aten.{acosh, asinh}.default is not supported by CoreML. Overriding op support."
131+
f"{node.target} is not supported by CoreML. Overriding op support."
116132
)
117133
return True
118134

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,54 @@ def forward(self, x):
338338
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
339339
)
340340

341+
def test_aten_rand_default_falls_back_to_portable(self):
342+
"""
343+
Regression test for https://github.com/pytorch/executorch/issues/11722.
344+
345+
coremltools 9.0's ``aten.rand.default`` handler hits an unimplemented
346+
branch (``not enough values to unpack (expected 5, got 1)``). The
347+
partitioner must reject it so the op falls back to the portable
348+
backend instead of crashing the export. Sibling ops like
349+
``aten.randn``, ``aten.rand_like``, etc. lower cleanly and are
350+
intentionally still delegated.
351+
"""
352+
353+
class Model(torch.nn.Module):
354+
def forward(self, x):
355+
return torch.rand(x.shape) + x
356+
357+
model = Model().eval()
358+
example_inputs = (torch.zeros(5, 5),)
359+
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
360+
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
361+
exir_program_aten, partitioner=[CoreMLPartitioner()]
362+
)
363+
op_names = [
364+
node.target.__name__
365+
for node in edge_program_manager.exported_program().graph.nodes
366+
if node.op == "call_function"
367+
]
368+
self.assertIn("aten.rand.default", op_names)
369+
370+
def test_aten_randn_is_still_delegated(self):
371+
"""``aten.randn`` is *not* in the deny list — it lowers cleanly."""
372+
373+
class Model(torch.nn.Module):
374+
def forward(self, x):
375+
return torch.randn(x.shape) + x
376+
377+
ep = torch.export.export(Model().eval(), (torch.zeros(5, 5),), strict=True)
378+
edge = executorch.exir.to_edge_transform_and_lower(
379+
ep, partitioner=[CoreMLPartitioner()]
380+
)
381+
op_names = [
382+
n.target.__name__
383+
for n in edge.exported_program().graph.nodes
384+
if n.op == "call_function"
385+
]
386+
self.assertIn("executorch_call_delegate", op_names)
387+
self.assertNotIn("aten.randn.default", op_names)
388+
341389
def test_deprecation_warning_for_to_backend_workflow(self):
342390
"""
343391
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
@@ -435,5 +483,7 @@ def forward(self, x):
435483
test_runner.test_lower_full_graph()
436484
# test_runner.test_symint_arg()
437485
test_runner.test_take_over_constant_data_false()
486+
test_runner.test_aten_rand_default_falls_back_to_portable()
487+
test_runner.test_aten_randn_is_still_delegated()
438488
test_runner.test_deprecation_warning_for_to_backend_workflow()
439489
test_runner.test_no_warning_for_to_edge_transform_and_lower_workflow()

0 commit comments

Comments
 (0)