Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions angelslim/compressor/quant/core/quant_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,17 @@ def quantize_weight_per_tensor_fp8(
) -> Tuple[torch.Tensor, float]:
finfo = torch.finfo(torch.float8_e4m3fn)

squeeze_dim = False
while scale.ndim < tensor.ndim:
scale = scale.unsqueeze(-1)
squeeze_dim = True

qweight = (tensor / scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
if squeeze_dim:
scale = scale.squeeze(-1)
scale = scale.float()
return qweight, scale

Expand Down
1 change: 1 addition & 0 deletions angelslim/compressor/quant/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .gptq.gptq import GPTQ # noqa: F401
from .gptq.gptq_module import GPTQModule # noqa: F401
from .helper_layer import GPTQQuantLinear # noqa: F401
from .helper_layer import MoEQDQModule # noqa: F401
from .helper_layer import NVFP4QDQModule # noqa: F401
from .helper_layer import QDQModule # noqa: F401
from .helper_layer import QDQSingleModule # noqa: F401
Expand Down
77 changes: 77 additions & 0 deletions angelslim/compressor/quant/modules/helper_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,3 +1023,80 @@ def _unpack_tensor(input: torch.Tensor):
deq_data.shape[0], deq_data.shape[1] // block_size, -1
) * per_block_scale.unsqueeze(-1)
return deq_data.view(-1)[: np.prod(self.shape)].reshape(self.shape).to(dtype)


class MoEQDQModule(torch.nn.Module):
def __init__(
self,
gate_proj: torch.nn.Parameter,
up_proj: torch.nn.Parameter,
down_proj: torch.nn.Parameter,
gate_proj_weight_scale: torch.nn.Parameter,
up_proj_weight_scale: torch.nn.Parameter,
down_proj_weight_scale: torch.nn.Parameter,
gate_up_proj_input_scale: torch.nn.Parameter,
down_proj_input_scale: torch.nn.Parameter,
):
super().__init__()
quant_gate_weight, _ = quantize_weight_per_tensor_fp8(
gate_proj, gate_proj_weight_scale
)
quant_up_weight, _ = quantize_weight_per_tensor_fp8(
up_proj, up_proj_weight_scale
)
quant_down_weight, _ = quantize_weight_per_tensor_fp8(
down_proj, down_proj_weight_scale
)
quant_gate_up_weight = torch.cat([quant_gate_weight, quant_up_weight], dim=-1)

self.gate_up_proj = torch.nn.Parameter(
quant_gate_up_weight, requires_grad=False
)
self.down_proj = torch.nn.Parameter(quant_down_weight, requires_grad=False)

gate_proj_weight_scale = (
gate_proj_weight_scale.view(-1)
if gate_proj_weight_scale.ndim == 0
else gate_proj_weight_scale
)
up_proj_weight_scale = (
up_proj_weight_scale.view(-1)
if up_proj_weight_scale.ndim == 0
else up_proj_weight_scale
)
down_proj_weight_scale = (
down_proj_weight_scale.view(-1)
if down_proj_weight_scale.ndim == 0
else down_proj_weight_scale
)
gate_up_proj_weight_scale = torch.cat(
[gate_proj_weight_scale, up_proj_weight_scale], dim=-1
)

self.gate_up_proj_weight_scale = torch.nn.Parameter(
gate_up_proj_weight_scale, requires_grad=False
)
self.down_proj_weight_scale = torch.nn.Parameter(
down_proj_weight_scale, requires_grad=False
)

down_proj_input_scale = (
down_proj_input_scale.view(-1)
if down_proj_input_scale.ndim == 0
else down_proj_input_scale.squeeze()
)
gate_up_proj_input_scale = (
gate_up_proj_input_scale.view(-1)
if gate_up_proj_input_scale.ndim == 0
else gate_up_proj_input_scale.squeeze()
)

self.gate_up_proj_input_scale = torch.nn.Parameter(
gate_up_proj_input_scale, requires_grad=False
)
self.down_proj_input_scale = torch.nn.Parameter(
down_proj_input_scale, requires_grad=False
)

