Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,28 @@ def should_override_support(self, node) -> bool:
)
return True

# https://github.com/pytorch/executorch/issues/11722
# coremltools' converter for the torch random ops does not pass the
# number-of-inputs check and aborts with an internal error.
if node.target in [
torch.ops.aten.rand.default,
torch.ops.aten.randn.default,
torch.ops.aten.rand_like.default,
torch.ops.aten.randn_like.default,
torch.ops.aten.randint.default,
torch.ops.aten.randint_like.default,
exir_ops.edge.aten.rand.default,
exir_ops.edge.aten.randn.default,
exir_ops.edge.aten.rand_like.default,
exir_ops.edge.aten.randn_like.default,
exir_ops.edge.aten.randint.default,
exir_ops.edge.aten.randint_like.default,
]:
self.log_once(
"torch random ops are not supported by CoreML. Overriding op support."
)
return True

# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
# # in the placeholders due to partitioning, which CoreML does not support
Expand Down
28 changes: 28 additions & 0 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,34 @@ def forward(self, x):
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
)

def test_random_ops_are_skipped(self):
"""
Regression test for https://github.com/pytorch/executorch/issues/11722.

coremltools' converter aborts when it encounters torch random ops.
The partitioner must reject them so they fall back to the portable
backend instead of crashing the export.
"""

class Model(torch.nn.Module):
def forward(self, x):
return torch.randn(x.shape) + torch.rand(x.shape) + x

model = Model().eval()
example_inputs = (torch.zeros(5, 5),)
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
exir_program_aten, partitioner=[CoreMLPartitioner()]
)

op_names = [
node.target.__name__
for node in edge_program_manager.exported_program().graph.nodes
if node.op == "call_function"
]
self.assertIn("aten.randn.default", op_names)
self.assertIn("aten.rand.default", op_names)

def test_deprecation_warning_for_to_backend_workflow(self):
"""
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
Expand Down
Loading