Skip to content

Commit 20119fd

Browse files
authored
feature: support Qwen3_VL_MoE static quantization (#209) (#210)
1 parent 97c467b commit 20119fd

14 files changed

Lines changed: 695 additions & 21 deletions

File tree

angelslim/compressor/quant/core/quant_func.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,17 @@ def quantize_weight_per_tensor_fp8(
7373
) -> Tuple[torch.Tensor, float]:
7474
finfo = torch.finfo(torch.float8_e4m3fn)
7575

76+
squeeze_dim = False
77+
while scale.ndim < tensor.ndim:
78+
scale = scale.unsqueeze(-1)
79+
squeeze_dim = True
80+
7681
qweight = (tensor / scale).clamp(min=finfo.min, max=finfo.max)
7782
# Return both float8 data and the inverse scale (as float),
7883
# as both required as inputs to torch._scaled_mm
7984
qweight = qweight.to(torch.float8_e4m3fn)
85+
if squeeze_dim:
86+
scale = scale.squeeze(-1)
8087
scale = scale.float()
8188
return qweight, scale
8289

angelslim/compressor/quant/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .gptq.gptq import GPTQ # noqa: F401
2020
from .gptq.gptq_module import GPTQModule # noqa: F401
2121
from .helper_layer import GPTQQuantLinear # noqa: F401
22+
from .helper_layer import MoEQDQModule # noqa: F401
2223
from .helper_layer import NVFP4QDQModule # noqa: F401
2324
from .helper_layer import QDQModule # noqa: F401
2425
from .helper_layer import QDQSingleModule # noqa: F401

angelslim/compressor/quant/modules/helper_layer.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,3 +1023,80 @@ def _unpack_tensor(input: torch.Tensor):
10231023
deq_data.shape[0], deq_data.shape[1] // block_size, -1
10241024
) * per_block_scale.unsqueeze(-1)
10251025
return deq_data.view(-1)[: np.prod(self.shape)].reshape(self.shape).to(dtype)
1026+
1027+
1028+
class MoEQDQModule(torch.nn.Module):
1029+
def __init__(
1030+
self,
1031+
gate_proj: torch.nn.Parameter,
1032+
up_proj: torch.nn.Parameter,
1033+
down_proj: torch.nn.Parameter,
1034+
gate_proj_weight_scale: torch.nn.Parameter,
1035+
up_proj_weight_scale: torch.nn.Parameter,
1036+
down_proj_weight_scale: torch.nn.Parameter,
1037+
gate_up_proj_input_scale: torch.nn.Parameter,
1038+
down_proj_input_scale: torch.nn.Parameter,
1039+
):
1040+
super().__init__()
1041+
quant_gate_weight, _ = quantize_weight_per_tensor_fp8(
1042+
gate_proj, gate_proj_weight_scale
1043+
)
1044+
quant_up_weight, _ = quantize_weight_per_tensor_fp8(
1045+
up_proj, up_proj_weight_scale
1046+
)
1047+
quant_down_weight, _ = quantize_weight_per_tensor_fp8(
1048+
down_proj, down_proj_weight_scale
1049+
)
1050+
quant_gate_up_weight = torch.cat([quant_gate_weight, quant_up_weight], dim=-1)
1051+
1052+
self.gate_up_proj = torch.nn.Parameter(
1053+
quant_gate_up_weight, requires_grad=False
1054+
)
1055+
self.down_proj = torch.nn.Parameter(quant_down_weight, requires_grad=False)
1056+
1057+
gate_proj_weight_scale = (
1058+
gate_proj_weight_scale.view(-1)
1059+
if gate_proj_weight_scale.ndim == 0
1060+
else gate_proj_weight_scale
1061+
)
1062+
up_proj_weight_scale = (
1063+
up_proj_weight_scale.view(-1)
1064+
if up_proj_weight_scale.ndim == 0
1065+
else up_proj_weight_scale
1066+
)
1067+
down_proj_weight_scale = (
1068+
down_proj_weight_scale.view(-1)
1069+
if down_proj_weight_scale.ndim == 0
1070+
else down_proj_weight_scale
1071+
)
1072+
gate_up_proj_weight_scale = torch.cat(
1073+
[gate_proj_weight_scale, up_proj_weight_scale], dim=-1
1074+
)
1075+
1076+
self.gate_up_proj_weight_scale = torch.nn.Parameter(
1077+
gate_up_proj_weight_scale, requires_grad=False
1078+
)
1079+
self.down_proj_weight_scale = torch.nn.Parameter(
1080+
down_proj_weight_scale, requires_grad=False
1081+
)
1082+
1083+
down_proj_input_scale = (
1084+
down_proj_input_scale.view(-1)
1085+
if down_proj_input_scale.ndim == 0
1086+
else down_proj_input_scale.squeeze()
1087+
)
1088+
gate_up_proj_input_scale = (
1089+
gate_up_proj_input_scale.view(-1)
1090+
if gate_up_proj_input_scale.ndim == 0
1091+
else gate_up_proj_input_scale.squeeze()
1092+
)
1093+
1094+
self.gate_up_proj_input_scale = torch.nn.Parameter(
1095+
gate_up_proj_input_scale, requires_grad=False
1096+
)
1097+
self.down_proj_input_scale = torch.nn.Parameter(
1098+
down_proj_input_scale, requires_grad=False
1099+
)
1100+
1101+
def forward(self, x):
1102+
pass