def forward(self, x):
pass
1 change: 1 addition & 0 deletions angelslim/compressor/quant/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .abs_max_activation import AbsmaxPerchannelObserver # noqa: F401
from .abs_max_activation import AbsmaxPertensorObserver # noqa: F401
from .abs_max_activation import AbsMaxTokenWiseActObserver # noqa: F401; noqa: F401
from .abs_max_activation import MoEAbsmaxPertensorObserver # noqa: F401
from .abs_max_weight import AbsMaxChannelWiseWeightObserver # noqa: F401
from .base_observer import BaseObserver, ParentObserver # noqa: F401
from .ema_activation import EMAObserver # noqa: F401
Expand Down
89 changes: 89 additions & 0 deletions angelslim/compressor/quant/observers/abs_max_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"AbsmaxPertensorObserver",
"AbsMaxTokenWiseActObserver",
"AbsmaxPerchannelObserver",
"MoEAbsmaxPertensorObserver",
]


Expand Down Expand Up @@ -217,3 +218,91 @@ def zero_points(self):
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point


class MoEAbsmaxPertensorObserver(BaseObserver):
def __init__(self, layer_name=None, quant_bits=8, **kwargs):
super(MoEAbsmaxPertensorObserver, self).__init__(quant_bits=quant_bits)
self.layer_name = layer_name
self._scale = None
self._zero_point = None
self._min = None
self._max = torch.tensor(1e-7, dtype=torch.float32)
self.step = 0
self.dtype = None
self.parent_observer = (
kwargs["parent_observer"]
if kwargs and "parent_observer" in kwargs
else None
)

def forward(self, inputs):
"""Calculate forward pass."""
self.step += 1
if not self.dtype:
self.dtype = inputs.dtype
if inputs.numel() > 0:
self._min, self._max = self._cal_min_max(inputs)
if self.parent_observer is not None:
self.parent_observer.update(self._min, self._max, self.step)
else:
assert self.parent_observer is not None
self._update_min_max(self.parent_observer.min, self.parent_observer.max)
return inputs

def _cal_min_max(self, inputs):
if inputs.dim() >= 2:
abs_inputs = torch.abs(inputs)
batch_size = abs_inputs.shape[0]
abs_inputs_flat = abs_inputs.view(
batch_size, -1
) # [batch_size, seq_len * hidden_dim]
abs_max_val, _ = torch.max(
abs_inputs_flat, dim=1, keepdim=True
) # [batch_size, 1]
min_threshold = self._max.to(abs_max_val.device).expand_as(abs_max_val)
abs_max_val = torch.maximum(abs_max_val, min_threshold)
else:
abs_max_val = torch.max(torch.abs(inputs))
if abs_max_val.data < self._max.data:
abs_max_val = self._max
abs_max_val = abs_max_val.unsqueeze(0).unsqueeze(0) # [1, 1]
return 0, abs_max_val.to(inputs.device)

def _update_min_max(self, min, max):
if min is not None and max is not None:
if self._min is None or min < self._min:
self._min = min
if self._max is None or max > self._max:
self._max = max

def cal_thresholds(self):
"""Compute thresholds for MAX function."""
if self._scale is None:
self._scale = self._max
self._zero_point = torch.zeros_like(self._scale)

def quant_axis(self):
"""Return quantization axis."""
return -1

def scales(self):
"""Return output scales."""
if self.step == 0 and self.parent_observer is not None:
self._update_min_max(self.parent_observer.min, self.parent_observer.max)
self.step = self.parent_observer.step
if self.step == 0:
raise ValueError(
"AbsmaxPertensorObserver scales must calibrate data first!"
)
if self._scale is None:
self.cal_thresholds()
if self.dtype:
self._scale = self._scale.type(self.dtype)
return self._scale

def zero_points(self):
"""Return output zero points."""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
15 changes: 15 additions & 0 deletions angelslim/compressor/quant/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from safetensors.torch import load_file
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts

from ...utils import find_parent_layer_and_sub_name, print_info
from ..compressor_factory import CompressorFactory
Expand Down Expand Up @@ -284,6 +285,20 @@ def _convert(self):

if qdq_module is not sub_layer:
setattr(parent_layer, sub_name, qdq_module)

