Skip to content

Commit b4928f7

Browse files
committed
Unify to_quantized_edge_program and to_quantized_executorch_program duplicates.
1 parent b4d4507 commit b4928f7

7 files changed

Lines changed: 150 additions & 284 deletions

File tree

backends/nxp/quantizer/utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import itertools
1111
from collections import OrderedDict
1212
from collections.abc import Iterable
13-
from typing import Any, Dict, List, Tuple, Type
13+
from typing import Any, Callable, Dict, List, Tuple, Type
1414

1515
import torch
1616
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
@@ -30,8 +30,10 @@
3030
check_subgraphs_connected,
3131
SourcePartition,
3232
)
33+
3334
from torchao.quantization.pt2e import (
3435
move_exported_model_to_eval,
36+
move_exported_model_to_train,
3537
ObserverOrFakeQuantize,
3638
)
3739
from torchao.quantization.pt2e.quantize_pt2e import (
@@ -176,16 +178,17 @@ def calibrate_and_quantize(
176178
calibration_inputs: Iterable[tuple[torch.Tensor, ...]],
177179
quantizer: Quantizer,
178180
is_qat: bool = False,
181+
train_fn: Callable[[torch.fx.GraphModule], None] | None = None,
179182
) -> fx.GraphModule:
180183
"""Quantize the provided model.
181184
182185
:param model: Aten model (or it's GraphModule representation) to quantize.
183-
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
184-
input. Or an iterator over such tuples.
186+
:param calibration_inputs: An iterator over tuples of calibration input tensors where each tensor corresponds to a
187+
model input.
185188
:param quantizer: Quantizer to use.
186189
:param is_qat: Whether quantization is done using Quantization Aware Training (QAT) or not.
187190
Note: In QAT mode, training is not performed. Only calibration (in eval mode) is done.
188-
191+
:param train_fn: Optional training function to be called during QAT.
189192
:return: Quantized GraphModule.
190193
"""
191194

@@ -195,12 +198,20 @@ def calibrate_and_quantize(
195198
if is_qat:
196199
m = prepare_qat_pt2e(model, quantizer)
197200
m = AddSimulatedLinearBatchNormFusionQATPass()(m).graph_module
201+
202+
if train_fn:
203+
m = move_exported_model_to_train(m)
204+
train_fn(m)
205+
198206
m = move_exported_model_to_eval(m)
207+
m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module
208+
m = FuseBatchNormWithLinearPass()(m).graph_module
199209
else:
200210
m = prepare_pt2e(model, quantizer)
201211

202-
for data in calibration_inputs:
203-
m(*data)
212+
if not is_qat or (is_qat and not train_fn):
213+
for data in calibration_inputs:
214+
m(*data)
204215

205216
if is_qat:
206217
m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module

backends/nxp/tests/executorch_pipeline.py

Lines changed: 118 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import re
99
from dataclasses import dataclass
1010
from functools import partial
11-
from typing import Callable
11+
from typing import Callable, Iterable
1212

1313
import eiq_neutron_sdk
14+
import numpy as np
1415
import torch
1516

1617
from executorch import exir
@@ -28,7 +29,6 @@
2829
RemoveIOQuantOpsPass,
2930
)
3031
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
31-
3232
from executorch.backends.nxp.nxp_backend import (
3333
core_aten_ops_exception_list,
3434
generate_neutron_compile_spec,
@@ -42,7 +42,7 @@
4242
ExecutorchProgramManager,
4343
to_edge_transform_and_lower,
4444
)
45-
from torch import nn
45+
from torch import memory_format, nn
4646
from torch.export import export
4747
from torchao.quantization.pt2e.quantizer import Quantizer
4848

@@ -52,7 +52,9 @@
5252
@dataclass
5353
class ModelInputSpec:
5454
shape: tuple[int, ...]
55+
type: np.dtype = np.float32
5556
dtype: torch.dtype = torch.float32
57+
dim_order: memory_format = torch.contiguous_format
5658

5759

5860
def handle_kernel_selection(model_name: str = ""):
@@ -81,11 +83,11 @@ def handle_kernel_selection(model_name: str = ""):
8183

8284

8385
def get_random_calibration_inputs(
84-
input_spec: tuple[ModelInputSpec, ...]
86+
input_spec: Iterable[ModelInputSpec], num_samples: int = 4
8587
) -> list[tuple[torch.Tensor, ...]]:
8688
return [
8789
tuple([torch.randn(spec.shape, dtype=spec.dtype) for spec in input_spec])
88-
for _ in range(4)
90+
for _ in range(num_samples)
8991
]
9092

9193

@@ -96,33 +98,89 @@ def _get_default_quantizer(target_spec: NeutronTargetSpec, use_qat: bool) -> Qua
9698
def to_model_input_spec(
9799
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
98100
) -> tuple[ModelInputSpec, ...]:
99-
if isinstance(input_spec, tuple) and all(
100-
isinstance(spec, ModelInputSpec) for spec in input_spec
101-
):
102-
return input_spec
103-
104-
elif isinstance(input_spec, tuple) and all(
105-
isinstance(spec, int) for spec in input_spec
106-
):
107-
return (ModelInputSpec(input_spec),)
108-
109-
elif isinstance(input_spec, list) and all(
110-
isinstance(input_shape, tuple) for input_shape in input_spec
111-
):
112-
return tuple([ModelInputSpec(spec) for spec in input_spec])
113-
else:
114-
raise TypeError(f"Unsupported type {type(input_spec)}")
101+
match input_spec:
102+
case tuple() | list() if all(
103+
isinstance(spec, ModelInputSpec) for spec in input_spec
104+
):
105+
return input_spec
106+
case tuple() if all(isinstance(spec, int) for spec in input_spec):
107+
return (ModelInputSpec(input_spec),)
108+
case list() if all(
109+
isinstance(input_shape, tuple) for input_shape in input_spec
110+
):
111+
return tuple([ModelInputSpec(spec) for spec in input_spec])
112+
case _:
113+
raise TypeError(f"Unsupported type {type(input_spec)}")
114+
115+
116+
GetCalibrationInputsFn = Callable[
117+
[tuple[ModelInputSpec, ...]], Iterable[tuple[torch.Tensor, ...]]
118+
]
119+
120+
121+
def get_calibration_inputs_fn_from_dataset_dir(dataset_dir) -> GetCalibrationInputsFn:
122+
def _nested(
123+
input_spec: tuple[ModelInputSpec, ...]
124+
) -> Iterable[tuple[torch.Tensor, ...]]:
125+
data = sorted(os.listdir(dataset_dir))
126+
inputs_needed = len(input_spec)
127+
128+
for path in data:
129+
path = os.path.join(dataset_dir, path)
130+
files = []
131+
132+
if os.path.isdir(path):
133+
files = [os.path.join(path, x) for x in sorted(os.listdir(path))]
134+
else:
135+
files.append(path)
136+
137+
input_data = []
138+
for idx, file in enumerate(files):
139+
if len(input_data) == inputs_needed:
140+
break
141+
142+
tensor = np.fromfile(file, dtype=input_spec[idx].type).reshape(
143+
input_spec[idx].shape
144+
)
145+
input_data += (torch.from_numpy(tensor),)
146+
continue
147+
148+
if len(input_data) < inputs_needed:
149+
continue
150+
151+
yield tuple(input_data)
152+
153+
return _nested
154+
155+
156+
def _get_example_input(
157+
input_spec: tuple[ModelInputSpec, ...]
158+
) -> tuple[torch.Tensor, ...]:
159+
example_input = []
160+
for spec in input_spec:
161+
match spec.dim_order:
162+
case torch.contiguous_format:
163+
sample = torch.ones(spec.shape, dtype=spec.dtype)
164+
case torch.channels_last:
165+
sample = torch.ones(spec.shape, dtype=spec.dtype).to(
166+
memory_format=torch.channels_last
167+
)
168+
case _:
169+
raise ValueError(f"Unsupported dim_order: {spec.dim_order}")
170+
# noinspection PyUnboundLocalVariable
171+
example_input.append(sample)
172+
173+
return tuple(example_input)
115174

116175

117176
def to_quantized_edge_program(
118177
model: torch.nn.Module,
119-
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
178+
input_spec: list[ModelInputSpec] | tuple[int, ...] | list[tuple[int, ...]],
120179
operators_not_to_delegate: list[str] = None,
121-
get_calibration_inputs_fn: Callable[
122-
[tuple[ModelInputSpec, ...]], list[tuple[torch.Tensor, ...]]
123-
] = get_random_calibration_inputs,
180+
get_calibration_inputs_fn: GetCalibrationInputsFn = get_random_calibration_inputs,
124181
target: str = "imxrt700",
125182
use_qat: bool = False,
183+
train_fn: Callable[[torch.fx.GraphModule], None] | None = None,
126184
remove_quant_io_ops: bool = False,
127185
custom_delegation_options: CustomDelegationOptions = CustomDelegationOptions(), # noqa B008
128186
get_quantizer_fn: Callable[[], Quantizer] | None = None,
@@ -131,15 +189,16 @@ def to_quantized_edge_program(
131189
fetch_constants_to_sram: bool = False,
132190
dump_kernel_selection_code: bool = False,
133191
use_new_flow_neutron_c: bool = False,
192+
delegate_to_npu=True,
134193
) -> EdgeProgramManager:
135194
_neutron_target_spec = NeutronTargetSpec(target)
136195
if get_quantizer_fn is None:
137196
get_quantizer_fn = partial(
138197
_get_default_quantizer, _neutron_target_spec, use_qat
139198
)
140-
141-
calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec))
142-
example_input = calibration_inputs[0]
199+
input_spec = to_model_input_spec(input_spec)
200+
calibration_inputs = get_calibration_inputs_fn(input_spec)
201+
example_input = _get_example_input(input_spec)
143202

