1010import logging
1111import os
1212
13+ import datasets # type: ignore[import-untyped]
14+ import nncf # type: ignore[import-untyped]
15+
1316import torch
1417
1518from executorch .backends .openvino .partitioner import OpenvinoPartitioner
19+ from executorch .backends .openvino .quantizer import (
20+ OpenVINOQuantizer ,
21+ QuantizationMode ,
22+ quantize_model ,
23+ )
1624from executorch .examples .models .stable_diffusion .model import ( # type: ignore[import-untyped]
1725 LCMModelLoader ,
26+ StableDiffusionComponent ,
1827)
1928from executorch .exir import ExecutorchBackendConfig , to_edge_transform_and_lower
2029from executorch .exir .backend .backend_details import CompileSpec
2130from torch .export import export
31+ from tqdm import tqdm # type: ignore[import-untyped]
2232
2333# Configure logging
2434logging .basicConfig (level = logging .INFO )
@@ -31,27 +41,180 @@ class LCMOpenVINOExporter:
3141 def __init__ (
3242 self ,
3343 model_id : str = "SimianLuo/LCM_Dreamshaper_v7" ,
44+ is_quantization_enabled : bool = False ,
3445 dtype : torch .dtype = torch .float16 ,
46+ calibration_dataset_name : str = "google-research-datasets/conceptual_captions" ,
47+ calibration_dataset_column : str = "caption" ,
3548 ):
49+ if is_quantization_enabled :
50+ dtype = torch .float32
51+ self .is_quantization_enabled = is_quantization_enabled
52+ self .calibration_dataset_name = calibration_dataset_name
53+ self .calibration_dataset_column = calibration_dataset_column
3654 self .model_loader = LCMModelLoader (model_id = model_id , dtype = dtype )
3755
3856 def load_models (self ) -> bool :
3957 """Load the LCM pipeline and extract components"""
4058 return self .model_loader .load_models ()
4159
60+ @staticmethod
61+ def get_unet_calibration_dataset (
62+ pipeline ,
63+ dataset_name : str ,
64+ dataset_column : str ,
65+ calibration_dataset_size : int = 200 ,
66+ num_inference_steps : int = 4 ,
67+ ) -> list [tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]:
68+ """Collect UNet calibration inputs from prompts."""
69+
70+ class UNetWrapper (torch .nn .Module ):
71+ def __init__ (self , model : torch .nn .Module , config ):
72+ super ().__init__ ()
73+ self .model = model
74+ self .config = config
75+ self .captured_args : list [
76+ tuple [torch .Tensor , torch .Tensor , torch .Tensor ]
77+ ] = []
78+
79+ def _pick_correct_arg_or_kwarg (
80+ self ,
81+ name : str ,
82+ args ,
83+ kwargs ,
84+ idx : int ,
85+ ):
86+ if name in kwargs and kwargs [name ] is not None :
87+ return kwargs [name ]
88+ if len (args ) > idx :
89+ return args [idx ]
90+ raise KeyError (f"Missing required UNet input: { name } " )
91+
92+ def _process_inputs (
93+ self , * args , ** kwargs
94+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
95+ sample = self ._pick_correct_arg_or_kwarg ("sample" , args , kwargs , 0 )
96+ timestep = self ._pick_correct_arg_or_kwarg ("timestep" , args , kwargs , 1 )
97+ encoder_hidden_states = self ._pick_correct_arg_or_kwarg (
98+ "encoder_hidden_states" , args , kwargs , 2
99+ )
100+ timestep = (
101+ timestep .unsqueeze (0 )
102+ if isinstance (timestep , torch .Tensor ) and timestep .dim () == 0
103+ else timestep
104+ )
105+ processed_args = (
106+ sample ,
107+ timestep ,
108+ encoder_hidden_states ,
109+ )
110+ return processed_args
111+
112+ def forward (self , * args , ** kwargs ):
113+ """
114+ Obtain and pass each input individually to ensure the order is maintained
115+ and the right values are being passed according to the expected inputs by
116+ the OpenVINO LCM runner.
117+ """
118+ unet_args = self ._process_inputs (* args , ** kwargs )
119+ self .captured_args .append (unet_args )
120+ return self .model (* args , ** kwargs )
121+
122+ calibration_data = []
123+ dataset = datasets .load_dataset (
124+ dataset_name ,
125+ split = "train" ,
126+ streaming = True ,
127+ ).shuffle (seed = 42 )
128+ original_unet = pipeline .unet
129+ wrapped_unet = UNetWrapper (pipeline .unet , pipeline .unet .config )
130+ pipeline .unet = wrapped_unet
131+ # Run inference for data collection
132+ pbar = tqdm (total = calibration_dataset_size )
133+ try :
134+ for batch in dataset :
135+ if dataset_column not in batch :
136+ raise RuntimeError (
137+ f"Column '{ dataset_column } ' was not found in dataset '{ dataset_name } '"
138+ )
139+ prompt = batch [dataset_column ]
140+ tokenized_prompt = pipeline .tokenizer .encode (prompt )
141+ if len (tokenized_prompt ) > pipeline .tokenizer .model_max_length :
142+ continue
143+ # Run the pipeline
144+ pipeline (
145+ prompt ,
146+ num_inference_steps = num_inference_steps ,
147+ height = 512 ,
148+ width = 512 ,
149+ output_type = "latent" ,
150+ )
151+ calibration_data .extend (wrapped_unet .captured_args )
152+ wrapped_unet .captured_args = []
153+ pbar .update (len (calibration_data ) - pbar .n )
154+ if pbar .n >= calibration_dataset_size :
155+ break
156+ finally :
157+ pipeline .unet = original_unet
158+ pbar .close ()
159+ return calibration_data
160+
161+ def quantize_unet_model (
162+ self ,
163+ model : torch .export .ExportedProgram ,
164+ dummy_inputs ,
165+ ) -> torch .export .ExportedProgram :
166+ """Quantize UNet using activation-aware PTQ."""
167+ pipeline = self .model_loader .pipeline
168+ calibration_dataset = self .get_unet_calibration_dataset (
169+ pipeline ,
170+ self .calibration_dataset_name ,
171+ self .calibration_dataset_column ,
172+ )
173+ model = model .module ()
174+ quantized_model = quantize_model (
175+ model ,
176+ mode = QuantizationMode .INT8_TRANSFORMER ,
177+ calibration_dataset = calibration_dataset , # type: ignore[arg-type]
178+ smooth_quant = True ,
179+ )
180+ # Re-export the transformed torch.fx.GraphModule to ExportedProgram
181+ quantized_exported_program = export (quantized_model , dummy_inputs )
182+ return quantized_exported_program
183+
184+ @staticmethod
185+ def compress_model (
186+ model : torch .export .ExportedProgram ,
187+ dummy_inputs ,
188+ ) -> torch .export .ExportedProgram :
189+ """Apply weights-only compression for non-UNet components."""
190+ model = model .module ()
191+ ov_quantizer = OpenVINOQuantizer (mode = QuantizationMode .INT8WO_ASYM )
192+ quantized_model = nncf .experimental .torch .fx .compress_pt2e (
193+ model , quantizer = ov_quantizer
194+ )
195+ # Re-export the transformed torch.fx.GraphModule to ExportedProgram
196+ quantized_exported_program = export (quantized_model , dummy_inputs )
197+ return quantized_exported_program
198+
42199 def export_text_encoder (self , output_path : str , device : str = "CPU" ) -> bool :
43200 """Export CLIP text encoder to PTE file"""
44201 try :
45202 logger .info ("Exporting text encoder with OpenVINO backend..." )
46203
204+ sd_model_component = StableDiffusionComponent .TEXT_ENCODER
205+
47206 # Get wrapped model and dummy inputs
48207 text_encoder_wrapper = self .model_loader .get_text_encoder_wrapper ()
49208 dummy_inputs = self .model_loader .get_dummy_inputs ()
50209
51210 # Export to ATEN graph
52- exported_program = export (
53- text_encoder_wrapper , dummy_inputs ["text_encoder" ]
54- )
211+ component_dummy_inputs = dummy_inputs [sd_model_component ]
212+ exported_program = export (text_encoder_wrapper , component_dummy_inputs )
213+
214+ if self .is_quantization_enabled :
215+ exported_program = self .compress_model (
216+ exported_program , component_dummy_inputs
217+ )
55218
56219 # Configure OpenVINO compilation
57220 compile_spec = [CompileSpec ("device" , device .encode ())]
@@ -85,13 +248,20 @@ def export_unet(self, output_path: str, device: str = "CPU") -> bool:
85248 """Export UNet model to PTE file"""
86249 try :
87250 logger .info ("Exporting UNet model with OpenVINO backend..." )
251+ sd_model_component = StableDiffusionComponent .UNET
88252
89253 # Get wrapped model and dummy inputs
90254 unet_wrapper = self .model_loader .get_unet_wrapper ()
91255 dummy_inputs = self .model_loader .get_dummy_inputs ()
92256
93257 # Export to ATEN graph
94- exported_program = export (unet_wrapper , dummy_inputs ["unet" ])
258+ component_dummy_inputs = dummy_inputs [sd_model_component ]
259+ exported_program = export (unet_wrapper , component_dummy_inputs )
260+
261+ if self .is_quantization_enabled :
262+ exported_program = self .quantize_unet_model (
263+ exported_program , component_dummy_inputs
264+ )
95265
96266 # Configure OpenVINO compilation
97267 compile_spec = [CompileSpec ("device" , device .encode ())]
@@ -125,13 +295,20 @@ def export_vae_decoder(self, output_path: str, device: str = "CPU") -> bool:
125295 """Export VAE decoder to PTE file"""
126296 try :
127297 logger .info ("Exporting VAE decoder with OpenVINO backend..." )
298+ sd_model_component = StableDiffusionComponent .VAE_DECODER
128299
129300 # Get wrapped model and dummy inputs
130301 vae_decoder = self .model_loader .get_vae_decoder ()
131302 dummy_inputs = self .model_loader .get_dummy_inputs ()
132303
133304 # Export to ATEN graph
134- exported_program = export (vae_decoder , dummy_inputs ["vae_decoder" ])
305+ component_dummy_inputs = dummy_inputs [sd_model_component ]
306+ exported_program = export (vae_decoder , component_dummy_inputs )
307+
308+ if self .is_quantization_enabled :
309+ exported_program = self .compress_model (
310+ exported_program , component_dummy_inputs
311+ )
135312
136313 # Configure OpenVINO compilation
137314 compile_spec = [CompileSpec ("device" , device .encode ())]
@@ -223,9 +400,23 @@ def create_argument_parser():
223400
224401 parser .add_argument (
225402 "--dtype" ,
226- choices = ["fp16" , "fp32" ],
403+ choices = ["fp16" , "fp32" , "int8" ],
227404 default = "fp16" ,
228- help = "Model data type (default: fp16)" ,
405+ help = "Model data type. Use int8 to enable PTQ quantization (default: fp16)" ,
406+ )
407+
408+ parser .add_argument (
409+ "--calibration_dataset_name" ,
410+ type = str ,
411+ default = "google-research-datasets/conceptual_captions" ,
412+ help = "HuggingFace dataset used for UNet calibration when INT8 quantization is enabled" ,
413+ )
414+
415+ parser .add_argument (
416+ "--calibration_dataset_column" ,
417+ type = str ,
418+ default = "caption" ,
419+ help = "Dataset column name used as prompt text for UNet calibration" ,
229420 )
230421
231422 parser .add_argument ("--verbose" , action = "store_true" , help = "Enable verbose logging" )
@@ -249,11 +440,18 @@ def main() -> int:
249440 logger .info ("=" * 60 )
250441
251442 # Map dtype string to torch dtype
252- dtype_map = {"fp16" : torch .float16 , "fp32" : torch .float32 }
443+ is_quantization_enabled = args .dtype == "int8"
444+ dtype_map = {"fp16" : torch .float16 , "fp32" : torch .float32 , "int8" : torch .float32 }
253445 dtype = dtype_map [args .dtype ]
254446
255447 # Create exporter and load models
256- exporter = LCMOpenVINOExporter (args .model_id , dtype = dtype )
448+ exporter = LCMOpenVINOExporter (
449+ args .model_id ,
450+ is_quantization_enabled = is_quantization_enabled ,
451+ dtype = dtype ,
452+ calibration_dataset_name = args .calibration_dataset_name ,
453+ calibration_dataset_column = args .calibration_dataset_column ,
454+ )
257455
258456 if not exporter .load_models ():
259457 logger .error ("Failed to load models" )
0 commit comments