Skip to content

Commit 6aa1fce

Browse files
authored
Arm backend: Add MXFP support for Conv2d (#20421)
Add the possibility to convert `torch.nn.Conv2d` submodules to the custom implemented MXFP counterpart `MXFPConv2dOp`. Rewrite the MXFP Conv2d custom op into block-scaled TOSA ops and fix the serializer, typing. Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com> Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent 7eeeb86 commit 6aa1fce

21 files changed

Lines changed: 2096 additions & 37 deletions

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa
172172
from .rewrite_matmul import RewriteMatmulPass # noqa
173173
from .rewrite_max_pool2d_pass import RewriteMaxPool2dPass # noqa
174+
from .rewrite_mxfp_conv2d import RewriteMXFPConv2dPass # noqa
174175
from .rewrite_mxfp_linear import RewriteMXFPLinearPass # noqa
175176
from .rewrite_pad import RewritePadPass # noqa
176177
from .rewrite_slice import RewriteSlicePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
RewriteLeLtToGeGtPass,
148148
RewriteMatmulPass,
149149
RewriteMaxPool2dPass,
150+
RewriteMXFPConv2dPass,
150151
RewriteMXFPLinearPass,
151152
RewritePadPass,
152153
RewriteSlicePass,
@@ -612,6 +613,7 @@ def _tosa_pipeline(
612613
RewriteMaxPool2dPass(),
613614
DecomposeAdaptiveMaxPool2dPass(),
614615
RewriteConvPass(exported_program),
616+
RewriteMXFPConv2dPass(exported_program),
615617
RewriteMXFPLinearPass(exported_program),
616618
RewriteMatmulPass(),
617619
RewritePadPass(),

0 commit comments

Comments
 (0)