angelslim/compressor/quant/observers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .abs_max_activation import AbsmaxPerchannelObserver # noqa: F401
1616
from .abs_max_activation import AbsmaxPertensorObserver # noqa: F401
1717
from .abs_max_activation import AbsMaxTokenWiseActObserver # noqa: F401; noqa: F401
18+
from .abs_max_activation import MoEAbsmaxPertensorObserver # noqa: F401
1819
from .abs_max_weight import AbsMaxChannelWiseWeightObserver # noqa: F401
1920
from .base_observer import BaseObserver, ParentObserver # noqa: F401
2021
from .ema_activation import EMAObserver # noqa: F401

angelslim/compressor/quant/observers/abs_max_activation.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"AbsmaxPertensorObserver",
2121
"AbsMaxTokenWiseActObserver",
2222
"AbsmaxPerchannelObserver",
23+
"MoEAbsmaxPertensorObserver",
2324
]
2425

2526

@@ -217,3 +218,91 @@ def zero_points(self):
217218
if self._zero_point is None:
218219
self.cal_thresholds()
219220
return self._zero_point
221+
222+
223+
class MoEAbsmaxPertensorObserver(BaseObserver):
224+
def __init__(self, layer_name=None, quant_bits=8, **kwargs):
225+
super(MoEAbsmaxPertensorObserver, self).__init__(quant_bits=quant_bits)
226+
self.layer_name = layer_name
227+
self._scale = None
228+
self._zero_point = None
229+
self._min = None
230+
self._max = torch.tensor(1e-7, dtype=torch.float32)
231+
self.step = 0
232+
self.dtype = None
233+
self.parent_observer = (
234+
kwargs["parent_observer"]
235+
if kwargs and "parent_observer" in kwargs
236+
else None
237+
)
238+
239+
def forward(self, inputs):
240+
"""Calculate forward pass."""
241+
self.step += 1
242+
if not self.dtype:
243+
self.dtype = inputs.dtype
244+
if inputs.numel() > 0:
245+
self._min, self._max = self._cal_min_max(inputs)
246+
if self.parent_observer is not None:
247+
self.parent_observer.update(self._min, self._max, self.step)
248+
else:
249+
assert self.parent_observer is not None
250+
self._update_min_max(self.parent_observer.min, self.parent_observer.max)
251+
return inputs
252+
253+
def _cal_min_max(self, inputs):
254+
if inputs.dim() >= 2:
255+
abs_inputs = torch.abs(inputs)
256+
batch_size = abs_inputs.shape[0]
257+
abs_inputs_flat = abs_inputs.view(
258+
batch_size, -1
259+
) # [batch_size, seq_len * hidden_dim]
260+
abs_max_val, _ = torch.max(
261+
abs_inputs_flat, dim=1, keepdim=True
262+
) # [batch_size, 1]
263+
min_threshold = self._max.to(abs_max_val.device).expand_as(abs_max_val)
264+
abs_max_val = torch.maximum(abs_max_val, min_threshold)
265+
else:
266+
abs_max_val = torch.max(torch.abs(inputs))
267+
if abs_max_val.data < self._max.data:
268+
abs_max_val = self._max
269+
abs_max_val = abs_max_val.unsqueeze(0).unsqueeze(0) # [1, 1]
270+
return 0, abs_max_val.to(inputs.device)
271+
272+
def _update_min_max(self, min, max):
273+
if min is not None and max is not None:
274+
if self._min is None or min < self._min:
275+
self._min = min
276+
if self._max is None or max > self._max:
277+
self._max = max
278+
279+
def cal_thresholds(self):
280+
"""Compute thresholds for MAX function."""
281+
if self._scale is None:
282+
self._scale = self._max
283+
self._zero_point = torch.zeros_like(self._scale)
284+
285+
def quant_axis(self):
286+
"""Return quantization axis."""
287+
return -1
288+
289+
def scales(self):
290+
"""Return output scales."""
291+
if self.step == 0 and self.parent_observer is not None:
292+
self._update_min_max(self.parent_observer.min, self.parent_observer.max)
293+
self.step = self.parent_observer.step
294+
if self.step == 0:
295+
raise ValueError(
296+
"AbsmaxPertensorObserver scales must calibrate data first!"
297+
)
298+
if self._scale is None:
299+
self.cal_thresholds()
300+
if self.dtype:
301+
self._scale = self._scale.type(self.dtype)
302+
return self._scale
303+
304+
def zero_points(self):
305+
"""Return output zero points."""
306+
if self._zero_point is None:
307+
self.cal_thresholds()
308+
return self._zero_point

