Skip to content

Commit f3a085b

Browse files
committed
Handle both XNNPACK quantization import paths
1 parent f59a113 commit f3a085b

1 file changed

Lines changed: 105 additions & 58 deletions

File tree

backends/xnnpack/debugger/observatory/lenses/xnnpack_patches.py

Lines changed: 105 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
"""XNNPACK backend patches for PipelineGraphCollectorLens.
88
9-
Installs a monkey-patch on executorch.examples.xnnpack.quantization.utils.quantize
10-
to capture the float ExportedProgram with from_node metadata populated.
9+
Installs a monkey-patch on XNNPACK quantization helpers (both
10+
``examples`` and ``executorch.examples`` import paths) to capture the
11+
float ExportedProgram with from_node metadata populated.
1112
"""
1213

1314
from __future__ import annotations
1415

16+
import importlib
1517
import logging
1618
from typing import TYPE_CHECKING
1719

@@ -21,63 +23,108 @@
2123
)
2224

2325

26+
MODULE_CANDIDATES = (
27+
"examples.xnnpack.quantization.utils",
28+
"executorch.examples.xnnpack.quantization.utils",
29+
)
30+
31+
32+
def _install_patch_for_module(
33+
cls: type[PipelineGraphCollectorLens], module, alias: str
34+
) -> bool:
35+
try:
36+
original = module.quantize
37+
except AttributeError:
38+
logging.debug(
39+
"[PipelineGraphCollector] XNNPACK patch skipped; no quantize in %s",
40+
alias,
41+
)
42+
return False
43+
44+
key = f"xnnpack.quantize[{alias}]"
45+
if key in cls._originals:
46+
return True
47+
48+
cls._originals[key] = original
49+
50+
def patched_quantize(model, example_inputs, quant_type=None):
51+
sample = None
52+
try:
53+
if isinstance(example_inputs, (tuple, list)):
54+
sample = tuple(example_inputs)
55+
else:
56+
sample = (example_inputs,)
57+
cls._set_accuracy_fallback_dataset([sample], source=key)
58+
except Exception:
59+
pass
60+
61+
collect_target = model
62+
try:
63+
import torch
64+
65+
if sample is not None:
66+
ep = torch.export.export(model, sample, strict=False)
67+
collect_target = ep.run_decompositions({})
68+
except Exception as exc:
69+
logging.debug(
70+
"[PipelineGraphCollector] XNNPACK from_node re-export skipped: %s",
71+
exc,
72+
)
73+
74+
try:
75+
cls._collect_fn("Exported Float", collect_target)
76+
except Exception as exc:
77+
logging.debug(
78+
"[PipelineGraphCollector] collect skipped (Exported Float): %s",
79+
exc,
80+
)
81+
82+
if quant_type is None:
83+
return original(model, example_inputs)
84+
return original(model, example_inputs, quant_type)
85+
86+
module.quantize = patched_quantize
87+
logging.info(
88+
"[PipelineGraphCollector] Installed XNNPACK patch: quantize (%s)", alias
89+
)
90+
91+
def _uninstall():
92+
try:
93+
module.quantize = original
94+
except Exception:
95+
pass
96+
97+
cls._backend_uninstallers.append(_uninstall)
98+
return True
99+
100+
24101
def install_xnnpack_patches(cls: type[PipelineGraphCollectorLens]) -> None:
25102
"""Install XNNPACK quantize patch on the PipelineGraphCollectorLens."""
26-
try:
27-
import executorch.examples.xnnpack.quantization.utils as xnnpack_qutils
28-
29-
original = xnnpack_qutils.quantize
30-
cls._originals["xnnpack.quantize"] = original
31-
32-
def patched_quantize(model, example_inputs, quant_type=None):
33-
sample = None
34-
try:
35-
if isinstance(example_inputs, (tuple, list)):
36-
sample = tuple(example_inputs)
37-
else:
38-
sample = (example_inputs,)
39-
cls._set_accuracy_fallback_dataset(
40-
[sample], source="xnnpack.quantize"
41-
)
42-
except Exception:
43-
pass
44-
45-
collect_target = model
46-
try:
47-
import torch
48-
49-
if sample is not None:
50-
ep = torch.export.export(model, sample, strict=False)
51-
collect_target = ep.run_decompositions({})
52-
except Exception as exc:
53-
logging.debug(
54-
"[PipelineGraphCollector] XNNPACK from_node re-export skipped: %s",
55-
exc,
56-
)
57-
58-
try:
59-
cls._collect_fn("Exported Float", collect_target)
60-
except Exception as exc:
61-
logging.debug(
62-
"[PipelineGraphCollector] collect skipped (Exported Float): %s",
63-
exc,
64-
)
65-
66-
if quant_type is None:
67-
return original(model, example_inputs)
68-
return original(model, example_inputs, quant_type)
69-
70-
xnnpack_qutils.quantize = patched_quantize
71-
logging.info("[PipelineGraphCollector] Installed XNNPACK patch: quantize")
72-
73-
def _uninstall():
74-
try:
75-
xnnpack_qutils.quantize = original
76-
except Exception:
77-
pass
78-
79-
cls._backend_uninstallers.append(_uninstall)
80-
except Exception as exc:
103+
104+
patched = False
105+
seen_modules: set[int] = set()
106+
107+
for alias in MODULE_CANDIDATES:
108+
try:
109+
module = importlib.import_module(alias)
110+
except ImportError:
111+
continue
112+
113+
module_id = id(module)
114+
if module_id in seen_modules:
115+
continue
116+
seen_modules.add(module_id)
117+
118+
try:
119+
patched |= _install_patch_for_module(cls, module, alias)
120+
except Exception as exc:
121+
logging.warning(
122+
"[PipelineGraphCollector] Failed to patch XNNPACK quantize (%s): %s",
123+
alias,
124+
exc,
125+
)
126+
127+
if not patched:
81128
logging.warning(
82-
"[PipelineGraphCollector] Failed to patch XNNPACK quantize: %s", exc
129+
"[PipelineGraphCollector] Failed to patch XNNPACK quantize: no candidate module found"
83130
)

0 commit comments

Comments
 (0)