Skip to content

Commit e29a2ec

Browse files
authored
feat: add SmoothQuant calibration pipeline for HY3 (#322)
1 parent 0fbaed7 commit e29a2ec

26 files changed

Lines changed: 6313 additions & 16 deletions

angelslim/compressor/quant/core/vllm_calibrate_utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
* :mod:`.search` – KV-cache FP8 scale grid-search (per-tensor and
1414
per-head) with the value-capture hooks needed by the searchers.
1515
16-
The vLLM ``fused_moe.py`` patch only imports
17-
``collect_fused_moe_internal_stats`` from this package, which is
18-
re-exported via :mod:`.hooks`.
16+
Smooth / Smooth-Alpha-Search APIs have been moved to
17+
:mod:`angelslim.compressor.transform.smooth.vllm` — import from there
18+
directly.
1919
"""
2020

2121
from .hooks import (

angelslim/compressor/quant/core/vllm_calibrate_utils/hooks.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,6 @@ def setup_activation_hooks(model, kv_granularity="per-tensor"):
232232
if hasattr(layer, "w13_weight") and layer.w13_weight is not None:
233233
layer.w13_weight._vllm_layer_name = name
234234
layer.w13_weight._moe_activation_stats_of_model = model._moe_activation_stats
235-
print(
236-
f"[DEBUG] Set w13_weight._vllm_layer_name = {name}, "
237-
f"type={type(layer.w13_weight)}"
238-
)
239235
else:
240236
print(
241237
f"[DEBUG] Cannot set w13_weight._vllm_layer_name: "

angelslim/compressor/quant/core/vllm_calibrate_utils/search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
are kept module-private with the underscore prefix.
1414
"""
1515

16+
import os
17+
from concurrent.futures import ThreadPoolExecutor, as_completed
18+
1619
import torch
1720

1821
from ._common import _compute_perhead_layout, _find_layers, _get_dist_info, _get_kv_role
@@ -283,9 +286,6 @@ def __init__(
283286
self.num_steps = num_steps
284287

285288
def __call__(self, model):
286-
import os
287-
from concurrent.futures import ThreadPoolExecutor, as_completed
288-
289289
fp8_max = torch.finfo(torch.float8_e4m3fn).max # 448.0
290290

291291
# Collect raw kv tensors stored by the value hook
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""SmoothQuant transform module.
16+
17+
Three sub-packages share a common :mod:`.core` algorithm layer:
18+
19+
* :mod:`.core` — backend-agnostic tensor primitives (formulas, QDQ,
20+
RoPE-aware pairing, GQA expansion, alpha-search inner loop, smooth-stats
21+
serialisation). Imported by both the vLLM and convert pipelines.
22+
* :mod:`.vllm` — online stat collection on a live vLLM model: hook
23+
classes, ``setup_smooth_hooks`` / ``get_smooth_stats``, the TP-aware
24+
``SmoothAlphaSearcher``, and FusedMoE kernel-injection entry points.
25+
* :mod:`.convert` — offline weight conversion on a HuggingFace model:
26+
``apply_qk_smooth`` / ``apply_vo_smooth`` / ``apply_down_proj_smooth``
27+
(+ alpha-search variant), plus snapshot/verify utilities.
28+
29+
Top-level :mod:`.config` holds the dataclasses that travel with both
30+
pipelines (:class:`SmoothAlphaSearchConfig`).
31+
"""
32+
33+
from .config import SmoothAlphaSearchConfig
34+
35+
__all__ = [
36+
"SmoothAlphaSearchConfig",
37+
]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Configuration dataclasses shared by the vLLM (online) and convert
16+
(offline) smooth pipelines.
17+
18+
These are *pure data containers* — keep the module free of any heavy
19+
imports so it can be loaded from CLI scripts without pulling in torch
20+
distributed / vLLM machinery.
21+
"""
22+
23+
from dataclasses import dataclass
24+
25+
__all__ = [
26+
"SmoothAlphaSearchConfig",
27+
]
28+
29+
30+
@dataclass
31+
class SmoothAlphaSearchConfig:
32+
"""Configuration for smooth alpha grid search."""
33+
34+
alpha_min: float = 0.3
35+
alpha_max: float = 1.0
36+
alpha_steps: int = 8 # [0.3, 0.4, ..., 1.0]
37+
act_quant_method: str = "per_token" # "per_tensor" | "per_token"
38+
act_quant_type: str = "int8" # "int8" | "fp8"
39+
weight_quant_method: str = (
40+
"per_channel" # "per_tensor" | "per_channel" | "per_group" | "per_block"
41+
)
42+
weight_quant_type: str = "int8" # "int8" | "int4" | "fp8"
43+
weight_quant_bits: int = 8
44+
weight_group_size: int = 128 # per_group, -1 = per_channel
45+
block_size: int = 128 # per_block fp8
46+
use_ema_for_absmax: bool = False
47+
smooth_search_mode: str = "default" # "default" | "per-tensor-act-first"
48+
act_mul_min: float = 0.1 # per-tensor-act-first: multiplier range min
49+
act_mul_max: float = 1.0 # per-tensor-act-first: multiplier range max
50+
smooth_min: float = 1e-6 # per-tensor-act-first: smooth clamp lower bound
51+
smooth_max: float = 1e6 # per-tensor-act-first: smooth clamp upper bound
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""HuggingFace-side smooth pipeline (offline weight conversion).
16+
17+
Re-exports the public API consumed by
18+
``tools/smooth/convert_smooth_weights.py`` (Phase 2 driver).
19+
"""
20+
21+
from .apply_funcs import (
22+
apply_down_proj_smooth,
23+
apply_down_proj_smooth_from_search,
24+
apply_qk_smooth,
25+
apply_vo_smooth,
26+
)
27+
from .utils import (
28+
DEFAULT_KEY_MAP,
29+
HY_V3_KEY_MAP,
30+
LLAMA_KEY_MAP,
31+
MIXTRAL_KEY_MAP,
32+
PREDEFINED_KEY_MAPS,
33+
QWEN3_MOE_KEY_MAP,
34+
attn_key_to_hf_prefix,
35+
find_first_attn_module,
36+
get_submodule_safe,
37+
maybe_materialize,
38+
snapshot_attn_output_before,
39+
snapshot_mlp_outputs_before,
40+
verify_attn_output_diff,
41+
verify_mlp_output_diff,
42+
)
43+
44+
__all__ = [
45+
# apply
46+
"apply_qk_smooth",
47+
"apply_vo_smooth",
48+
"apply_down_proj_smooth",
49+
"apply_down_proj_smooth_from_search",
50+
# key maps
51+
"DEFAULT_KEY_MAP",
52+
"HY_V3_KEY_MAP",
53+
"LLAMA_KEY_MAP",
54+
"MIXTRAL_KEY_MAP",
55+
"QWEN3_MOE_KEY_MAP",
56+
"PREDEFINED_KEY_MAPS",
57+
# helpers
58+
"get_submodule_safe",
59+
"maybe_materialize",
60+
"attn_key_to_hf_prefix",
61+
# snapshot / verify
62+
"find_first_attn_module",
63+
"snapshot_attn_output_before",
64+
"snapshot_mlp_outputs_before",
65+
"verify_attn_output_diff",
66+
"verify_mlp_output_diff",
67+
]

0 commit comments

Comments
 (0)