Skip to content

Commit 6150522

Browse files
Add deprecation warning for to_edge + to_backend workflow in CoreMLPa… (#17082)
Fixes #15960 This PR adds a deprecation warning in the CoreMLPartitioner to guide users away from the deprecated [to_edge() + to_backend()](cci:1://file:///c:/Users/moham/Documents/osc/executorch-contribution/backends/xnnpack/test/test_xnnpack_partitioner.py:62:4-88:61) workflow and toward the recommended [to_edge_transform_and_lower()](cci:1://file:///c:/Users/moham/Documents/osc/executorch-contribution/backends/xnnpack/test/test_xnnpack_partitioner.py:62:4-88:61) flow. ## Changes - Added [_check_if_called_from_to_backend()](cci:1://file:///c:/Users/moham/Documents/osc/executorch-contribution/backends/apple/coreml/partition/coreml_partitioner.py:225:4-241:20) method to detect deprecated workflow - Modified [partition()](cci:1://file:///c:/Users/moham/Documents/osc/executorch-contribution/backends/apple/coreml/partition/coreml_partitioner.py:243:4-292:9) to log warning when deprecated flow is detected - Added two unit tests to verify warning behavior ## Pattern Following the same pattern as #13209 (XNNPACK). ## Related This picks up the abandoned PR #15963. cc @kimishpatel @YifanShenSZ @cymbalrush @metascroy --------- Signed-off-by: mohammed-saalim <mohammed.saalim.k@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent e96ab14 commit 6150522

2 files changed

Lines changed: 123 additions & 0 deletions

File tree

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

5+
import inspect
56
import logging
67
from typing import Callable, List, Optional, Tuple
78

@@ -222,7 +223,38 @@ def __init__(
222223
self.take_over_mutable_buffer
223224
), "When lower_full_graph=True, you must set take_over_mutable_buffer=True"
224225

226+
def _check_if_called_from_to_backend(self) -> bool:
227+
"""
228+
Check if the partition method is being called from the deprecated to_backend workflow.
229+
Returns True if called from deprecated direct to_backend, False if called from to_edge_transform_and_lower.
230+
"""
231+
stack = inspect.stack()
232+
233+
for frame_info in stack:
234+
if frame_info.function == "to_edge_transform_and_lower":
235+
return False
236+
237+
for frame_info in stack:
238+
if frame_info.function == "to_backend":
239+
filename = frame_info.filename
240+
if "program/_program.py" in filename:
241+
return True
242+
return False
243+
225244
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
245+
"""
246+
Override partition to add deprecation warning when called from to_backend.
247+
"""
248+
# Check if we're being called from the deprecated to_backend workflow
249+
if self._check_if_called_from_to_backend():
250+
logger.warning(
251+
"\nDEPRECATION WARNING: You are using the deprecated 'to_edge() + to_backend()' workflow. "
252+
"This may result in decreased performance because ExecuTorch decomposes ops (e.g., SDPA) "
253+
"that CoreML has optimized implementations for. "
254+
"Please consider migrating to 'to_edge_transform_and_lower()' for better performance. "
255+
"See: https://docs.pytorch.org/executorch/main/backends/coreml/coreml-overview.html#using-the-core-ml-backend"
256+
)
257+
226258
# Run the CapabilityBasedPartitioner to return the largest possible
227259
# subgraphs containing the nodes with the tags
228260
logger.info("CoreMLPartitioner::partition")

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

55
import copy
6+
import io
7+
import logging
68
import sys
79
import unittest
810

@@ -336,6 +338,93 @@ def forward(self, x):
336338
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
337339
)
338340

341+
def test_deprecation_warning_for_to_backend_workflow(self):
342+
"""
343+
Test that the deprecated to_edge + to_backend workflow shows a deprecation warning.
344+
"""
345+
346+
class SimpleModel(torch.nn.Module):
347+
def __init__(self):
348+
super().__init__()
349+
self.linear = torch.nn.Linear(10, 5)
350+
351+
def forward(self, x):
352+
return self.linear(x)
353+
354+
model = SimpleModel()
355+
model.eval()
356+
x = torch.randn(1, 10)
357+
358+
exported_model = torch.export.export(model, (x,), strict=True)
359+
360+
# Capture log output to check for deprecation warning
361+
log_capture_string = io.StringIO()
362+
ch = logging.StreamHandler(log_capture_string)
363+
ch.setLevel(logging.WARNING)
364+
365+
partitioner_logger = logging.getLogger(
366+
"executorch.backends.apple.coreml.partition.coreml_partitioner"
367+
)
368+
partitioner_logger.addHandler(ch)
369+
partitioner_logger.setLevel(logging.WARNING)
370+
371+
edge = executorch.exir.to_edge(
372+
exported_model, compile_config=self.edge_compile_config
373+
)
374+
partitioner = CoreMLPartitioner()
375+
376+
edge.to_backend(partitioner)
377+
378+
log_contents = log_capture_string.getvalue()
379+
self.assertIn("DEPRECATION WARNING", log_contents)
380+
self.assertIn("to_edge() + to_backend()", log_contents)
381+
self.assertIn("to_edge_transform_and_lower()", log_contents)
382+
383+
# Clean up handler
384+
partitioner_logger.removeHandler(ch)
385+
386+
def test_no_warning_for_to_edge_transform_and_lower_workflow(self):
387+
"""
388+
Test that the recommended to_edge_transform_and_lower workflow does NOT show a deprecation warning.
389+
"""
390+
391+
class SimpleModel(torch.nn.Module):
392+
def __init__(self):
393+
super().__init__()
394+
self.linear = torch.nn.Linear(10, 5)
395+
396+
def forward(self, x):
397+
return self.linear(x)
398+
399+
model = SimpleModel()
400+
model.eval()
401+
x = torch.randn(1, 10)
402+
403+
exported_model = torch.export.export(model, (x,), strict=True)
404+
405+
# Capture log output to check for deprecation warning
406+
log_capture_string = io.StringIO()
407+
ch = logging.StreamHandler(log_capture_string)
408+
ch.setLevel(logging.WARNING)
409+
410+
partitioner_logger = logging.getLogger(
411+
"executorch.backends.apple.coreml.partition.coreml_partitioner"
412+
)
413+
partitioner_logger.addHandler(ch)
414+
partitioner_logger.setLevel(logging.WARNING)
415+
416+
partitioner = CoreMLPartitioner()
417+
418+
executorch.exir.to_edge_transform_and_lower(
419+
exported_model, partitioner=[partitioner]
420+
)
421+
422+
log_contents = log_capture_string.getvalue()
423+
self.assertNotIn("DEPRECATION WARNING", log_contents)
424+
425+
# Clean up handler
426+
partitioner_logger.removeHandler(ch)
427+
339428

340429
if __name__ == "__main__":
341430
test_runner = TestCoreMLPartitioner()
@@ -346,3 +435,5 @@ def forward(self, x):
346435
test_runner.test_lower_full_graph()
347436
# test_runner.test_symint_arg()
348437
test_runner.test_take_over_constant_data_false()
438+
test_runner.test_deprecation_warning_for_to_backend_workflow()
439+
test_runner.test_no_warning_for_to_edge_transform_and_lower_workflow()

0 commit comments

Comments
 (0)