1818import json
1919import multiprocessing as mp
2020import os
21+ import shutil
2122from collections import OrderedDict , defaultdict
2223from concurrent .futures import ThreadPoolExecutor
2324from dataclasses import dataclass
2425from glob import glob
2526
2627import torch
27- from huggingface_hub import snapshot_download
2828from safetensors .torch import load_file , save_file
2929from tqdm import tqdm
3030
4040 prefetch_base_shard ,
4141)
4242
43+ # Suffixes that identify weight tensors to be quantized.
44+ # Imported from fp8_quant_blockwise for consistency; any weight whose name
45+ # ends with one of these suffixes will be quantized by DAQ.
46+ SUFFIX_TO_QUANT = [
47+ ".gate_and_up_proj.weight" ,
48+ ".gate_proj.weight" ,
49+ ".up_proj.weight" ,
50+ ".down_proj.weight" ,
51+ ".q_a_proj.weight" ,
52+ ".q_b_proj.weight" ,
53+ ".kv_a_proj_with_mqa.weight" ,
54+ ".kv_b_proj.weight" ,
55+ ".qkv_proj.weight" ,
56+ ".q_proj.weight" ,
57+ ".k_proj.weight" ,
58+ ".v_proj.weight" ,
59+ ".o_proj.weight" ,
60+ ".experts.gate_up_proj" ,
61+ ".experts.down_proj" ,
62+ ]
63+
4364__all__ = ["DAQ" ]
4465
4566
@@ -127,7 +148,6 @@ def __init__(self, quant_config, sft_model_path: str):
127148 self .quantization_method = quant_config .quantization_method
128149 self .num_workers = quant_config .num_workers
129150 self .ignore_layers = getattr (quant_config , "ignore_layers" , []) or []
130- self .base_model_repo = quant_config .base_model_repo
131151
132152 gpus_str = quant_config .gpus
133153 if gpus_str :
@@ -228,7 +248,6 @@ def run(self, save_path: str):
228248 model_index_file = os .path .join (save_path , "model.safetensors.index.json" )
229249 with open (model_index_file , "r" ) as f :
230250 model_index = json .load (f )
231- weight_map = model_index ["weight_map" ]
232251
233252 base_weight_map = get_weight_map (self .base_model_path )
234253 if not base_weight_map :
@@ -253,7 +272,6 @@ def run(self, save_path: str):
253272 safetensor_files ,
254273 self .base_model_path ,
255274 save_path ,
256- weight_map ,
257275 base_weight_map ,
258276 dynamic_cache_size ,
259277 )
@@ -262,7 +280,6 @@ def run(self, save_path: str):
262280 safetensor_files ,
263281 self .base_model_path ,
264282 save_path ,
265- weight_map ,
266283 base_weight_map ,
267284 dynamic_cache_size ,
268285 )
@@ -284,32 +301,12 @@ def run(self, save_path: str):
284301 print_info ("DAQ quantization complete!" )
285302
286303 def _prepare_output_dir (self , save_path : str ):
287- # TODO: Currently we only support quantizing BF16 DeepSeek V3/R1 models to FP8.
288- # To support all model architectures, the logic for determining which weights
289- # to quantize should be changed from referencing the target model's
290- # model.safetensors.index.json to using regex-based include/exclude lists
291- # (e.g. regex patterns for weights to quantize and weights to ignore).
292- model_index_file = os .path .join (save_path , "model.safetensors.index.json" )
293- config_file = os .path .join (save_path , "config.json" )
294-
295- # Check if files need to be downloaded
296- if not os .path .exists (model_index_file ) or not os .path .exists (config_file ):
297- print (f"Model index or config file not found in { save_path } " )
298- print (f"Downloading config files from HuggingFace: { self .base_model_repo } " )
299- try :
300- snapshot_download (
301- repo_id = self .base_model_repo ,
302- ignore_patterns = ["*.safetensors" ],
303- local_dir = save_path ,
304- local_dir_use_symlinks = False ,
305- )
306- except Exception as e :
307- raise RuntimeError (
308- f"Failed to download config files from HuggingFace repo "
309- f"'{ self .base_model_repo } '. Please check your network connection "
310- f"and ensure the repo_id is correct. Original error: { e } "
311- ) from e
312- print (f"✓ Model index file and config file downloaded to { save_path } " )
304+ for item in os .listdir (self .sft_model_path ):
305+ src = os .path .join (self .sft_model_path , item )
306+ dst = os .path .join (save_path , item )
307+ if os .path .isfile (src ) and not item .endswith (".safetensors" ):
308+ if not os .path .exists (dst ):
309+ shutil .copy2 (src , dst )
313310
314311 def _update_config_json (self , save_path : str ):
315312 config_file = os .path .join (save_path , "config.json" )
@@ -346,7 +343,6 @@ def _run_single_process(
346343 safetensor_files ,
347344 base_path ,
348345 save_path ,
349- weight_map ,
350346 base_weight_map ,
351347 dynamic_cache_size ,
352348 ):
@@ -357,7 +353,6 @@ def _run_single_process(
357353 safetensor_file ,
358354 base_path ,
359355 save_path ,
360- weight_map ,
361356 base_weight_map ,
362357 self .scale_search_kwargs ,
363358 True ,
@@ -377,7 +372,6 @@ def _run_multiprocess(
377372 safetensor_files ,
378373 base_path ,
379374 save_path ,
380- weight_map ,
381375 base_weight_map ,
382376 dynamic_cache_size ,
383377 ):
@@ -403,7 +397,6 @@ def _run_multiprocess(
403397 worker_file_groups [wid ],
404398 base_path ,
405399 save_path ,
406- weight_map ,
407400 base_weight_map ,
408401 self .scale_search_kwargs ,
409402 worker_devices [wid ],
@@ -487,7 +480,6 @@ def _worker_process_files(args):
487480 file_list ,
488481 base_path ,
489482 save_path ,
490- weight_map ,
491483 base_weight_map ,
492484 scale_search_kwargs ,
493485 device ,
@@ -512,7 +504,6 @@ def _worker_process_files(args):
512504 safetensor_file ,
513505 base_path ,
514506 save_path ,
515- weight_map ,
516507 base_weight_map ,
517508 scale_search_kwargs ,
518509 False ,
@@ -532,7 +523,6 @@ def _process_single_file(
532523 safetensor_file ,
533524 base_path ,
534525 fp8_path ,
535- weight_map ,
536526 base_weight_map ,
537527 scale_search_kwargs ,
538528 verbose ,
@@ -622,8 +612,9 @@ def _process_single_file(
622612 scale_inv_name = f"{ weight_name } _scale_inv"
623613
624614 should_ignore = any (ignore_pattern in weight_name for ignore_pattern in ignore_layers )
615+ should_quant = any (weight_name .endswith (suffix ) for suffix in SUFFIX_TO_QUANT )
625616
626- if scale_inv_name in weight_map and not should_ignore :
617+ if should_quant and not should_ignore :
627618 assert weight .element_size () == 2 , f"Expected BF16, got { weight .dtype } "
628619
629620 base_weight = load_base_weight (
0 commit comments