-
Notifications
You must be signed in to change notification settings - Fork 353
Expand file tree
/
Copy pathquantize_config.py
More file actions
161 lines (125 loc) · 5.01 KB
/
quantize_config.py
File metadata and controls
161 lines (125 loc) · 5.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any
import torch
from models_utils import MODEL_REGISTRY, ModelType
class DataType(str, Enum):
"""Supported data types for model loading."""
HALF = "Half"
BFLOAT16 = "BFloat16"
FLOAT = "Float"
@property
def torch_dtype(self) -> torch.dtype:
return self._dtype_map[self.value]
DataType._dtype_map = {
DataType.HALF: torch.float16,
DataType.BFLOAT16: torch.bfloat16,
DataType.FLOAT: torch.float32,
}
class QuantFormat(str, Enum):
"""Supported quantization formats."""
INT8 = "int8"
FP8 = "fp8"
FP4 = "fp4"
class QuantAlgo(str, Enum):
"""Supported quantization algorithms."""
MAX = "max"
SVDQUANT = "svdquant"
SMOOTHQUANT = "smoothquant"
class CollectMethod(str, Enum):
"""Calibration collection methods."""
GLOBAL_MIN = "global_min"
MIN_MAX = "min-max"
MIN_MEAN = "min-mean"
MEAN_MAX = "mean-max"
DEFAULT = "default"
@dataclass
class QuantizationConfig:
"""Configuration for model quantization."""
format: QuantFormat = QuantFormat.INT8
algo: QuantAlgo = QuantAlgo.MAX
percentile: float = 1.0
collect_method: CollectMethod = CollectMethod.DEFAULT
alpha: float = 1.0 # SmoothQuant alpha
lowrank: int = 32 # SVDQuant lowrank
quantize_mha: bool = False
compress: bool = False
def validate(self) -> None:
"""Validate configuration consistency."""
if self.format == QuantFormat.FP8 and self.collect_method != CollectMethod.DEFAULT:
raise NotImplementedError("Only 'default' collect method is implemented for FP8.")
if self.quantize_mha and self.format == QuantFormat.INT8:
raise ValueError("MHA quantization is only supported for FP8, not INT8.")
if self.compress and self.format == QuantFormat.INT8:
raise ValueError("Compression is only supported for FP8 and FP4, not INT8.")
@dataclass
class CalibrationConfig:
"""Configuration for calibration process."""
prompts_dataset: dict | Path
batch_size: int = 2
calib_size: int = 128
n_steps: int = 30
def validate(self) -> None:
"""Validate calibration configuration."""
if self.batch_size <= 0:
raise ValueError("Batch size must be positive.")
if self.calib_size <= 0:
raise ValueError("Calibration size must be positive.")
if self.n_steps <= 0:
raise ValueError("Number of steps must be positive.")
@property
def num_batches(self) -> int:
"""Calculate number of calibration batches."""
return math.ceil(self.calib_size / self.batch_size)
@dataclass
class ModelConfig:
"""Configuration for model loading and inference."""
model_type: ModelType = ModelType.FLUX_DEV
model_dtype: dict[str, torch.dtype] = field(default_factory=lambda: {"default": torch.float16})
backbone: str = ""
trt_high_precision_dtype: DataType = DataType.HALF
override_model_path: Path | None = None
cpu_offloading: bool = False
ltx_skip_upsampler: bool = False # Skip upsampler for LTX-Video (faster calibration)
extra_params: dict[str, Any] = field(default_factory=dict)
@property
def model_path(self) -> str:
"""Get the model path (override or default)."""
if self.override_model_path:
return str(self.override_model_path)
return MODEL_REGISTRY[self.model_type]
@dataclass
class ExportConfig:
"""Configuration for model export."""
quantized_torch_ckpt_path: Path | None = None
onnx_dir: Path | None = None
hf_ckpt_dir: Path | None = None
restore_from: Path | None = None
def validate(self) -> None:
"""Validate export configuration."""
if self.restore_from and not self.restore_from.exists():
raise FileNotFoundError(f"Restore checkpoint not found: {self.restore_from}")
if self.quantized_torch_ckpt_path:
parent_dir = self.quantized_torch_ckpt_path.parent
if not parent_dir.exists():
parent_dir.mkdir(parents=True, exist_ok=True)
if self.onnx_dir and not self.onnx_dir.exists():
self.onnx_dir.mkdir(parents=True, exist_ok=True)
if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists():
self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True)