@@ -36,6 +36,7 @@ class DebugMode(Enum):
3636 compiler_flags : list [str ] = field (default_factory = list )
3737 path_for_intermediates : str | None = None
3838 tosa_debug_mode : DebugMode | None = None
39+ preserve_io_quantization : bool = False
3940
4041 _TOSA_SPEC_KEY = "tosa_spec"
4142 _COMPILE_FLAGS_KEY = "compile_flags"
@@ -44,6 +45,7 @@ class DebugMode(Enum):
4445 _DEBUG_MODE_KEY = "dump_debug_info"
4546 _OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
4647 _TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config"
48+ _PRESERVE_IO_QUANT_KEY = "preserve_io_quantization"
4749
4850 def _set_compile_specs (
4951 self ,
@@ -53,6 +55,7 @@ def _set_compile_specs(
5355 tosa_debug_mode : DebugMode | None = None ,
5456 output_order_workaround : bool = False ,
5557 pipeline_config : ArmPassPipelineConfig | None = None ,
58+ preserve_io_quantization : bool = False ,
5659 ):
5760 """Set all values of dataclass directly."""
5861 self .tosa_spec = tosa_spec
@@ -61,6 +64,8 @@ def _set_compile_specs(
6164 self .tosa_debug_mode = tosa_debug_mode
6265 self ._pipeline_config = pipeline_config
6366 self .output_order_workaround = output_order_workaround
67+ self .preserve_io_quantization = preserve_io_quantization
68+ self ._warn_if_redundant_preserve_io_quantization ()
6469 if output_order_workaround :
6570 warnings .warn (
6671 "ArmCompileSpec(output_order_workaround=True) is deprecated and will be "
@@ -78,6 +83,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
7883 tosa_debug_mode : ArmCompileSpec .DebugMode | None = None
7984 output_order_workaround : bool = False
8085 pipeline_config : ArmPassPipelineConfig | None = None
86+ preserve_io_quantization : bool = False
8187 unknown_specs : dict [str , str ] = {}
8288 for spec in compile_specs :
8389 key = spec .key
@@ -128,6 +134,8 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
128134 "More than one transform pipeline entry in compile spec."
129135 )
130136 pipeline_config = ArmPassPipelineConfig .from_dict (json .loads (val ))
137+ elif key == ArmCompileSpec ._PRESERVE_IO_QUANT_KEY :
138+ preserve_io_quantization = str (val ).lower () in ("1" , "true" , "yes" )
131139 else :
132140 unknown_specs [key ] = val
133141
@@ -151,6 +159,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
151159 tosa_debug_mode = tosa_debug_mode ,
152160 output_order_workaround = output_order_workaround ,
153161 pipeline_config = pipeline_config ,
162+ preserve_io_quantization = preserve_io_quantization ,
154163 )
155164 cls ._from_list_hook (compile_spec , unknown_specs )
156165 compile_spec ._validate ()
@@ -227,8 +236,35 @@ def _to_list(self):
227236 self ._pipeline_config .serialize (),
228237 )
229238 )
239+ compile_spec .append (
240+ CompileSpec (
241+ ArmCompileSpec ._PRESERVE_IO_QUANT_KEY ,
242+ str (bool (self .preserve_io_quantization )).encode (),
243+ )
244+ )
230245 return compile_spec
231246
247+ def _set_preserve_io_quantization (self , enabled : bool ) -> "ArmCompileSpec" :
248+ """Preserve Q/DQ nodes at IO boundaries when lowering."""
249+ self .preserve_io_quantization = enabled
250+ self ._warn_if_redundant_preserve_io_quantization ()
251+ return self
252+
253+ def _warn_if_redundant_preserve_io_quantization (self ) -> None :
254+ """Warn when preserve_io_quantization has no effect for INT-only
255+ specs.
256+ """
257+ if (
258+ self .preserve_io_quantization
259+ and self .tosa_spec .support_integer ()
260+ and not self .tosa_spec .support_float ()
261+ ):
262+ warnings .warn (
263+ "preserve_io_quantization=True is redundant for INT-only TOSA "
264+ "specifications because boundary Q/DQ are already de-tagged." ,
265+ stacklevel = 3 ,
266+ )
267+
232268 def _get_pass_pipeline_config (self ) -> ArmPassPipelineConfig :
233269 """Returns configuration that controls how the Arm pass pipeline should
234270 behave.
0 commit comments