Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ def should_override_support(self, node) -> bool:
)
return True

# https://github.com/pytorch/executorch/issues/11715
# argmin/argmax with dim=None reduces over the flattened input, which
# CoreML does not support and causes intermittent process crashes.
if node.target in [
torch.ops.aten.argmax.default,
torch.ops.aten.argmin.default,
exir_ops.edge.aten.argmax.default,
exir_ops.edge.aten.argmin.default,
]:
dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("dim", None)
if dim is None:
self.log_once(
"torch.ops.aten.{argmax, argmin}.default with dim=None is "
"not supported by CoreML. Overriding op support."
)
return True

# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
# # in the placeholders due to partitioning, which CoreML does not support
Expand Down
45 changes: 45 additions & 0 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,51 @@ def forward(self, x):
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
)

def test_argmax_argmin_dim_none_is_skipped(self):
"""
Regression test for https://github.com/pytorch/executorch/issues/11715.

argmax/argmin with dim=None reduces over the flattened tensor, which
CoreML does not support; the resulting model intermittently crashes
the process at runtime. The partitioner must reject these so they
fall back to the portable backend, while still delegating the
ordinary dim=int form.
"""

class FlatModel(torch.nn.Module):
def forward(self, x):
return torch.argmax(x, dim=None, keepdim=False) + torch.argmin(
x, dim=None
)

ep = torch.export.export(FlatModel().eval(), (torch.randn(10, 10),), strict=True)
edge = executorch.exir.to_edge_transform_and_lower(
ep, partitioner=[CoreMLPartitioner()]
)
op_names = [
n.target.__name__
for n in edge.exported_program().graph.nodes
if n.op == "call_function"
]
self.assertIn("aten.argmax.default", op_names)
self.assertIn("aten.argmin.default", op_names)

class DimModel(torch.nn.Module):
def forward(self, x):
return torch.argmax(x, dim=1)

ep = torch.export.export(DimModel().eval(), (torch.randn(10, 10),), strict=True)
edge = executorch.exir.to_edge_transform_and_lower(
ep, partitioner=[CoreMLPartitioner()]
)
op_names = [
n.target.__name__
for n in edge.exported_program().graph.nodes
if n.op == "call_function"
]
self.assertIn("executorch_call_delegate", op_names)
self.assertNotIn("aten.argmax.default", op_names)

def test_deprecation_warning_for_to_backend_workflow(self):
"""
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
Expand Down
Loading