angelslim/compressor/quant/ptq.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
from safetensors.torch import load_file
21+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts
2122

2223
from ...utils import find_parent_layer_and_sub_name, print_info
2324
from ..compressor_factory import CompressorFactory
@@ -284,6 +285,20 @@ def _convert(self):
284285

285286
if qdq_module is not sub_layer:
286287
setattr(parent_layer, sub_name, qdq_module)
288+
289+
# 3. insert moe qdq module
290+
# For qwen3_vl_moe models, we need to insert MoEQDQModule for MOE experts,
291+
# since these modules contain gate_up_proj and down_proj, which are defined as
292+
# nn.Parameters, not nn.Linear.
293+
if Qwen3VLMoeTextExperts in self.quant_model.observer_layer_classes:
294+
for name, sub_layer in self.quant_model.model.named_modules():
295+
parent_layer, sub_name = find_parent_layer_and_sub_name(
296+
quant_convert_module, name
297+
)
298+
moe_qdq_module = self.quant_model.get_moe_qdq_module(sub_layer, name)
299+
if moe_qdq_module is not sub_layer:
300+
setattr(parent_layer, sub_name, moe_qdq_module)
301+
287302
self.quant_model.quantized = True
288303

289304
def __getattr__(self, item):

angelslim/data/multimodal_dataset.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949

5050
def _load_file_based_dataset(self, data_path: str, num_samples: int):
5151
"""Load dataset from local file system"""
52-
image_dir = os.path.join(os.path.dirname(data_path), "images")
52+
self.data_path = data_path
5353
line_count = 0
5454

5555
with open(data_path, "r") as f:
@@ -58,29 +58,84 @@ def _load_file_based_dataset(self, data_path: str, num_samples: int):
5858
break
5959

6060
data = json.loads(line.strip())
61-
image_path = os.path.join(image_dir, data["img_path"])
6261

