Skip to content

Commit 097b054

Browse files
committed
add fp8 support
1 parent 01af1bf commit 097b054

16 files changed

Lines changed: 1208 additions & 50 deletions

File tree

gptqmodel/looper/weight_only_looper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ..models import BaseQModel
2525
from ..models._const import CPU, SUPPORTS_MODULE_TYPES
2626
from ..nn_modules.converter import MODULE_CONVERTER_MAP
27-
from ..quantization.config import GGUFQuantizeConfig, RTNQuantizeConfig
27+
from ..quantization.config import FP8Config, GGUFQuantizeConfig, RTNQuantizeConfig
2828
from ..utils.logger import setup_logger
2929
from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to
3030
from ..utils.offload import offload_to_disk
@@ -96,9 +96,9 @@ def _offload_quantized_module(self, module: NamedModule) -> None:
9696
def loop(self, **kwargs):
9797
"""Quantize layers directly from weights without calibration forwards."""
9898
quant_config = self.gptq_model.quantize_config
99-
if not isinstance(quant_config, (RTNQuantizeConfig, GGUFQuantizeConfig)):
99+
if not isinstance(quant_config, (RTNQuantizeConfig, GGUFQuantizeConfig, FP8Config)):
100100
raise NotImplementedError(
101-
"Weight-only looper only supports `RTNQuantizeConfig` and `GGUFQuantizeConfig` today."
101+
"Weight-only looper only supports `RTNQuantizeConfig`, `GGUFQuantizeConfig`, and `FP8Config` today."
102102
)
103103

104104
if quant_config.lm_head:

