Skip to content

Commit 941486e

Browse files
authored
NXP backend: Add Linear+BN fusion to conversion pipeline of unit tests (#18527)
### Summary Enables proper quantization and conversion of Linear+BN based models in our integration tests by adding Linear+BN fusion related passes. It also introduces Accuracy testing support for non-softmax-based models. ### Test plan Covered by our NXP internal integration tests. cc @robert-kalmar @JakeStevens @digantdesai
1 parent 656850a commit 941486e

3 files changed

Lines changed: 88 additions & 17 deletions

File tree

backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,12 @@ def _is_linear(node_: Node):
177177

178178
# Replace the uses of the BatchNorm with the Linear.
179179
bn_node.replace_all_uses_with(linear_node)
180+
graph_module.graph.erase_node(bn_node)
180181

181182
made_changes = True
182183

184+
if made_changes:
185+
graph_module.graph.eliminate_dead_code()
186+
graph_module.recompile()
187+
183188
return PassResult(graph_module, made_changes)

backends/nxp/tests_models/model_output_comparator.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
from abc import abstractmethod
99
from pathlib import Path
10+
from typing import Callable
1011

1112
import numpy as np
1213
import polars as pl
@@ -57,13 +58,11 @@ def compare_results(self, cpu_results_dir, npu_results_dir, output_tensor_spec):
5758
cpu_tensor = np.fromfile(
5859
cpu_tensor_path, dtype=torch_type_to_numpy_type(tensor_spec.dtype)
5960
)
60-
np.reshape(cpu_tensor, tensor_spec.shape)
6161
cpu_output_tensors.append((output_tensor_name, cpu_tensor))
6262

6363
npu_tensor = np.fromfile(
6464
npu_tensor_path, dtype=torch_type_to_numpy_type(tensor_spec.dtype)
6565
)
66-
np.reshape(npu_tensor, tensor_spec.shape)
6766
npu_output_tensors.append((output_tensor_name, npu_tensor))
6867

6968
self.compare_sample(sample_dir, cpu_output_tensors, npu_output_tensors)
@@ -95,17 +94,30 @@ def compare_sample(self, sample_dir, cpu_output_tensors, npu_output_tensors):
9594
assert np.allclose(cpu_tensor, npu_tensor, atol=self.atol)
9695

9796

97+
def _default_postprocess_fn(outputs: np.ndarray, _: str):
98+
return np.argmax(outputs, axis=-1)
99+
100+
98101
class ClassificationAccuracyOutputComparator(BaseOutputComparator):
99102

100-
def __init__(self, class_dict: dict[int, str], tolerance=0.0):
103+
def __init__(
104+
self,
105+
class_dict: dict[int, str],
106+
postprocess_fn: Callable[
107+
[np.ndarray, str], np.ndarray
108+
] = _default_postprocess_fn,
109+
tolerance=0.0,
110+
):
101111
"""
102112
Comparator for comparing model prediction accuracies based on a ground-truth annotations.
103113
The comparator passes if finetuned model results have higher accuracy than baseline (accounting for a tolerance).
104114
105-
:param class_dict: Dictionary mapping class names to class indices.
115+
:param class_dict: Dictionary mapping class indices to class names.
116+
:param postprocess_fn: An optional callback for postprocessing model output into classification predictions.
106117
:param tolerance: Tolerance threshold for accuracy comparison.
107118
Used for checking `baseline_acc + tolerance < finetuned_acc`.
108119
"""
120+
self.postprocess_fn = postprocess_fn
109121
self.tolerance = tolerance
110122
self.inv_class_dict = {v: k for k, v in class_dict.items()}
111123

@@ -141,6 +153,9 @@ def compare_results(
141153
total_samples = 0
142154

143155
for sample_dir in sample_dirs:
156+
finetuned_sample_paths = []
157+
baseline_sample_paths = []
158+
144159
finetuned_output_tensors = []
145160
baseline_output_tensors = []
146161

@@ -157,18 +172,24 @@ def compare_results(
157172
baseline_tensor_path,
158173
dtype=torch_type_to_numpy_type(tensor_spec.dtype),
159174
)
160-
np.reshape(baseline_tensor, tensor_spec.shape)
175+
baseline_tensor = np.reshape(baseline_tensor, tensor_spec.shape)
176+
baseline_sample_paths.append(baseline_tensor_path)
161177
baseline_output_tensors.append((output_tensor_name, baseline_tensor))
162178

163179
finetuned_tensor = np.fromfile(
164180
finetuned_tensor_path,
165181
dtype=torch_type_to_numpy_type(tensor_spec.dtype),
166182
)
167-
np.reshape(finetuned_tensor, tensor_spec.shape)
183+
finetuned_tensor = np.reshape(finetuned_tensor, tensor_spec.shape)
184+
finetuned_sample_paths.append(finetuned_tensor_path)
168185
finetuned_output_tensors.append((output_tensor_name, finetuned_tensor))
169186

170187
finetuned_correct, baseline_correct, total = self.compare_sample(
171-
sample_dir, baseline_output_tensors, finetuned_output_tensors
188+
sample_dir,
189+
baseline_sample_paths,
190+
baseline_output_tensors,
191+
finetuned_sample_paths,
192+
finetuned_output_tensors,
172193
)
173194

174195
finetuned_total_correct += finetuned_correct
@@ -187,35 +208,70 @@ def compare_results(
187208
)
188209

189210
def compare_sample(
190-
self, sample_dir, baseline_output_tensors, finetuned_output_tensors
211+
self,
212+
sample_dir,
213+
baseline_filepaths,
214+
baseline_output_tensors,
215+
finetuned_filepaths,
216+
finetuned_output_tensors,
191217
) -> tuple[int, int, int]:
192-
baseline_correct = 0
193-
finetuned_correct = 0
218+
baseline_correct_total = 0
219+
finetuned_correct_total = 0
220+
total_samples = 0
221+
222+
if not isinstance(sample_dir, str) or len(sample_dir.split("_")) < 3:
223+
raise ValueError(
224+
f"Sample dir format invalid. Expected format: 'example_classname_0', got {sample_dir}"
225+
)
194226

195-
if not isinstance(sample_dir, str) or len(sample_dir.split("_")) < 2:
227+
dir_parts = sample_dir.split("_")
228+
first_numerical_index = next(
229+
(i for i, s in enumerate(dir_parts) if s.isdigit()), -1
230+
)
231+
232+
if first_numerical_index < 2:
196233
raise ValueError(
197234
f"Sample dir format invalid. Expected format: 'example_classname_0', got {sample_dir}"
198235
)
199236

200-
class_name = sample_dir.split("_")[1]
237+
class_name = "_".join(dir_parts[1:first_numerical_index])
201238
class_id = self.inv_class_dict[class_name]
202239

203240
for idx in range(len(baseline_output_tensors)):
204241
(baseline_output_name, baseline_tensor) = baseline_output_tensors[idx]
205242
(finetuned_output_name, finetuned_tensor) = finetuned_output_tensors[idx]
206243

207244
assert baseline_output_name == finetuned_output_name
245+
assert baseline_tensor.shape == finetuned_tensor.shape
208246
assert np.any(
209247
baseline_tensor
210248
), "Output tensor contains only zeros. This is suspicious."
211249

212-
finetuned_class = np.argmax(finetuned_tensor, axis=-1)
213-
baseline_class = np.argmax(baseline_tensor, axis=-1)
250+
finetuned_class = self.postprocess_fn(
251+
finetuned_tensor, finetuned_filepaths[idx]
252+
)
253+
baseline_class = self.postprocess_fn(
254+
baseline_tensor, baseline_filepaths[idx]
255+
)
256+
257+
baseline_correct = baseline_class == class_id
258+
finetuned_correct = finetuned_class == class_id
214259

215-
baseline_correct += baseline_class == class_id
216-
finetuned_correct += finetuned_class == class_id
260+
baseline_correct_total += (
261+
baseline_correct
262+
if np.isscalar(baseline_correct)
263+
else sum(baseline_correct)
264+
)
265+
finetuned_correct_total += (
266+
finetuned_correct
267+
if np.isscalar(finetuned_correct)
268+
else sum(finetuned_correct)
269+
)
270+
total_samples += (
271+
1 if np.isscalar(finetuned_correct) else len(baseline_correct)
272+
)
217273

218-
return finetuned_correct, baseline_correct, len(baseline_output_tensors)
274+
return finetuned_correct_total, baseline_correct_total, total_samples
219275

220276

221277
class NumericalStatsOutputComparator(BaseOutputComparator):

backends/nxp/tests_models/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
import numpy as np
1515
import torch
1616

17+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
18+
FuseBatchNormWithLinearPass,
19+
)
20+
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes import (
21+
AddSimulatedLinearBatchNormFusionQATPass,
22+
RemoveSimulatedLinearBatchNormFusionQATPass,
23+
)
1724
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1825
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
1926
NeutronEdgePassManager,
@@ -92,12 +99,15 @@ def to_quantized_edge_program(
9299
)
93100
if use_qat:
94101
m = prepare_qat_pt2e(module, quantizer)
102+
m = AddSimulatedLinearBatchNormFusionQATPass()(m).graph_module
95103

96104
if train_fn:
97105
m = move_exported_model_to_train(m)
98106
train_fn(m)
99107

100108
m = move_exported_model_to_eval(m)
109+
m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module
110+
m = FuseBatchNormWithLinearPass()(m).graph_module
101111
else:
102112
m = prepare_pt2e(module, quantizer)
103113

0 commit comments

Comments
 (0)