|
41 | 41 | from torchao.quantization.utils import compute_error |
42 | 42 |
|
43 | 43 |
|
| 44 | +def _has_quantized_decomposed_out_variants() -> bool: |
| 45 | + """Check if the quantized_decomposed .out variants are registered. |
| 46 | +
|
| 47 | + These are built by the quantized_ops_aot_lib and loaded via |
| 48 | + executorch.kernels.quantized. Under BUCK the library is preloaded; |
| 49 | + under plain pytest it is usually absent. |
| 50 | + """ |
| 51 | + try: |
| 52 | + # Attempt to load the library (no-op if already loaded) |
| 53 | + import executorch.kernels.quantized # noqa: F401 |
| 54 | + |
| 55 | + return ( |
| 56 | + hasattr(torch.ops, "quantized_decomposed") |
| 57 | + and hasattr(torch.ops.quantized_decomposed, "quantize_per_tensor") |
| 58 | + and hasattr( |
| 59 | + torch.ops.quantized_decomposed.quantize_per_tensor, "out" |
| 60 | + ) |
| 61 | + ) |
| 62 | + except Exception: |
| 63 | + return False |
| 64 | + |
| 65 | + |
| 66 | +_skip_no_qd_out = unittest.skipUnless( |
| 67 | + _has_quantized_decomposed_out_variants(), |
| 68 | + "quantized_decomposed .out variants not registered " |
| 69 | + "(build quantized_ops_aot_lib or run via BUCK)", |
| 70 | +) |
| 71 | + |
| 72 | + |
44 | 73 | class TestQuantFusionPass(unittest.TestCase): |
45 | 74 | @classmethod |
46 | 75 | def setUpClass(cls) -> None: |
47 | 76 | register_additional_test_aten_ops() |
48 | 77 |
|
| 78 | + @_skip_no_qd_out |
49 | 79 | def test_add(self) -> None: |
50 | 80 | class M(torch.nn.Module): |
51 | 81 | def forward(self, x, y): |
@@ -85,6 +115,7 @@ def forward(self, x, y): |
85 | 115 | m.exported_program().graph_module.code |
86 | 116 | ) |
87 | 117 |
|
| 118 | + @_skip_no_qd_out |
88 | 119 | def test_reshape(self) -> None: |
89 | 120 | class M(torch.nn.Module): |
90 | 121 | def forward(self, x, y): |
@@ -133,6 +164,7 @@ def forward(self, x, y): |
133 | 164 | "torch.ops.aten.view_copy.out" |
134 | 165 | ).run(m.exported_program().graph_module.code) |
135 | 166 |
|
| 167 | + @_skip_no_qd_out |
136 | 168 | def test_slice(self) -> None: |
137 | 169 | """We don't proactively quantize slice today, but we'll fuse the dq-slice-q |
138 | 170 |
|
@@ -188,6 +220,7 @@ def forward(self, x, y): |
188 | 220 | "torch.ops.aten.slice_copy.Tensor_out" |
189 | 221 | ).run(m.exported_program().graph_module.code) |
190 | 222 |
|
| 223 | + @_skip_no_qd_out |
191 | 224 | def test_cat(self) -> None: |
192 | 225 | class M(torch.nn.Module): |
193 | 226 | def forward(self, x, y): |
|
0 commit comments