Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
logger = logging.getLogger(__name__)
logger.setLevel(get_coreml_log_level(default_level=logging.INFO))

# Mirrors coremltools' TORCH_DTYPE_TO_MIL_DTYPE. Ops whose inputs use any
# other dtype (e.g. torch.uint8) cannot be lowered, so the partitioner must
# reject them. See pytorch/executorch#11686.
_COREML_SUPPORTED_INPUT_DTYPES = {
torch.bool,
torch.float16,
torch.float32,
torch.float64,
torch.int16,
torch.int32,
torch.int64,
}


def _is_view_op(op: torch._ops.OpOverload) -> bool:
schema = op._schema
Expand Down Expand Up @@ -75,6 +88,18 @@ def should_skip_op_for_delegation(self, node_target_name: str) -> bool:
return False

def should_override_support(self, node) -> bool:
# https://github.com/pytorch/executorch/issues/11686
for arg in node.all_input_nodes:
val = arg.meta.get("val", None)
dtype = getattr(val, "dtype", None)
if dtype is not None and dtype not in _COREML_SUPPORTED_INPUT_DTYPES:
self.log_once(
"Skipping op for CoreML delegation because input dtype "
f"{dtype} is not supported by CoreML: "
+ getattr(node.target, "__name__", str(node.target))
)
return True

# https://github.com/apple/coremltools/issues/2573
if (
node.target
Expand Down Expand Up @@ -147,10 +172,19 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
if self.should_skip_op_for_delegation(node_target_name):
return False

# query coremltools to see if node is supported
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
node
)
# query coremltools to see if node is supported. coremltools may
# raise instead of returning False; treat any exception as
# unsupported and let the op fall back to the portable backend.
try:
is_supported = (
ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
)
except Exception as e:
self.log_once(
"Skipping op for CoreML delegation because coremltools raised "
f"while checking support for {node_target_name}: {e}"
)
is_supported = False
if self.should_override_support(node):
is_supported = False

Expand Down
30 changes: 30 additions & 0 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,35 @@ def forward(self, x):
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
)

def test_unsupported_dtype_does_not_crash_partitioner(self):
"""
Regression test for https://github.com/pytorch/executorch/issues/11686.

coremltools raises KeyError when asked about ops with unsupported input
dtypes (e.g. torch.uint8). The partitioner must treat this as "not
supported" rather than propagating the exception, so the op falls back
to the portable backend.
"""

class Model(torch.nn.Module):
def forward(self, x):
return torch.abs(x)

model = Model().eval()
example_inputs = (torch.randint(0, 255, (1, 10)).to(torch.uint8),)
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
exir_program_aten, partitioner=[CoreMLPartitioner()]
)

op_names = [
node.target.__name__
for node in edge_program_manager.exported_program().graph.nodes
if node.op == "call_function"
]
self.assertNotIn("executorch_call_delegate", op_names)
self.assertIn("aten.abs.default", op_names)

def test_deprecation_warning_for_to_backend_workflow(self):
"""
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
Expand Down Expand Up @@ -435,5 +464,6 @@ def forward(self, x):
test_runner.test_lower_full_graph()
# test_runner.test_symint_arg()
test_runner.test_take_over_constant_data_false()
test_runner.test_unsupported_dtype_does_not_crash_partitioner()
test_runner.test_deprecation_warning_for_to_backend_workflow()
test_runner.test_no_warning_for_to_edge_transform_and_lower_workflow()
Loading