88import re
99from dataclasses import dataclass
1010from functools import partial
11- from typing import Callable
11+ from typing import Callable , Iterable
1212
1313import eiq_neutron_sdk
14+ import numpy as np
1415import torch
1516
1617from executorch import exir
2829 RemoveIOQuantOpsPass ,
2930)
3031from executorch .backends .nxp .neutron_partitioner import NeutronPartitioner
31-
3232from executorch .backends .nxp .nxp_backend import (
3333 core_aten_ops_exception_list ,
3434 generate_neutron_compile_spec ,
4242 ExecutorchProgramManager ,
4343 to_edge_transform_and_lower ,
4444)
45- from torch import nn
45+ from torch import memory_format , nn
4646from torch .export import export
4747from torchao .quantization .pt2e .quantizer import Quantizer
4848
5252@dataclass
5353class ModelInputSpec :
5454 shape : tuple [int , ...]
55+ type : np .dtype = np .float32
5556 dtype : torch .dtype = torch .float32
57+ dim_order : memory_format = torch .contiguous_format
5658
5759
5860def handle_kernel_selection (model_name : str = "" ):
@@ -81,11 +83,11 @@ def handle_kernel_selection(model_name: str = ""):
8183
8284
8385def get_random_calibration_inputs (
84- input_spec : tuple [ModelInputSpec , ...]
86+ input_spec : Iterable [ModelInputSpec ], num_samples : int = 4
8587) -> list [tuple [torch .Tensor , ...]]:
8688 return [
8789 tuple ([torch .randn (spec .shape , dtype = spec .dtype ) for spec in input_spec ])
88- for _ in range (4 )
90+ for _ in range (num_samples )
8991 ]
9092
9193
@@ -94,35 +96,91 @@ def _get_default_quantizer(target_spec: NeutronTargetSpec, use_qat: bool) -> Qua
9496
9597
9698def to_model_input_spec (
97- input_spec : tuple [ModelInputSpec , ... ] | tuple [int , ...] | list [tuple [int , ...]]
99+ input_spec : Iterable [ModelInputSpec ] | tuple [int , ...] | list [tuple [int , ...]]
98100) -> tuple [ModelInputSpec , ...]:
99- if isinstance (input_spec , tuple ) and all (
100- isinstance (spec , ModelInputSpec ) for spec in input_spec
101- ):
102- return input_spec
103-
104- elif isinstance (input_spec , tuple ) and all (
105- isinstance (spec , int ) for spec in input_spec
106- ):
107- return (ModelInputSpec (input_spec ),)
108-
109- elif isinstance (input_spec , list ) and all (
110- isinstance (input_shape , tuple ) for input_shape in input_spec
111- ):
112- return tuple ([ModelInputSpec (spec ) for spec in input_spec ])
113- else :
114- raise TypeError (f"Unsupported type { type (input_spec )} " )
101+ match input_spec :
102+ case _ if isinstance (input_spec , Iterable ) and all (
103+ isinstance (spec , ModelInputSpec ) for spec in input_spec
104+ ):
105+ return tuple (input_spec )
106+ case tuple () if all (isinstance (spec , int ) for spec in input_spec ):
107+ return (ModelInputSpec (input_spec ),)
108+ case list () if all (
109+ isinstance (input_shape , tuple ) for input_shape in input_spec
110+ ):
111+ return tuple ([ModelInputSpec (spec ) for spec in input_spec ])
112+ case _:
113+ raise TypeError (f"Unsupported type { type (input_spec )} " )
114+
115+
116+ GetCalibrationInputsFn = Callable [
117+ [tuple [ModelInputSpec , ...]], Iterable [tuple [torch .Tensor , ...]]
118+ ]
119+
120+
121+ def get_calibration_inputs_fn_from_dataset_dir (dataset_dir ) -> GetCalibrationInputsFn :
122+ def _nested (
123+ input_spec : tuple [ModelInputSpec , ...]
124+ ) -> Iterable [tuple [torch .Tensor , ...]]:
125+ data = sorted (os .listdir (dataset_dir ))
126+ inputs_needed = len (input_spec )
127+
128+ for path in data :
129+ path = os .path .join (dataset_dir , path )
130+ files = []
131+
132+ if os .path .isdir (path ):
133+ files = [os .path .join (path , x ) for x in sorted (os .listdir (path ))]
134+ else :
135+ files .append (path )
136+
137+ input_data = []
138+ for idx , file in enumerate (files ):
139+ if len (input_data ) == inputs_needed :
140+ break
141+
142+ tensor = np .fromfile (file , dtype = input_spec [idx ].type ).reshape (
143+ input_spec [idx ].shape
144+ )
145+ input_data += (torch .from_numpy (tensor ),)
146+ continue
147+
148+ if len (input_data ) < inputs_needed :
149+ continue
150+
151+ yield tuple (input_data )
152+
153+ return _nested
154+
155+
156+ def _get_example_input (
157+ input_spec : tuple [ModelInputSpec , ...]
158+ ) -> tuple [torch .Tensor , ...]:
159+ example_input = []
160+ for spec in input_spec :
161+ match spec .dim_order :
162+ case torch .contiguous_format :
163+ sample = torch .ones (spec .shape , dtype = spec .dtype )
164+ case torch .channels_last :
165+ sample = torch .ones (spec .shape , dtype = spec .dtype ).to (
166+ memory_format = torch .channels_last
167+ )
168+ case _:
169+ raise ValueError (f"Unsupported dim_order: { spec .dim_order } " )
170+ # noinspection PyUnboundLocalVariable
171+ example_input .append (sample )
172+
173+ return tuple (example_input )
115174
116175
117176def to_quantized_edge_program (
118177 model : torch .nn .Module ,
119- input_spec : tuple [ModelInputSpec , ... ] | tuple [int , ...] | list [tuple [int , ...]],
178+ input_spec : list [ModelInputSpec ] | tuple [int , ...] | list [tuple [int , ...]],
120179 operators_not_to_delegate : list [str ] = None ,
121- get_calibration_inputs_fn : Callable [
122- [tuple [ModelInputSpec , ...]], list [tuple [torch .Tensor , ...]]
123- ] = get_random_calibration_inputs ,
180+ get_calibration_inputs_fn : GetCalibrationInputsFn = get_random_calibration_inputs ,
124181 target : str = "imxrt700" ,
125182 use_qat : bool = False ,
183+ train_fn : Callable [[torch .fx .GraphModule ], None ] | None = None ,
126184 remove_quant_io_ops : bool = False ,
127185 custom_delegation_options : CustomDelegationOptions = CustomDelegationOptions (), # noqa B008
128186 get_quantizer_fn : Callable [[], Quantizer ] | None = None ,
@@ -131,15 +189,16 @@ def to_quantized_edge_program(
131189 fetch_constants_to_sram : bool = False ,
132190 dump_kernel_selection_code : bool = False ,
133191 use_new_flow_neutron_c : bool = False ,
192+ delegate_to_npu = True ,
134193) -> EdgeProgramManager :
135194 _neutron_target_spec = NeutronTargetSpec (target )
136195 if get_quantizer_fn is None :
137196 get_quantizer_fn = partial (
138197 _get_default_quantizer , _neutron_target_spec , use_qat
139198 )
140-
141- calibration_inputs = get_calibration_inputs_fn (to_model_input_spec ( input_spec ) )
142- example_input = calibration_inputs [ 0 ]
199+ input_spec = to_model_input_spec ( input_spec )
200+ calibration_inputs = get_calibration_inputs_fn (input_spec )
201+ example_input = _get_example_input ( input_spec )
143202
144203 # Make sure the model is in the evaluation mode.
145204 model .eval ()
@@ -151,6 +210,7 @@ def to_quantized_edge_program(
151210 calibration_inputs = calibration_inputs ,
152211 quantizer = get_quantizer_fn (),
153212 is_qat = use_qat ,
213+ train_fn = train_fn ,
154214 )
155215
156216 # List of operators to not decompose during the lowering.
@@ -166,15 +226,18 @@ def to_quantized_edge_program(
166226 post_quant_state_dict = (
167227 exir_program_aten__module_quant .state_dict () if use_quant_state_dict else None
168228 )
169- partitioners = [
170- NeutronPartitioner (
171- compile_spec ,
172- _neutron_target_spec ,
173- custom_delegation_options ,
174- post_quant_state_dict ,
175- preserve_ops = preserve_ops ,
176- )
177- ]
229+ if delegate_to_npu :
230+ partitioners = [
231+ NeutronPartitioner (
232+ compile_spec ,
233+ _neutron_target_spec ,
234+ custom_delegation_options ,
235+ post_quant_state_dict ,
236+ preserve_ops = preserve_ops ,
237+ )
238+ ]
239+ else :
240+ partitioners = []
178241
179242 edge_program_manager = to_edge_transform_and_lower (
180243 export (exir_program_aten__module_quant , example_input , strict = True ),
@@ -205,13 +268,31 @@ def to_quantized_executorch_program(
205268 model : torch .nn .Module ,
206269 input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
207270 use_qat : bool = False ,
271+ train_fn : Callable [[torch .fx .GraphModule ], None ] | None = None ,
208272 use_neutron_for_format_conversion : bool = True ,
273+ dataset_dir : str | None = None ,
274+ delegate_to_npu = True ,
275+ use_new_flow_neutron_c : bool = False ,
209276) -> ExecutorchProgramManager :
277+ if dataset_dir :
278+ # Extract calibration data from a directory.
279+ get_calibration_inputs_fn = {
280+ "get_calibration_inputs_fn" : get_calibration_inputs_fn_from_dataset_dir (
281+ dataset_dir
282+ )
283+ }
284+ else :
285+ get_calibration_inputs_fn = {} # Use default parameter value.
286+
210287 edge_program_manager = to_quantized_edge_program (
211288 model ,
212289 input_spec ,
213290 use_qat = use_qat ,
291+ train_fn = train_fn ,
214292 use_neutron_for_format_conversion = use_neutron_for_format_conversion ,
293+ delegate_to_npu = delegate_to_npu ,
294+ use_new_flow_neutron_c = use_new_flow_neutron_c ,
295+ ** get_calibration_inputs_fn ,
215296 )
216297
217298 return edge_program_manager .to_executorch (
0 commit comments