Skip to content

Commit 8a397b4

Browse files
Arm backend: Reorder BN-fusing and decomposition (#19276)
Make sure BN-fusing pass is run before decomposition pass. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 2dca183 commit 8a397b4

5 files changed

Lines changed: 52 additions & 7 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def _tosa_pipeline(
439439
ConvertSplitToSlicePass(),
440440
QuantizeClampArgumentsPass(),
441441
RemoveGetItemPass(),
442+
FuseBatchNorm2dPass(exported_program),
442443
DecomposeBatchNormNoStatsPass(),
443444
DecomposeLogitPass(),
444445
DecomposeMaskedFillPass(),
@@ -502,7 +503,6 @@ def _tosa_pipeline(
502503
RewriteBoolBitwiseToLogicalPass(),
503504
DecomposeRemainderPass(),
504505
DecomposeDivTensorModePass(),
505-
FuseBatchNorm2dPass(exported_program),
506506
ConvertMmToBmmPass(),
507507
DecomposeGluPass(),
508508
DecomposeDivPass(),

backends/arm/_passes/fuse_batch_norm2d_pass.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
5656
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default
5757
):
5858
continue
59+
if get_first_fake_tensor(node).dtype == torch.bfloat16:
60+
# Don't fuse if the data type is bfloat16, as the fused weights may
61+
# not be accurate enough and cause significant accuracy drop.
62+
continue
5963

6064
# Get data from batchnorm
6165
input_node = node.all_input_nodes[0]
@@ -153,8 +157,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
153157
if not (
154158
(input_bias_node is None)
155159
or (
156-
isinstance(input_weight_node, Node)
157-
and input_weight_node.op == "placeholder"
160+
isinstance(input_bias_node, Node)
161+
and input_bias_node.op == "placeholder"
158162
)
159163
):
160164
raise RuntimeError(

backends/arm/test/misc/test_transpose_counts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -445,16 +445,16 @@ def forward(self, x):
445445
Model6GruLinear(), (torch.randn(2, 16, 8),), 2
446446
),
447447
"model_7_dwconv_batchnorm_linear": TransposeCountCase(
448-
Model7DwConvBatchNormLinear(), (torch.randn(2, 8, 64),), 3
448+
Model7DwConvBatchNormLinear(), (torch.randn(2, 8, 64),), 1
449449
),
450450
"model_8_conv_batchnorm_maxpool_residual": TransposeCountCase(
451-
Model8ConvBatchNormMaxPoolResidual(), (torch.randn(1, 8, 16, 16),), 6
451+
Model8ConvBatchNormMaxPoolResidual(), (torch.randn(1, 8, 16, 16),), 4
452452
),
453453
"model_9_dilated_conv_batchnorm_avgpool_residual": TransposeCountCase(
454-
Model9DilatedConvBatchNormAvgPoolResidual(), (torch.randn(1, 8, 16, 16),), 6
454+
Model9DilatedConvBatchNormAvgPoolResidual(), (torch.randn(1, 8, 16, 16),), 4
455455
),
456456
"model_10_dwconv_batchnorm_linear_cat": TransposeCountCase(
457-
Model10DwConvBatchNormLinearCat(), (torch.randn(2, 8, 64),), 3
457+
Model10DwConvBatchNormLinearCat(), (torch.randn(2, 8, 64),), 1
458458
),
459459
}
460460

backends/arm/test/ops/test_batch_norm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,20 @@ def test_native_batch_norm_legit_no_training_tosa_FP_conv(test_data: Tuple):
200200
pipeline.run()
201201

202202

203+
@common.parametrize("test_data", test_data_suite)
204+
def test_native_batch_norm_legit_no_training_tosa_FP_conv_fuses_before_decompose(
205+
test_data: Tuple,
206+
):
207+
test_data, model_params = test_data()
208+
pipeline = TosaPipelineFP[input_t1](
209+
BatchNorm2dConv(*model_params),
210+
(test_data,),
211+
aten_op=BatchNorm2dConv.aten_ops,
212+
)
213+
pipeline.count_tosa_ops({"CONV2D": 1, "RSQRT": 0, "SUB": 0})
214+
pipeline.run()
215+
216+
203217
@common.parametrize("test_data", test_data_suite)
204218
def test_native_batch_norm_legit_no_training_tosa_INT_conv(test_data: Tuple):
205219
test_data, model_params = test_data()

backends/arm/test/passes/test_fuse_batchnorm_pass.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5454
return x
5555

5656

57+
class MergeOneOfTwoBNBf16(MergeOneOfTwoBN):
58+
ops_before_pass: ClassVar[Dict[str, int]] = MergeOneOfTwoBN.ops_before_pass
59+
ops_after_pass: ClassVar[Dict[str, int]] = MergeOneOfTwoBN.ops_before_pass
60+
61+
def __init__(self, affine: bool):
62+
super().__init__(affine)
63+
self.to(torch.bfloat16)
64+
65+
def get_inputs(self) -> input_t:
66+
return (torch.randn(1, 3, 256, 256, dtype=torch.bfloat16),)
67+
68+
5769
class MergeTwosOfTwoBN(torch.nn.Module):
5870
ops_before_pass: ClassVar[Dict[str, int]] = {
5971
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
@@ -163,3 +175,18 @@ def test_fuse_batch_norm2d_tosa_FP(module: ModuleWithBatchNormAttrs) -> None:
163175
passes_with_exported_program=[FuseBatchNorm2dPass],
164176
)
165177
pipeline.run()
178+
179+
180+
def test_fuse_batch_norm2d_tosa_FP_bf16_skips_fusion() -> None:
181+
module = cast(ModuleWithBatchNormAttrs, MergeOneOfTwoBNBf16(True))
182+
nn_module = cast(torch.nn.Module, module)
183+
pipeline = PassPipeline[input_t](
184+
nn_module,
185+
module.get_inputs(),
186+
quantize=False,
187+
ops_before_pass=module.ops_before_pass,
188+
ops_after_pass=module.ops_after_pass,
189+
passes_with_exported_program=[FuseBatchNorm2dPass],
190+
tosa_extensions=["bf16"],
191+
)
192+
pipeline.run()

0 commit comments

Comments
 (0)