63-
# Prepare chat messages with image
64-
messages = [
65-
{
66-
"role": "user",
67-
"content": [
68-
{"type": "image", "image": image_path},
69-
{
70-
"type": "text",
71-
"text": data["question"].replace("<image>", ""),
72-
},
73-
],
74-
},
75-
{
76-
"role": "assistant",
77-
"content": [{"type": "text", "text": data["answer"]}],
78-
},
79-
]
62+
# Validate format
63+
assert "messages" in data or "question" in data, "JSON format error"
64+
65+
# Prepare messages
66+
messages = self._prepare_messages(data)
8067

8168
self._process_and_append(messages)
8269
line_count += 1
8370

71+
def _prepare_messages(self, data: Dict) -> List[Dict]:
72+
image_dir = os.path.join(os.path.dirname(self.data_path), "images")
73+
if "question" in data:
74+
# Prepare chat messages with image
75+
messages = []
76+
if "system_prompt" in data:
77+
messages.extend(
78+
[
79+
{
80+
"role": "system",
81+
"content": [
82+
{"type": "text", "text": data["system_prompt"]}
83+
],
84+
}
85+
]
86+
)
87+
if "img_path" in data:
88+
image_path = os.path.join(image_dir, data["img_path"])
89+
messages.extend(
90+
[
91+
{
92+
"role": "user",
93+
"content": [
94+
{"type": "image", "image": image_path},
95+
{
96+
"type": "text",
97+
"text": data["question"].replace("<image>", ""),
98+
},
99+
],
100+
},
101+
{
102+
"role": "assistant",
103+
"content": [{"type": "text", "text": data["answer"]}],
104+
},
105+
]
106+
)
107+
else:
108+
messages.extend(
109+
[
110+
{
111+
"role": "user",
112+
"content": [
113+
{"type": "text", "text": data["question"]},
114+
],
115+
},
116+
{
117+
"role": "assistant",
118+
"content": [{"type": "text", "text": data["answer"]}],
119+
},
120+
]
121+
)
122+
elif "messages" in data:
123+
messages = data["messages"]
124+
for message in messages:
125+
if message["role"] == "user":
126+
for content in message["content"]:
127+
if content["type"] == "image":
128+
content["image"] = os.path.join(image_dir, content["image"])
129+
else:
130+
raise ValueError("Invalid data format")
131+
132+
# adapt to hunyuan_vl
133+
if self.model_name in ["HunyuanVL"]:
134+
for message in messages:
135+
if message["role"] == "assistant" or message["role"] == "system":
136+
message["content"] = message["content"][0]["text"]
137+
return messages
138+
84139
def _load_hf_dataset(self, dataset: str, num_samples: int):
85140
"""Load dataset from Hugging Face format"""
86141
dataset = load_dataset(dataset, split="test")
@@ -108,7 +163,7 @@ def _load_hf_dataset(self, dataset: str, num_samples: int):
108163

109164
def _process_and_append(self, messages: List[Dict]):
110165
"""Process messages and append to dataset"""
111-
if self.model_name in ["Qwen3VL"]:
166+
if self.model_name in ["Qwen3VL", "Qwen3VLMoE"]:
112167
inputs = self.processor.apply_chat_template(
113168
messages,
114169
tokenize=True,

angelslim/models/base_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ def get_qdq_module(self, sub_layer, name):
147147
raise NotImplementedError
148148
return q_linear
149149

150+
def get_moe_qdq_module(self, sub_layer, name):
151+
return sub_layer
152+
150153
def get_nvfp4_qdq_module(self, sub_layer, name):
151154
act_scale, weight_scale, weight_scale_2 = None, None, None
152155
block_size = self.quant_config.quant_algo_info["block_size"]

angelslim/models/vlm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414

1515
from .hunyuan_vl import HunyuanVL # noqa: F401
1616
from .qwen3_vl import Qwen3VL # noqa: F401
17+
from .qwen3_vl_moe import Qwen3VLMoE # noqa: F401
1718
from .qwen_vl import QwenVL # noqa: F401

0 commit comments

Comments
 (0)