Skip to content

Commit b0af806

Browse files
C-KRSWkrizaltang
andauthored
[fix]: Fix reference for weights to quant. (#276)
Co-authored-by: krizaltang <krizaltang@tencent.com>
1 parent 8f66b65 commit b0af806

5 files changed

Lines changed: 30 additions & 44 deletions

File tree

angelslim/compressor/quant/modules/daq/daq.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
import json
1919
import multiprocessing as mp
2020
import os
21+
import shutil
2122
from collections import OrderedDict, defaultdict
2223
from concurrent.futures import ThreadPoolExecutor
2324
from dataclasses import dataclass
2425
from glob import glob
2526

2627
import torch
27-
from huggingface_hub import snapshot_download
2828
from safetensors.torch import load_file, save_file
2929
from tqdm import tqdm
3030

@@ -40,6 +40,27 @@
4040
prefetch_base_shard,
4141
)
4242

43+
# Suffixes that identify weight tensors to be quantized.
44+
# Imported from fp8_quant_blockwise for consistency; any weight whose name
45+
# ends with one of these suffixes will be quantized by DAQ.
46+
SUFFIX_TO_QUANT = [
47+
".gate_and_up_proj.weight",
48+
".gate_proj.weight",
49+
".up_proj.weight",
50+
".down_proj.weight",
51+
".q_a_proj.weight",
52+
".q_b_proj.weight",
53+
".kv_a_proj_with_mqa.weight",
54+
".kv_b_proj.weight",
55+
".qkv_proj.weight",
56+
".q_proj.weight",
57+
".k_proj.weight",
58+
".v_proj.weight",
59+
".o_proj.weight",
60+
".experts.gate_up_proj",
61+
".experts.down_proj",
62+
]
63+
4364
__all__ = ["DAQ"]
4465

4566

@@ -127,7 +148,6 @@ def __init__(self, quant_config, sft_model_path: str):
127148
self.quantization_method = quant_config.quantization_method
128149
self.num_workers = quant_config.num_workers
129150
self.ignore_layers = getattr(quant_config, "ignore_layers", []) or []
130-
self.base_model_repo = quant_config.base_model_repo
131151

132152
gpus_str = quant_config.gpus
133153
if gpus_str:
@@ -228,7 +248,6 @@ def run(self, save_path: str):
228248
model_index_file = os.path.join(save_path, "model.safetensors.index.json")
229249
with open(model_index_file, "r") as f:
230250
model_index = json.load(f)
231-
weight_map = model_index["weight_map"]
232251

233252
base_weight_map = get_weight_map(self.base_model_path)
234253
if not base_weight_map:
@@ -253,7 +272,6 @@ def run(self, save_path: str):
253272
safetensor_files,
254273
self.base_model_path,
255274
save_path,
256-
weight_map,
257275
base_weight_map,
258276
dynamic_cache_size,
259277
)
@@ -262,7 +280,6 @@ def run(self, save_path: str):
262280
safetensor_files,
263281
self.base_model_path,
264282
save_path,
265-
weight_map,
266283
base_weight_map,
267284
dynamic_cache_size,
268285
)
@@ -284,32 +301,12 @@ def run(self, save_path: str):
284301
print_info("DAQ quantization complete!")
285302

286303
def _prepare_output_dir(self, save_path: str):
287-
# TODO: Currently we only support quantizing BF16 DeepSeek V3/R1 models to FP8.
288-
# To support all model architectures, the logic for determining which weights
289-
# to quantize should be changed from referencing the target model's
290-
# model.safetensors.index.json to using regex-based include/exclude lists
291-
# (e.g. regex patterns for weights to quantize and weights to ignore).
292-
model_index_file = os.path.join(save_path, "model.safetensors.index.json")
293-
config_file = os.path.join(save_path, "config.json")
294-
295-
# Check if files need to be downloaded
296-
if not os.path.exists(model_index_file) or not os.path.exists(config_file):
297-
print(f"Model index or config file not found in {save_path}")
298-
print(f"Downloading config files from HuggingFace: {self.base_model_repo}")
299-
try:
300-
snapshot_download(
301-
repo_id=self.base_model_repo,
302-
ignore_patterns=["*.safetensors"],
303-
local_dir=save_path,
304-
local_dir_use_symlinks=False,
305-
)
306-
except Exception as e:
307-
raise RuntimeError(
308-
f"Failed to download config files from HuggingFace repo "
309-
f"'{self.base_model_repo}'. Please check your network connection "
310-
f"and ensure the repo_id is correct. Original error: {e}"
311-
) from e
312-
print(f"✓ Model index file and config file downloaded to {save_path}")
304+
for item in os.listdir(self.sft_model_path):
305+
src = os.path.join(self.sft_model_path, item)
306+
dst = os.path.join(save_path, item)
307+
if os.path.isfile(src) and not item.endswith(".safetensors"):
308+
if not os.path.exists(dst):
309+
shutil.copy2(src, dst)
313310