# 3. insert moe qdq module
# For qwen3_vl_moe models, we need to insert MoEQDQModule for MOE experts,
# since these modules contain gate_up_proj and down_proj, which are defined as
# nn.Parameters, not nn.Linear.
if Qwen3VLMoeTextExperts in self.quant_model.observer_layer_classes:
for name, sub_layer in self.quant_model.model.named_modules():
parent_layer, sub_name = find_parent_layer_and_sub_name(
quant_convert_module, name
)
moe_qdq_module = self.quant_model.get_moe_qdq_module(sub_layer, name)
if moe_qdq_module is not sub_layer:
setattr(parent_layer, sub_name, moe_qdq_module)

self.quant_model.quantized = True

def __getattr__(self, item):
Expand Down
95 changes: 75 additions & 20 deletions angelslim/data/multimodal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(

def _load_file_based_dataset(self, data_path: str, num_samples: int):
"""Load dataset from local file system"""
image_dir = os.path.join(os.path.dirname(data_path), "images")
self.data_path = data_path
line_count = 0

with open(data_path, "r") as f:
Expand All @@ -58,29 +58,84 @@ def _load_file_based_dataset(self, data_path: str, num_samples: int):
break

data = json.loads(line.strip())
image_path = os.path.join(image_dir, data["img_path"])

# Prepare chat messages with image
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{
"type": "text",
"text": data["question"].replace("<image>", ""),
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": data["answer"]}],
},
]
# Validate format
assert "messages" in data or "question" in data, "JSON format error"

# Prepare messages
messages = self._prepare_messages(data)

self._process_and_append(messages)
line_count += 1

def _prepare_messages(self, data: Dict) -> List[Dict]:
image_dir = os.path.join(os.path.dirname(self.data_path), "images")
if "question" in data:
# Prepare chat messages with image
messages = []
if "system_prompt" in data:
messages.extend(
[
{
"role": "system",
"content": [
{"type": "text", "text": data["system_prompt"]}
],
}
]
)
if "img_path" in data:
image_path = os.path.join(image_dir, data["img_path"])
messages.extend(
[
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{
"type": "text",
"text": data["question"].replace("<image>", ""),
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": data["answer"]}],
},
]
)
else:
messages.extend(
[
{
"role": "user",
"content": [
{"type": "text", "text": data["question"]},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": data["answer"]}],
},
]
)
elif "messages" in data:
messages = data["messages"]
for message in messages:
if message["role"] == "user":
for content in message["content"]:
if content["type"] == "image":
content["image"] = os.path.join(image_dir, content["image"])
else:
raise ValueError("Invalid data format")

# adapt to hunyuan_vl
if self.model_name in ["HunyuanVL"]:
for message in messages:
if message["role"] == "assistant" or message["role"] == "system":
message["content"] = message["content"][0]["text"]
return messages

def _load_hf_dataset(self, dataset: str, num_samples: int):
"""Load dataset from Hugging Face format"""
dataset = load_dataset(dataset, split="test")
Expand Down Expand Up @@ -108,7 +163,7 @@ def _load_hf_dataset(self, dataset: str, num_samples: int):

def _process_and_append(self, messages: List[Dict]):
"""Process messages and append to dataset"""
if self.model_name in ["Qwen3VL"]:
if self.model_name in ["Qwen3VL", "Qwen3VLMoE"]:
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
Expand Down
3 changes: 3 additions & 0 deletions angelslim/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def get_qdq_module(self, sub_layer, name):
raise NotImplementedError
return q_linear

def get_moe_qdq_module(self, sub_layer, name):
return sub_layer

def get_nvfp4_qdq_module(self, sub_layer, name):
act_scale, weight_scale, weight_scale_2 = None, None, None
block_size = self.quant_config.quant_algo_info["block_size"]
Expand Down
1 change: 1 addition & 0 deletions angelslim/models/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@

from .hunyuan_vl import HunyuanVL # noqa: F401
from .qwen3_vl import Qwen3VL # noqa: F401
from .qwen3_vl_moe import Qwen3VLMoE # noqa: F401
from .qwen_vl import QwenVL # noqa: F401
Loading
Loading