144203
# Make sure the model is in the evaluation mode.
145204
model.eval()
@@ -151,6 +210,7 @@ def to_quantized_edge_program(
151210
calibration_inputs=calibration_inputs,
152211
quantizer=get_quantizer_fn(),
153212
is_qat=use_qat,
213+
train_fn=train_fn,
154214
)
155215

156216
# List of operators to not decompose during the lowering.
@@ -166,15 +226,18 @@ def to_quantized_edge_program(
166226
post_quant_state_dict = (
167227
exir_program_aten__module_quant.state_dict() if use_quant_state_dict else None
168228
)
169-
partitioners = [
170-
NeutronPartitioner(
171-
compile_spec,
172-
_neutron_target_spec,
173-
custom_delegation_options,
174-
post_quant_state_dict,
175-
preserve_ops=preserve_ops,
176-
)
177-
]
229+
if delegate_to_npu:
230+
partitioners = [
231+
NeutronPartitioner(
232+
compile_spec,
233+
_neutron_target_spec,
234+
custom_delegation_options,
235+
post_quant_state_dict,
236+
preserve_ops=preserve_ops,
237+
)
238+
]
239+
else:
240+
partitioners = []
178241

179242
edge_program_manager = to_edge_transform_and_lower(
180243
export(exir_program_aten__module_quant, example_input, strict=True),
@@ -205,13 +268,31 @@ def to_quantized_executorch_program(
205268
model: torch.nn.Module,
206269
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
207270
use_qat: bool = False,
271+
train_fn: Callable[[torch.fx.GraphModule], None] | None = None,
208272
use_neutron_for_format_conversion: bool = True,
273+
dataset_dir: str | None = None,
274+
delegate_to_npu=True,
275+
use_new_flow_neutron_c: bool = False,
209276
) -> ExecutorchProgramManager:
277+
if dataset_dir:
278+
# Extract calibration data from a directory.
279+
get_calibration_inputs_fn = {
280+
"get_calibration_inputs_fn": get_calibration_inputs_fn_from_dataset_dir(
281+
dataset_dir
282+
)
283+
}
284+
else:
285+
get_calibration_inputs_fn = {} # Use default parameter value.
286+
210287
edge_program_manager = to_quantized_edge_program(
211288
model,
212289
input_spec,
213290
use_qat=use_qat,
291+
train_fn=train_fn,
214292
use_neutron_for_format_conversion=use_neutron_for_format_conversion,
293+
delegate_to_npu=delegate_to_npu,
294+
use_new_flow_neutron_c=use_new_flow_neutron_c,
295+
**get_calibration_inputs_fn,
215296
)
216297

217298
return edge_program_manager.to_executorch(

backends/nxp/tests_models/dataset_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import numpy as np
1515
import torch
1616
from executorch.backends.nxp.backend.ir.converter.conversion import translator
17+
from executorch.backends.nxp.tests.executorch_pipeline import ModelInputSpec
1718
from executorch.backends.nxp.tests_models.calibration_dataset import CalibrationDataset
18-
from executorch.backends.nxp.tests_models.model_input_spec import ModelInputSpec
1919
from torch import Tensor
2020

2121

backends/nxp/tests_models/executors.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,20 @@
1919
from executorch.backends.nxp.backend.edge_helper import is_channels_last_dim_order
2020
from executorch.backends.nxp.backend.ir.converter.conversion import translator
2121
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
22+
from executorch.backends.nxp.tests.executorch_pipeline import (
23+
get_calibration_inputs_fn_from_dataset_dir,
24+
ModelInputSpec,
25+
to_quantized_edge_program,
26+
to_quantized_executorch_program,
27+
)
2228
from executorch.backends.nxp.tests_models.config_importer import test_config
2329
from executorch.backends.nxp.tests_models.dataset_creator import RandomDatasetCreator
2430
from executorch.backends.nxp.tests_models.graph_verifier import GraphVerifier
25-
from executorch.backends.nxp.tests_models.model_input_spec import ModelInputSpec
2631
from executorch.backends.nxp.tests_models.model_output_comparator import (
2732
AllCloseOutputComparator,
2833
)
2934
from executorch.backends.nxp.tests_models.outputs_dir_importer import outputs_dir
30-
from executorch.backends.nxp.tests_models.utils import (
31-
save_pte_program,
32-
to_quantized_edge_program,
33-
to_quantized_executorch_program,
34-
)
35+
from executorch.backends.nxp.tests_models.utils import save_pte_program
3536
from executorch.devtools.visualization.visualization_utils import (
3637
visualize_with_clusters,
3738
)
@@ -113,7 +114,7 @@ def wrapper(*args, **kwargs):
113114
delegated_program = to_quantized_executorch_program(
114115
model,
115116
input_spec,
116-
calibration_dataset_dir,
117+
dataset_dir=calibration_dataset_dir,
117118
delegate_to_npu=True,
118119
use_qat=use_qat,
119120
train_fn=train_fn,
@@ -175,7 +176,7 @@ def _run_non_delegated_executorch_program(
175176
non_delegated_program = to_quantized_executorch_program(
176177
model,
177178
input_spec,
178-
calibration_dataset_dir,
179+
dataset_dir=calibration_dataset_dir,
179180
delegate_to_npu=False,
180181
use_qat=use_qat,
181182
train_fn=train_fn,
@@ -463,7 +464,9 @@ def convert_run_compare(
463464
to_quantized_edge_program(
464465
model_to_not_delegate,
465466
input_spec,
466-
calibration_dataset_dir,
467+
get_calibration_inputs_fn=get_calibration_inputs_fn_from_dataset_dir(
468+
calibration_dataset_dir
469+
),
467470
delegate_to_npu=False,
468471
use_qat=use_qat,
469472
train_fn=train_fn,

backends/nxp/tests_models/model_input_spec.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)