Skip to content

Commit 7f19a2e

Browse files
authored
Revert "Arm backend: Lower MXFP Linear to TOSA" (pytorch#20047)
Reverts pytorch#19969
1 parent 5563ee9 commit 7f19a2e

19 files changed

Lines changed: 35 additions & 1212 deletions

backends/arm/_passes/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@
165165
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
166166
from .rewrite_matmul import RewriteMatmulPass # noqa
167167
from .rewrite_max_pool2d_pass import RewriteMaxPool2dPass # noqa
168-
from .rewrite_mxfp_linear import RewriteMXFPLinearPass # noqa
169168
from .rewrite_pad import RewritePadPass # noqa
170169
from .rewrite_slice import RewriteSlicePass # noqa
171170
from .rewrite_upsample import RewriteUpsamplePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@
141141
RewriteLeLtToGeGtPass,
142142
RewriteMatmulPass,
143143
RewriteMaxPool2dPass,
144-
RewriteMXFPLinearPass,
145144
RewritePadPass,
146145
RewriteSlicePass,
147146
RewriteUpsamplePass,
@@ -525,7 +524,6 @@ def _tosa_pipeline(
525524
RewriteUpsamplePass(),
526525
RewriteMaxPool2dPass(),
527526
RewriteConvPass(exported_program),
528-
RewriteMXFPLinearPass(exported_program),
529527
RewriteMatmulPass(),
530528
RewritePadPass(),
531529
FuseViewCopyTransformPass(),

backends/arm/_passes/rewrite_mxfp_linear.py

Lines changed: 0 additions & 318 deletions
This file was deleted.

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,6 @@ def get_registered_tosa_support_checks(
237237
return checks
238238

239239

240-
class MXOpsSupportList(OperatorSupportBase):
241-
"""Accept Arm MX custom ops when the active spec enables MX support."""
242-
243-
targets = (exir_ops.edge.tosa_mxfp.linear.default,)
244-
245-
def is_node_supported(
246-
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
247-
) -> bool:
248-
return node.op == "call_function" and node.target in self.targets
249-
250-
251240
def tosa_support_factory(
252241
tosa_spec: TosaSpecification,
253242
exported_program: ExportedProgram,
@@ -282,8 +271,6 @@ def tosa_support_factory(
282271
positive_checks.append(TOSAProINTSupportList())
283272
elif tosa_spec.support_float():
284273
positive_checks.append(TOSAProFPSupportList())
285-
if tosa_spec.support_extension("mxfp"):
286-
positive_checks.append(MXOpsSupportList())
287274
# TODO: Refactor to use TOSAProSupportLists + negtive checks
288275
positive_checks += [
289276
check(tosa_spec, reporter)
@@ -309,13 +296,9 @@ def tosa_support_factory(
309296
disallowed_dtypes = [torch.float64]
310297
if not tosa_spec.support_extension("bf16"):
311298
disallowed_dtypes.append(torch.bfloat16)
312-
if not (
313-
tosa_spec.support_extension("fp8e4m3") or tosa_spec.support_extension("mxfp")
314-
):
299+
if not tosa_spec.support_extension("fp8e4m3"):
315300
disallowed_dtypes.append(torch.float8_e4m3fn)
316-
if not (
317-
tosa_spec.support_extension("fp8e5m2") or tosa_spec.support_extension("mxfp")
318-
):
301+
if not tosa_spec.support_extension("fp8e5m2"):
319302
disallowed_dtypes.append(torch.float8_e5m2)
320303
if tosa_spec.is_U55_subset:
321304
disallowed_dtypes.append(torch.bool)
@@ -763,9 +746,6 @@ def is_node_supported(
763746
):
764747
return True
765748

766-
if node.target in MXOpsSupportList.targets:
767-
return True
768-
769749
floating_dtypes = set()
770750
for input_node in (
771751
input_node

0 commit comments

Comments
 (0)