Skip to content

Commit 976bf72

Browse files
committed
Skip torch random ops in CoreML partitioner
coremltools' converter fails the input-count check for the torch random ops (rand / randn / rand_like / randn_like / randint / randint_like) and aborts with an internal error during CoreML compilation. Reject them in the partitioner so they fall back to the portable backend. Fixes #11722.
1 parent 94d2881 commit 976bf72

2 files changed

Lines changed: 50 additions & 0 deletions

File tree

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,28 @@ def should_override_support(self, node) -> bool:
116116
)
117117
return True
118118

119+
# https://github.com/pytorch/executorch/issues/11722
120+
# coremltools' converter for the torch random ops does not pass the
121+
# number-of-inputs check and aborts with an internal error.
122+
if node.target in [
123+
torch.ops.aten.rand.default,
124+
torch.ops.aten.randn.default,
125+
torch.ops.aten.rand_like.default,
126+
torch.ops.aten.randn_like.default,
127+
torch.ops.aten.randint.default,
128+
torch.ops.aten.randint_like.default,
129+
exir_ops.edge.aten.rand.default,
130+
exir_ops.edge.aten.randn.default,
131+
exir_ops.edge.aten.rand_like.default,
132+
exir_ops.edge.aten.randn_like.default,
133+
exir_ops.edge.aten.randint.default,
134+
exir_ops.edge.aten.randint_like.default,
135+
]:
136+
self.log_once(
137+
"torch random ops are not supported by CoreML. Overriding op support."
138+
)
139+
return True
140+
119141
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
120142
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
121143
# # in the placeholders due to partitioning, which CoreML does not support

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,34 @@ def forward(self, x):
338338
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
339339
)
340340

341+
def test_random_ops_are_skipped(self):
342+
"""
343+
Regression test for https://github.com/pytorch/executorch/issues/11722.
344+
345+
coremltools' converter aborts when it encounters torch random ops.
346+
The partitioner must reject them so they fall back to the portable
347+
backend instead of crashing the export.
348+
"""
349+
350+
class Model(torch.nn.Module):
351+
def forward(self, x):
352+
return torch.randn(x.shape) + torch.rand(x.shape) + x
353+
354+
model = Model().eval()
355+
example_inputs = (torch.zeros(5, 5),)
356+
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
357+
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
358+
exir_program_aten, partitioner=[CoreMLPartitioner()]
359+
)
360+
361+
op_names = [
362+
node.target.__name__
363+
for node in edge_program_manager.exported_program().graph.nodes
364+
if node.op == "call_function"
365+
]
366+
self.assertIn("aten.randn.default", op_names)
367+
self.assertIn("aten.rand.default", op_names)
368+
341369
def test_deprecation_warning_for_to_backend_workflow(self):
342370
"""
343371
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.

0 commit comments

Comments
 (0)