314311
def _update_config_json(self, save_path: str):
315312
config_file = os.path.join(save_path, "config.json")
@@ -346,7 +343,6 @@ def _run_single_process(
346343
safetensor_files,
347344
base_path,
348345
save_path,
349-
weight_map,
350346
base_weight_map,
351347
dynamic_cache_size,
352348
):
@@ -357,7 +353,6 @@ def _run_single_process(
357353
safetensor_file,
358354
base_path,
359355
save_path,
360-
weight_map,
361356
base_weight_map,
362357
self.scale_search_kwargs,
363358
True,
@@ -377,7 +372,6 @@ def _run_multiprocess(
377372
safetensor_files,
378373
base_path,
379374
save_path,
380-
weight_map,
381375
base_weight_map,
382376
dynamic_cache_size,
383377
):
@@ -403,7 +397,6 @@ def _run_multiprocess(
403397
worker_file_groups[wid],
404398
base_path,
405399
save_path,
406-
weight_map,
407400
base_weight_map,
408401
self.scale_search_kwargs,
409402
worker_devices[wid],
@@ -487,7 +480,6 @@ def _worker_process_files(args):
487480
file_list,
488481
base_path,
489482
save_path,
490-
weight_map,
491483
base_weight_map,
492484
scale_search_kwargs,
493485
device,
@@ -512,7 +504,6 @@ def _worker_process_files(args):
512504
safetensor_file,
513505
base_path,
514506
save_path,
515-
weight_map,
516507
base_weight_map,
517508
scale_search_kwargs,
518509
False,
@@ -532,7 +523,6 @@ def _process_single_file(
532523
safetensor_file,
533524
base_path,
534525
fp8_path,
535-
weight_map,
536526
base_weight_map,
537527
scale_search_kwargs,
538528
verbose,
@@ -622,8 +612,9 @@ def _process_single_file(
622612
scale_inv_name = f"{weight_name}_scale_inv"
623613

624614
should_ignore = any(ignore_pattern in weight_name for ignore_pattern in ignore_layers)
615+
should_quant = any(weight_name.endswith(suffix) for suffix in SUFFIX_TO_QUANT)
625616

626-
if scale_inv_name in weight_map and not should_ignore:
617+
if should_quant and not should_ignore:
627618
assert weight.element_size() == 2, f"Expected BF16, got {weight.dtype}"
628619

629620
base_weight = load_base_weight(

angelslim/utils/config_parser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ class QuantizationConfig:
229229
scale_search: Optional[Dict[str, Any]] = field(default=None)
230230
num_workers: int = field(default=8)
231231
gpus: Optional[str] = field(default=None)
232-
base_model_repo: Optional[str] = field(default=None)
233232

234233

235234
@dataclass

configs/deepseek_r1/fp8_daq/deepseek_r1_daq_fp8_w8a8_block.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ compression:
2020
bits: 8
2121
# DAQ-specific: path to the base (pretrained) model
2222
base_model_path: deepseek-ai/DeepSeek-R1-Base
23-
base_model_repo: deepseek-ai/DeepSeek-R1
2423
# Set to true if the base model is FP8 format
2524
base_is_fp8: true
2625
# Optimization metric: "sign" (sign preservation rate),

configs/deepseek_r1/fp8_daq/deepseek_r1_daq_fp8_w8a8_channel.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ compression:
2020
bits: 8
2121
# DAQ-specific: path to the base (pretrained) model
2222
base_model_path: deepseek-ai/DeepSeek-R1-Base
23-
base_model_repo: deepseek-ai/DeepSeek-R1
2423
# Set to true if the base model is FP8 format
2524
base_is_fp8: true
2625
# Optimization metric: "sign" (sign preservation rate),

docs/source/features/quantization/daq.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ python3 tools/run.py -c configs/deepseek_r1/fp8_daq/deepseek_r1_daq_fp8_w8a8_blo
4343
- `quantization.name`:压缩算法选填`daq`
4444
- `quantization.bits`:目标量化比特数,如fp8量化对应填写8bit。
4545
- `quantization.base_model_path`:基座模型路径。
46-
- `quantization.base_model_repo`:基座模型在huggingface的路径。
4746
- `quantization.base_is_fp8`:基座模型是否是FP8格式。
4847
- `quantization.metric`:优化指标,选填`sign``cosine``mse`。详细说明可参见[指标说明](#指标说明)[技术报告](https://arxiv.org/abs/2603.22324)
4948
- `quantization.quantization_method`:量化方式,选填`blockwise``per_channel`。详细说明可参见[量化方式](#量化方式)
@@ -63,7 +62,6 @@ compression:
6362
name: daq
6463
bits: 8
6564
base_model_path: deepseek-ai/DeepSeek-R1-Base # DAQ-specific: path to the base model
66-
base_model_repo: deepseek-ai/DeepSeek-R1
6765
base_is_fp8: true # Set to true if the base model is FP8 format
6866
metric: cosine # Optimization metric: "sign","cosine",or "mse"
6967
quantization_method: blockwise # Quantization method: "blockwise" or "per_channel"

0 commit comments

Comments
 (0)