Skip to content

Commit 182ae1d

Browse files
committed
Skip argmin/argmax with dim=None in CoreML partitioner
argmax/argmin with dim=None reduces over the flattened input, which CoreML does not support and which intermittently crashes the process at runtime. Reject these in the partitioner so they fall back to the portable backend; the ordinary dim=int form is still delegated. Fixes #11715.
1 parent 94d2881 commit 182ae1d

2 files changed

Lines changed: 62 additions & 0 deletions

File tree

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ def should_override_support(self, node) -> bool:
116116
)
117117
return True
118118

119+
# https://github.com/pytorch/executorch/issues/11715
120+
# argmin/argmax with dim=None reduces over the flattened input, which
121+
# CoreML does not support and causes intermittent process crashes.
122+
if node.target in [
123+
torch.ops.aten.argmax.default,
124+
torch.ops.aten.argmin.default,
125+
exir_ops.edge.aten.argmax.default,
126+
exir_ops.edge.aten.argmin.default,
127+
]:
128+
dim = node.args[1] if len(node.args) >= 2 else node.kwargs.get("dim", None)
129+
if dim is None:
130+
self.log_once(
131+
"torch.ops.aten.{argmax, argmin}.default with dim=None is "
132+
"not supported by CoreML. Overriding op support."
133+
)
134+
return True
135+
119136
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
120137
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
121138
# # in the placeholders due to partitioning, which CoreML does not support

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,51 @@ def forward(self, x):
338338
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
339339
)
340340

341+
def test_argmax_argmin_dim_none_is_skipped(self):
342+
"""
343+
Regression test for https://github.com/pytorch/executorch/issues/11715.
344+
345+
argmax/argmin with dim=None reduces over the flattened tensor, which
346+
CoreML does not support; the resulting model intermittently crashes
347+
the process at runtime. The partitioner must reject these so they
348+
fall back to the portable backend, while still delegating the
349+
ordinary dim=int form.
350+
"""
351+
352+
class FlatModel(torch.nn.Module):
353+
def forward(self, x):
354+
return torch.argmax(x, dim=None, keepdim=False) + torch.argmin(
355+
x, dim=None
356+
)
357+
358+
ep = torch.export.export(FlatModel().eval(), (torch.randn(10, 10),), strict=True)
359+
edge = executorch.exir.to_edge_transform_and_lower(
360+
ep, partitioner=[CoreMLPartitioner()]
361+
)
362+
op_names = [
363+
n.target.__name__
364+
for n in edge.exported_program().graph.nodes
365+
if n.op == "call_function"
366+
]
367+
self.assertIn("aten.argmax.default", op_names)
368+
self.assertIn("aten.argmin.default", op_names)
369+
370+
class DimModel(torch.nn.Module):
371+
def forward(self, x):
372+
return torch.argmax(x, dim=1)
373+
374+
ep = torch.export.export(DimModel().eval(), (torch.randn(10, 10),), strict=True)
375+
edge = executorch.exir.to_edge_transform_and_lower(
376+
ep, partitioner=[CoreMLPartitioner()]
377+
)
378+
op_names = [
379+
n.target.__name__
380+
for n in edge.exported_program().graph.nodes
381+
if n.op == "call_function"
382+
]
383+
self.assertIn("executorch_call_delegate", op_names)
384+
self.assertNotIn("aten.argmax.default", op_names)
385+
341386
def test_deprecation_warning_for_to_backend_workflow(self):
342387
"""
343388
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.

0 commit comments

Comments
 (0)