2727)
2828from ..quantization .config import (
2929 BaseQuantizeConfig ,
30+ FP8Config ,
3031 GGUFQuantizeConfig ,
3132 METHOD ,
3233 RTNQuantizeConfig ,
@@ -50,7 +51,7 @@ class WeightOnlyProcessor(LoopProcessor):
5051 def __init__ (
5152 self ,
5253 tokenizer ,
53- qcfg : RTNQuantizeConfig | GGUFQuantizeConfig ,
54+ qcfg : RTNQuantizeConfig | GGUFQuantizeConfig | FP8Config ,
5455 ):
5556 super ().__init__ (
5657 tokenizer = tokenizer ,
@@ -67,8 +68,8 @@ def __init__(
6768 self .lock = threading .Lock ()
6869
6970 @staticmethod
70- def _uses_direct_gguf (qcfg : RTNQuantizeConfig | GGUFQuantizeConfig ) -> bool :
71- return qcfg .quant_method == METHOD .GGUF
71+ def _uses_direct_pack (qcfg : RTNQuantizeConfig | GGUFQuantizeConfig | FP8Config ) -> bool :
72+ return qcfg .quant_method in { METHOD .GGUF , METHOD . FP8 }
7273
7374 def _update_logged_loss (self , module : NamedModule , avg_loss : str ) -> None :
7475 with self .lock :
@@ -94,15 +95,15 @@ def _annotate_tp_padding(self, module: NamedModule, qcfg: BaseQuantizeConfig) ->
9495 "original_columns" : columns ,
9596 }
9697
97- def quantize_module (self , module : NamedModule ) -> Optional [RTNQuantizeConfig | GGUFQuantizeConfig ]:
98+ def quantize_module (self , module : NamedModule ) -> Optional [RTNQuantizeConfig | GGUFQuantizeConfig | FP8Config ]:
9899 qcfg_clone = clone_weight_only_config_for_module (self .qcfg , module .full_name )
99100 if qcfg_clone is None :
100101 return None
101102
102- if self ._uses_direct_gguf (qcfg_clone ):
103+ if self ._uses_direct_pack (qcfg_clone ):
103104 start_time = time .time ()
104105 duration = time .time () - start_time
105- avg_loss = "gguf : pending"
106+ avg_loss = f" { qcfg_clone . quant_method . value } : pending"
106107 damp_percent = 0.0
107108 nsamples = 0
108109 else :
@@ -139,7 +140,7 @@ def quantize_module(self, module: NamedModule) -> Optional[RTNQuantizeConfig | G
139140 self .log .append (stat )
140141 self .log_new_row (stat )
141142
142- if not self ._uses_direct_gguf (qcfg_clone ):
143+ if not self ._uses_direct_pack (qcfg_clone ):
143144 module .weight .data = wq
144145 return qcfg_clone
145146
@@ -148,11 +149,11 @@ def submodule_finalize(
148149 module : NamedModule ,
149150 model : BaseQModel ,
150151 * ,
151- qcfg : Optional [RTNQuantizeConfig | GGUFQuantizeConfig ] = None ,
152+ qcfg : Optional [RTNQuantizeConfig | GGUFQuantizeConfig | FP8Config ] = None ,
152153 ** kwargs ,
153154 ):
154155 active_qcfg = qcfg or self .qcfg
155- if not self ._uses_direct_gguf (active_qcfg ):
156+ if not self ._uses_direct_pack (active_qcfg ):
156157 module .stream_sync ()
157158 with self .lock :
158159 q_zeros = module .state .pop ("q_zeros" ).clone ()
@@ -187,6 +188,7 @@ def submodule_finalize(
187188 pack_dtype = active_qcfg .pack_dtype ,
188189 format = resolve_quant_format (active_qcfg .format , active_qcfg .quant_method ),
189190 register_buffers = False ,
191+ init_kwargs = active_qcfg .quant_linear_init_kwargs (),
190192 )
191193 if timer is not None and create_start is not None :
192194 timer .record ("submodule_finalize_create" , time .perf_counter () - create_start , source = module_label )
@@ -197,7 +199,7 @@ def submodule_finalize(
197199 if name == module .full_name
198200 }
199201
200- if self ._uses_direct_gguf (active_qcfg ):
202+ if self ._uses_direct_pack (active_qcfg ):
201203 pack_start = time .perf_counter () if timer is not None else None
202204 with log_time_block ("module.pack_original" , logger = log , module_name = module_label ):
203205 with parent_module_lock (parent_key ):
@@ -219,7 +221,7 @@ def submodule_finalize(
219221 reference_weight = qmodule ._weight_to_matrix (original_layer ).detach ().cpu ().to (torch .float32 )
220222 dequant_weight = qmodule .dequantize_weight ().T .detach ().cpu ().to (torch .float32 )
221223 mean_abs_err = (dequant_weight - reference_weight ).abs ().mean ().item ()
222- self ._update_logged_loss (module , f"gguf : { mean_abs_err :.7f} " )
224+ self ._update_logged_loss (module , f"{ active_qcfg . quant_method . value } : { mean_abs_err :.7f} " )
223225 module .unregister_parameter ("weight" )
224226 return
225227
@@ -254,6 +256,8 @@ def finalize(self, model: BaseQModel, **kwargs):
254256 def name (self ) -> str :
255257 if self .qcfg .quant_method == METHOD .GGUF :
256258 return "weight_only_gguf"
259+ if self .qcfg .quant_method == METHOD .FP8 :
260+ return "weight_only_fp8"
257261 return "weight_only_rtn"
258262
259263__all__ = ["WeightOnlyProcessor" ]
0 commit comments