gptqmodel/looper/weight_only_processor.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from ..quantization.config import (
2929
BaseQuantizeConfig,
30+
FP8Config,
3031
GGUFQuantizeConfig,
3132
METHOD,
3233
RTNQuantizeConfig,
@@ -50,7 +51,7 @@ class WeightOnlyProcessor(LoopProcessor):
5051
def __init__(
5152
self,
5253
tokenizer,
53-
qcfg: RTNQuantizeConfig | GGUFQuantizeConfig,
54+
qcfg: RTNQuantizeConfig | GGUFQuantizeConfig | FP8Config,
5455
):
5556
super().__init__(
5657
tokenizer=tokenizer,
@@ -67,8 +68,8 @@ def __init__(
6768
self.lock = threading.Lock()
6869

6970
@staticmethod
70-
def _uses_direct_gguf(qcfg: RTNQuantizeConfig | GGUFQuantizeConfig) -> bool:
71-
return qcfg.quant_method == METHOD.GGUF
71+
def _uses_direct_pack(qcfg: RTNQuantizeConfig | GGUFQuantizeConfig | FP8Config) -> bool:
72+
return qcfg.quant_method in {METHOD.GGUF, METHOD.FP8}
7273

7374
def _update_logged_loss(self, module: NamedModule, avg_loss: str) -> None:
7475
with self.lock:
@@ -94,15 +95,15 @@ def _annotate_tp_padding(self, module: NamedModule, qcfg: BaseQuantizeConfig) ->
9495
"original_columns": columns,
9596
}
9697

97-
def quantize_module(self, module: NamedModule) -> Optional[RTNQuantizeConfig | GGUFQuantizeConfig]:
98+
def quantize_module(self, module: NamedModule) -> Optional[RTNQuantizeConfig | GGUFQuantizeConfig | FP8Config]:
9899
qcfg_clone = clone_weight_only_config_for_module(self.qcfg, module.full_name)
99100
if qcfg_clone is None:
100101
return None
101102

102-
if self._uses_direct_gguf(qcfg_clone):
103+
if self._uses_direct_pack(qcfg_clone):
103104
start_time = time.time()
104105
duration = time.time() - start_time
105-
avg_loss = "gguf: pending"
106+
avg_loss = f"{qcfg_clone.quant_method.value}: pending"
106107
damp_percent = 0.0
107108
nsamples = 0
108109
else:
@@ -139,7 +140,7 @@ def quantize_module(self, module: NamedModule) -> Optional[RTNQuantizeConfig | G
139140
self.log.append(stat)
140141
self.log_new_row(stat)
141142

142-
if not self._uses_direct_gguf(qcfg_clone):
143+
if not self._uses_direct_pack(qcfg_clone):
143144
module.weight.data = wq
144145
return qcfg_clone
145146

@@ -148,11 +149,11 @@ def submodule_finalize(
148149
module: NamedModule,
149150
model: BaseQModel,
150151
*,
151-
qcfg: Optional[RTNQuantizeConfig | GGUFQuantizeConfig] = None,
152+
qcfg: Optional[RTNQuantizeConfig | GGUFQuantizeConfig | FP8Config] = None,
152153
**kwargs,
153154
):
154155
active_qcfg = qcfg or self.qcfg
155-
if not self._uses_direct_gguf(active_qcfg):
156+
if not self._uses_direct_pack(active_qcfg):
156157
module.stream_sync()
157158
with self.lock:
158159
q_zeros = module.state.pop("q_zeros").clone()
@@ -187,6 +188,7 @@ def submodule_finalize(
187188
pack_dtype=active_qcfg.pack_dtype,
188189
format=resolve_quant_format(active_qcfg.format, active_qcfg.quant_method),
189190
register_buffers=False,
191+
init_kwargs=active_qcfg.quant_linear_init_kwargs(),
190192
)
191193
if timer is not None and create_start is not None:
192194
timer.record("submodule_finalize_create", time.perf_counter() - create_start, source=module_label)
@@ -197,7 +199,7 @@ def submodule_finalize(
197199
if name == module.full_name
198200
}
199201

200-
if self._uses_direct_gguf(active_qcfg):
202+
if self._uses_direct_pack(active_qcfg):
201203
pack_start = time.perf_counter() if timer is not None else None
202204
with log_time_block("module.pack_original", logger=log, module_name=module_label):
203205
with parent_module_lock(parent_key):
@@ -219,7 +221,7 @@ def submodule_finalize(
219221
reference_weight = qmodule._weight_to_matrix(original_layer).detach().cpu().to(torch.float32)
220222
dequant_weight = qmodule.dequantize_weight().T.detach().cpu().to(torch.float32)
221223
mean_abs_err = (dequant_weight - reference_weight).abs().mean().item()
222-
self._update_logged_loss(module, f"gguf: {mean_abs_err:.7f}")
224+
self._update_logged_loss(module, f"{active_qcfg.quant_method.value}: {mean_abs_err:.7f}")
223225
module.unregister_parameter("weight")
224226
return
225227

@@ -254,6 +256,8 @@ def finalize(self, model: BaseQModel, **kwargs):
254256
def name(self) -> str:
255257
if self.qcfg.quant_method == METHOD.GGUF:
256258
return "weight_only_gguf"
259+
if self.qcfg.quant_method == METHOD.FP8:
260+
return "weight_only_fp8"
257261
return "weight_only_rtn"
258262

259263
__all__ = ["WeightOnlyProcessor"]

gptqmodel/models/auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def _is_supported_quantization_config(config: AutoConfig) -> bool:
289289
if isinstance(quant_format, str) and quant_format.lower() in (
290290
METHOD.GPTQ,
291291
METHOD.GGUF,
292+
METHOD.FP8,
292293
METHOD.AWQ,
293294
METHOD.QQQ,
294295
METHOD.EXL3,
@@ -299,6 +300,7 @@ def _is_supported_quantization_config(config: AutoConfig) -> bool:
299300
if isinstance(quant_method, str) and quant_method.lower() in (
300301
METHOD.GPTQ,
301302
METHOD.GGUF,
303+
METHOD.FP8,
302304
METHOD.AWQ,
303305
METHOD.QQQ,
304306
METHOD.EXL3,

gptqmodel/models/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,8 @@ def quantize(
708708
preferred_backend = BACKEND.EXLLAMA_V3
709709
elif self.quantize_config.quant_method == METHOD.GGUF:
710710
preferred_backend = BACKEND.AUTO
711+
elif self.quantize_config.quant_method == METHOD.FP8:
712+
preferred_backend = BACKEND.TORCH
711713
else:
712714
preferred_backend = BACKEND.TORCH
713715

@@ -2015,8 +2017,8 @@ def __getattr__(self, item):
20152017
def _auto_detect_module_tree(self, model: PreTrainedModel, quant_method: METHOD):
20162018
log.warn("Model not yet support, attempting Module Tree AutoCompat...")
20172019

2018-
if quant_method not in {METHOD.GPTQ, METHOD.GGUF, METHOD.EXL3}:
2019-
log.warn(f"Module Tree AutoCompat: Failed, quant_method={quant_method}, only support GPTQ/GGUF/EXL3")
2020+
if quant_method not in {METHOD.GPTQ, METHOD.GGUF, METHOD.FP8, METHOD.EXL3}:
2021+
log.warn(f"Module Tree AutoCompat: Failed, quant_method={quant_method}, only support GPTQ/GGUF/FP8/EXL3")
20202022
return None
20212023

20222024
def _get(path):

0 commit comments

Comments
 (0)