Skip to content

Commit 6823b33

Browse files
anzr299Copilotsuryasidd
authored
[OpenVINO][Examples] Add Quantization for the OpenVINO Stable Diffusion Example (pytorch#17807)
### Summary Extend the stable diffusion example for OpenVINO backend with quantization support. --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Surya Siddharth Pemmaraju <surya.siddharth.pemmaraju@intel.com>
1 parent 300e368 commit 6823b33

6 files changed

Lines changed: 248 additions & 16 deletions

File tree

examples/models/stable_diffusion/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44
# except in compliance with the License. See the license file found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .model import LCMModelLoader, TextEncoderWrapper, UNetWrapper, VAEDecoder
7+
from .model import (
8+
LCMModelLoader,
9+
StableDiffusionComponent,
10+
TextEncoderWrapper,
11+
UNetWrapper,
12+
VAEDecoder,
13+
)
814

9-
__all__ = ["LCMModelLoader", "TextEncoderWrapper", "UNetWrapper", "VAEDecoder"]
15+
__all__ = [
16+
"LCMModelLoader",
17+
"StableDiffusionComponent",
18+
"TextEncoderWrapper",
19+
"UNetWrapper",
20+
"VAEDecoder",
21+
]

examples/models/stable_diffusion/model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"""
1313

1414
import logging
15+
from enum import Enum
1516
from typing import Any, Optional
1617

1718
import torch
@@ -26,6 +27,14 @@
2627
logger = logging.getLogger(__name__)
2728

2829

30+
class StableDiffusionComponent(Enum):
31+
"""Maintain Stable Diffusion model component names reliably"""
32+
33+
TEXT_ENCODER = "text_encoder"
34+
UNET = "unet"
35+
VAE_DECODER = "vae_decoder"
36+
37+
2938
class TextEncoderWrapper(torch.nn.Module):
3039
"""Wrapper for CLIP text encoder that extracts last_hidden_state"""
3140

@@ -150,7 +159,7 @@ def get_vae_decoder(self) -> VAEDecoder:
150159
raise ValueError("Models not loaded. Call load_models() first.")
151160
return VAEDecoder(self.vae)
152161

153-
def get_dummy_inputs(self):
162+
def get_dummy_inputs(self) -> dict[StableDiffusionComponent, tuple[Any, ...]]:
154163
"""
155164
Get dummy inputs for each model component.
156165
@@ -187,7 +196,7 @@ def get_dummy_inputs(self):
187196
vae_input = torch.randn(1, 4, 64, 64, dtype=self.dtype)
188197

189198
return {
190-
"text_encoder": (text_encoder_input,),
191-
"unet": unet_inputs,
192-
"vae_decoder": (vae_input,),
199+
StableDiffusionComponent.TEXT_ENCODER: (text_encoder_input,),
200+
StableDiffusionComponent.UNET: unet_inputs,
201+
StableDiffusionComponent.VAE_DECODER: (vae_input,),
193202
}

examples/openvino/stable_diffusion/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ python export_lcm.py \
2525
--device CPU \
2626
--dtype fp16
2727
```
28+
29+
To quantize the UNet with 8-bit activations and 8-bit weights (8a8w) and apply weights-only 8-bit quantization (16a8w) to the remaining components, run:
30+
```bash
31+
python export_lcm.py \
32+
--model_id SimianLuo/LCM_Dreamshaper_v7 \
33+
--output_dir ./lcm_models \
34+
--device CPU \
35+
--dtype int8
36+
```
37+
2838
This will create three files in `./lcm_models/`:
2939
- `text_encoder.pte`
3040
- `unet.pte`
@@ -33,6 +43,7 @@ This will create three files in `./lcm_models/`:
3343
### Generate Images
3444

3545
Run inference with the exported model:
46+
Note: For quantized models, pass `--dtype int8`
3647

3748
```bash
3849
python openvino_lcm.py \

examples/openvino/stable_diffusion/export_lcm.py

Lines changed: 207 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,25 @@
1010
import logging
1111
import os
1212

13+
import datasets # type: ignore[import-untyped]
14+
import nncf # type: ignore[import-untyped]
15+
1316
import torch
1417

1518
from executorch.backends.openvino.partitioner import OpenvinoPartitioner
19+
from executorch.backends.openvino.quantizer import (
20+
OpenVINOQuantizer,
21+
QuantizationMode,
22+
quantize_model,
23+
)
1624
from executorch.examples.models.stable_diffusion.model import ( # type: ignore[import-untyped]
1725
LCMModelLoader,
26+
StableDiffusionComponent,
1827
)
1928
from executorch.exir import ExecutorchBackendConfig, to_edge_transform_and_lower
2029
from executorch.exir.backend.backend_details import CompileSpec
2130
from torch.export import export
31+
from tqdm import tqdm # type: ignore[import-untyped]
2232

2333
# Configure logging
2434
logging.basicConfig(level=logging.INFO)
@@ -31,27 +41,180 @@ class LCMOpenVINOExporter:
3141
def __init__(
3242
self,
3343
model_id: str = "SimianLuo/LCM_Dreamshaper_v7",
44+
is_quantization_enabled: bool = False,
3445
dtype: torch.dtype = torch.float16,
46+
calibration_dataset_name: str = "google-research-datasets/conceptual_captions",
47+
calibration_dataset_column: str = "caption",
3548
):
49+
if is_quantization_enabled:
50+
dtype = torch.float32
51+
self.is_quantization_enabled = is_quantization_enabled
52+
self.calibration_dataset_name = calibration_dataset_name
53+
self.calibration_dataset_column = calibration_dataset_column
3654
self.model_loader = LCMModelLoader(model_id=model_id, dtype=dtype)
3755

3856
def load_models(self) -> bool:
3957
"""Load the LCM pipeline and extract components"""
4058
return self.model_loader.load_models()
4159

60+
@staticmethod
61+
def get_unet_calibration_dataset(
62+
pipeline,
63+
dataset_name: str,
64+
dataset_column: str,
65+
calibration_dataset_size: int = 200,
66+
num_inference_steps: int = 4,
67+
) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
68+
"""Collect UNet calibration inputs from prompts."""
69+
70+
class UNetWrapper(torch.nn.Module):
71+
def __init__(self, model: torch.nn.Module, config):
72+
super().__init__()
73+
self.model = model
74+
self.config = config
75+
self.captured_args: list[
76+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
77+
] = []
78+
79+
def _pick_correct_arg_or_kwarg(
80+
self,
81+
name: str,
82+
args,
83+
kwargs,
84+
idx: int,
85+
):
86+
if name in kwargs and kwargs[name] is not None:
87+
return kwargs[name]
88+
if len(args) > idx:
89+
return args[idx]
90+
raise KeyError(f"Missing required UNet input: {name}")
91+
92+
def _process_inputs(
93+
self, *args, **kwargs
94+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
95+
sample = self._pick_correct_arg_or_kwarg("sample", args, kwargs, 0)
96+
timestep = self._pick_correct_arg_or_kwarg("timestep", args, kwargs, 1)
97+
encoder_hidden_states = self._pick_correct_arg_or_kwarg(
98+
"encoder_hidden_states", args, kwargs, 2
99+
)
100+
timestep = (
101+
timestep.unsqueeze(0)
102+
if isinstance(timestep, torch.Tensor) and timestep.dim() == 0
103+
else timestep
104+
)
105+
processed_args = (
106+
sample,
107+
timestep,
108+
encoder_hidden_states,
109+
)
110+
return processed_args
111+
112+
def forward(self, *args, **kwargs):
113+
"""
114+
Obtain and pass each input individually to ensure the order is maintained
115+
and the right values are being passed according to the expected inputs by
116+
the OpenVINO LCM runner.
117+
"""
118+
unet_args = self._process_inputs(*args, **kwargs)
119+
self.captured_args.append(unet_args)
120+
return self.model(*args, **kwargs)
121+
122+
calibration_data = []
123+
dataset = datasets.load_dataset(
124+
dataset_name,
125+
split="train",
126+
streaming=True,
127+
).shuffle(seed=42)
128+
original_unet = pipeline.unet
129+
wrapped_unet = UNetWrapper(pipeline.unet, pipeline.unet.config)
130+
pipeline.unet = wrapped_unet
131+
# Run inference for data collection
132+
pbar = tqdm(total=calibration_dataset_size)
133+
try:
134+
for batch in dataset:
135+
if dataset_column not in batch:
136+
raise RuntimeError(
137+
f"Column '{dataset_column}' was not found in dataset '{dataset_name}'"
138+
)
139+
prompt = batch[dataset_column]
140+
tokenized_prompt = pipeline.tokenizer.encode(prompt)
141+
if len(tokenized_prompt) > pipeline.tokenizer.model_max_length:
142+
continue
143+
# Run the pipeline
144+
pipeline(
145+
prompt,
146+
num_inference_steps=num_inference_steps,
147+
height=512,
148+
width=512,
149+
output_type="latent",
150+
)
151+
calibration_data.extend(wrapped_unet.captured_args)
152+
wrapped_unet.captured_args = []
153+
pbar.update(len(calibration_data) - pbar.n)
154+
if pbar.n >= calibration_dataset_size:
155+
break
156+
finally:
157+
pipeline.unet = original_unet
158+
pbar.close()
159+
return calibration_data
160+
161+
def quantize_unet_model(
162+
self,
163+
model: torch.export.ExportedProgram,
164+
dummy_inputs,
165+
) -> torch.export.ExportedProgram:
166+
"""Quantize UNet using activation-aware PTQ."""
167+
pipeline = self.model_loader.pipeline
168+
calibration_dataset = self.get_unet_calibration_dataset(
169+
pipeline,
170+
self.calibration_dataset_name,
171+
self.calibration_dataset_column,
172+
)
173+
model = model.module()
174+
quantized_model = quantize_model(
175+
model,
176+
mode=QuantizationMode.INT8_TRANSFORMER,
177+
calibration_dataset=calibration_dataset, # type: ignore[arg-type]
178+
smooth_quant=True,
179+
)
180+
# Re-export the transformed torch.fx.GraphModule to ExportedProgram
181+
quantized_exported_program = export(quantized_model, dummy_inputs)
182+
return quantized_exported_program
183+
184+
@staticmethod
185+
def compress_model(
186+
model: torch.export.ExportedProgram,
187+
dummy_inputs,
188+
) -> torch.export.ExportedProgram:
189+
"""Apply weights-only compression for non-UNet components."""
190+
model = model.module()
191+
ov_quantizer = OpenVINOQuantizer(mode=QuantizationMode.INT8WO_ASYM)
192+
quantized_model = nncf.experimental.torch.fx.compress_pt2e(
193+
model, quantizer=ov_quantizer
194+
)
195+
# Re-export the transformed torch.fx.GraphModule to ExportedProgram
196+
quantized_exported_program = export(quantized_model, dummy_inputs)
197+
return quantized_exported_program
198+
42199
def export_text_encoder(self, output_path: str, device: str = "CPU") -> bool:
43200
"""Export CLIP text encoder to PTE file"""
44201
try:
45202
logger.info("Exporting text encoder with OpenVINO backend...")
46203

204+
sd_model_component = StableDiffusionComponent.TEXT_ENCODER
205+
47206
# Get wrapped model and dummy inputs
48207
text_encoder_wrapper = self.model_loader.get_text_encoder_wrapper()
49208
dummy_inputs = self.model_loader.get_dummy_inputs()
50209

51210
# Export to ATEN graph
52-
exported_program = export(
53-
text_encoder_wrapper, dummy_inputs["text_encoder"]
54-
)
211+
component_dummy_inputs = dummy_inputs[sd_model_component]
212+
exported_program = export(text_encoder_wrapper, component_dummy_inputs)
213+
214+
if self.is_quantization_enabled:
215+
exported_program = self.compress_model(
216+
exported_program, component_dummy_inputs
217+
)
55218

56219
# Configure OpenVINO compilation
57220
compile_spec = [CompileSpec("device", device.encode())]
@@ -85,13 +248,20 @@ def export_unet(self, output_path: str, device: str = "CPU") -> bool:
85248
"""Export UNet model to PTE file"""
86249
try:
87250
logger.info("Exporting UNet model with OpenVINO backend...")
251+
sd_model_component = StableDiffusionComponent.UNET
88252

89253
# Get wrapped model and dummy inputs
90254
unet_wrapper = self.model_loader.get_unet_wrapper()
91255
dummy_inputs = self.model_loader.get_dummy_inputs()
92256

93257
# Export to ATEN graph
94-
exported_program = export(unet_wrapper, dummy_inputs["unet"])
258+
component_dummy_inputs = dummy_inputs[sd_model_component]
259+
exported_program = export(unet_wrapper, component_dummy_inputs)
260+
261+
if self.is_quantization_enabled:
262+
exported_program = self.quantize_unet_model(
263+
exported_program, component_dummy_inputs
264+
)
95265

96266
# Configure OpenVINO compilation
97267
compile_spec = [CompileSpec("device", device.encode())]
@@ -125,13 +295,20 @@ def export_vae_decoder(self, output_path: str, device: str = "CPU") -> bool:
125295
"""Export VAE decoder to PTE file"""
126296
try:
127297
logger.info("Exporting VAE decoder with OpenVINO backend...")
298+
sd_model_component = StableDiffusionComponent.VAE_DECODER
128299

129300
# Get wrapped model and dummy inputs
130301
vae_decoder = self.model_loader.get_vae_decoder()
131302
dummy_inputs = self.model_loader.get_dummy_inputs()
132303

133304
# Export to ATEN graph
134-
exported_program = export(vae_decoder, dummy_inputs["vae_decoder"])
305+
component_dummy_inputs = dummy_inputs[sd_model_component]
306+
exported_program = export(vae_decoder, component_dummy_inputs)
307+
308+
if self.is_quantization_enabled:
309+
exported_program = self.compress_model(
310+
exported_program, component_dummy_inputs
311+
)
135312

136313
# Configure OpenVINO compilation
137314
compile_spec = [CompileSpec("device", device.encode())]
@@ -223,9 +400,23 @@ def create_argument_parser():
223400

224401
parser.add_argument(
225402
"--dtype",
226-
choices=["fp16", "fp32"],
403+
choices=["fp16", "fp32", "int8"],
227404
default="fp16",
228-
help="Model data type (default: fp16)",
405+
help="Model data type. Use int8 to enable PTQ quantization (default: fp16)",
406+
)
407+
408+
parser.add_argument(
409+
"--calibration_dataset_name",
410+
type=str,
411+
default="google-research-datasets/conceptual_captions",
412+
help="HuggingFace dataset used for UNet calibration when INT8 quantization is enabled",
413+
)
414+
415+
parser.add_argument(
416+
"--calibration_dataset_column",
417+
type=str,
418+
default="caption",
419+
help="Dataset column name used as prompt text for UNet calibration",
229420
)
230421

231422
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
@@ -249,11 +440,18 @@ def main() -> int:
249440
logger.info("=" * 60)
250441

251442
# Map dtype string to torch dtype
252-
dtype_map = {"fp16": torch.float16, "fp32": torch.float32}
443+
is_quantization_enabled = args.dtype == "int8"
444+
dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "int8": torch.float32}
253445
dtype = dtype_map[args.dtype]
254446

255447
# Create exporter and load models
256-
exporter = LCMOpenVINOExporter(args.model_id, dtype=dtype)
448+
exporter = LCMOpenVINOExporter(
449+
args.model_id,
450+
is_quantization_enabled=is_quantization_enabled,
451+
dtype=dtype,
452+
calibration_dataset_name=args.calibration_dataset_name,
453+
calibration_dataset_column=args.calibration_dataset_column,
454+
)
257455

258456
if not exporter.load_models():
259457
logger.error("Failed to load models")

examples/openvino/stable_diffusion/openvino_lcm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def create_argument_parser():
331331
"--device", choices=["CPU", "GPU"], default="CPU", help="Target device"
332332
)
333333
parser.add_argument(
334-
"--dtype", choices=["fp16", "fp32"], default="fp16", help="Model dtype"
334+
"--dtype", choices=["fp16", "fp32", "int8"], default="fp16", help="Model dtype"
335335
)
336336
parser.add_argument(
337337
"--output_dir", type=str, default="./lcm_outputs", help="Output directory"

0 commit comments

Comments
 (0)