Skip to content

Commit f76d316

Browse files
authored
fix: update qwen3_omni with transformers>=5.0 (#296)
1 parent c76d35e commit f76d316

12 files changed

Lines changed: 135 additions & 25 deletions

File tree

angelslim/compressor/quant/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, config, global_config=None):
6666
kv_cache_quant_method = quantization_args.quant_method.get("kv_cache", None)
6767
self.cpu_convert = quantization_args.cpu_convert
6868
self.save_name = quantization_args.save_name
69+
self.quant_talker = getattr(quantization_args, "quant_talker", False)
6970

7071
if global_config:
7172
self.max_seq_length = global_config.max_seq_length

angelslim/data/dataloader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def create_data_loader(
4545
model_name: str = None,
4646
quantization_config: str = None,
4747
is_sft_data: bool = False,
48+
dtype=None,
4849
) -> DataLoader:
4950
"""
5051
Create appropriate DataLoader based on data source
@@ -114,6 +115,7 @@ def create_data_loader(
114115
data_source=data_source,
115116
is_hf_dataset=not os.path.isfile(data_source),
116117
use_audio_in_video=use_audio_in_video,
118+
dtype=dtype,
117119
)
118120
elif data_type == "AudioDataset":
119121
dataset = AudioDataset(

angelslim/data/omni_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ def __init__(
3535
data_source: Union[str, Dict] = None,
3636
is_hf_dataset: bool = False,
3737
use_audio_in_video: bool = False,
38+
dtype=None,
3839
):
3940
super().__init__(processor, device, max_length)
4041
self.is_hf_dataset = is_hf_dataset
4142
self.use_audio_in_video = use_audio_in_video
43+
self.dtype = dtype
4244

4345
self._load_file_based_dataset(data_source, num_samples)
4446

@@ -112,10 +114,11 @@ def _process_and_append(self, messages: List[Dict]):
112114
inputs = self.processor(
113115
text=text,
114116
images=images,
115-
audios=audios,
117+
audio=audios,
116118
videos=videos,
117119
padding=True,
118120
return_tensors="pt",
119121
use_audio_in_video=self.use_audio_in_video,
120122
)
123+
inputs = inputs.to(self.device).to(self.dtype)
121124
self.data.append(inputs)

angelslim/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def prepare_data(
160160
model_name=None,
161161
quantization_config=None,
162162
is_sft_data=False,
163+
dtype=None,
163164
) -> Optional[Any]:
164165
"""Prepare compression dataset"""
165166
if custom_dataloader is not None:
@@ -187,6 +188,7 @@ def prepare_data(
187188
model_name=model_name,
188189
quantization_config=quantization_config,
189190
is_sft_data=is_sft_data,
191+
dtype=dtype,
190192
)
191193
self.max_seq_length = max_length
192194

angelslim/models/omni/qwen3_omni.py

Lines changed: 107 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,17 @@
1919
AutoTokenizer,
2020
Qwen3OmniMoeForConditionalGeneration,
2121
)
22+
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
23+
Qwen3OmniMoeTalkerTextExperts,
24+
Qwen3OmniMoeTalkerTextTopKRouter,
25+
Qwen3OmniMoeThinkerTextExperts,
26+
Qwen3OmniMoeThinkerTextTopKRouter,
27+
)
2228

2329
from ...compressor.quant.core import PTQVLMSaveVllmHF
24-
from ...utils import find_layers, print_info
30+
from ...utils import find_layers, find_parent_layer_and_sub_name, print_info
2531
from ..base_model import BaseLLMModel
32+
from ..llm.qwen import QwenMoeExpertsWithLinear
2633
from ..model_factory import SlimModelFactory
2734

2835

@@ -38,7 +45,92 @@ def __init__(
3845
deploy_backend=deploy_backend,
3946
)
4047
self.modal_type = "Omni"
41-
self.block_name = ["thinker.model.layers", "talker.model.layers"]
48+
self.thinker_block_name = "thinker.model.layers"
49+
self.talker_block_name = "talker.model.layers"
50+
self.observer_layer_classes = [
51+
torch.nn.Linear,
52+
Qwen3OmniMoeThinkerTextTopKRouter,
53+
Qwen3OmniMoeTalkerTextTopKRouter,
54+
]
55+
self.observed_names = [
56+
"k_proj",
57+
"v_proj",
58+
"q_proj",
59+
"o_proj",
60+
"gate_proj",
61+
"up_proj",
62+
"down_proj",
63+
]
64+
65+
def replace_moe(self):
66+
for name, module in self.model.thinker.named_modules():
67+
if isinstance(module, Qwen3OmniMoeThinkerTextExperts) and not isinstance(
68+
module, QwenMoeExpertsWithLinear
69+
):
70+
print(name)
71+
parent_layer, sub_name = find_parent_layer_and_sub_name(self.model.thinker, name)
72+
moe_linear = QwenMoeExpertsWithLinear(module)
73+
del module
74+
setattr(parent_layer, sub_name, moe_linear)
75+
76+
for name, module in self.model.talker.named_modules():
77+
if isinstance(module, Qwen3OmniMoeTalkerTextExperts) and not isinstance(
78+
module, QwenMoeExpertsWithLinear
79+
):
80+
print(name)
81+
parent_layer, sub_name = find_parent_layer_and_sub_name(self.model.talker, name)
82+
moe_linear = QwenMoeExpertsWithLinear(module)
83+
del module
84+
setattr(parent_layer, sub_name, moe_linear)
85+
86+
def _patch_inputs_embeds_generate_device(self, module):
87+
if module is None or getattr(module, "_angelslim_generate_device_patch", False):
88+
return
89+
90+
original_generate = module.generate
91+
skip_keys = {"past_key_values", "encoder_outputs"}
92+
93+
def move_to_target_device(value, target_device):
94+
if isinstance(value, torch.Tensor):
95+
if value.device.type == "meta" or value.device == target_device:
96+
return value
97+
return value.to(target_device)
98+
if isinstance(value, tuple):
99+
return tuple(move_to_target_device(item, target_device) for item in value)
100+
if isinstance(value, list):
101+
return [move_to_target_device(item, target_device) for item in value]
102+
if isinstance(value, dict):
103+
return {
104+
key: item if key in skip_keys else move_to_target_device(item, target_device)
105+
for key, item in value.items()
106+
}
107+
return value
108+
109+
def generate_on_module_device(*args, **kwargs):
110+
inputs_embeds = kwargs.get("inputs_embeds")
111+
if inputs_embeds is not None:
112+
target_device = getattr(module, "device", inputs_embeds.device)
113+
if target_device.type == "meta":
114+
target_device = inputs_embeds.device
115+
116+
kwargs = {
117+
key: value if key in skip_keys else move_to_target_device(value, target_device)
118+
for key, value in kwargs.items()
119+
}
120+
121+
return original_generate(*args, **kwargs)
122+
123+
module.generate = generate_on_module_device
124+
module._angelslim_generate_device_patch = True
125+
126+
def _patch_omni_generate_devices(self):
127+
talker = getattr(self.model, "talker", None)
128+
self._patch_inputs_embeds_generate_device(talker)
129+
self._patch_inputs_embeds_generate_device(getattr(talker, "code_predictor", None))
130+
131+
def init_ptq(self, slim_config):
132+
super().init_ptq(slim_config)
133+
self.replace_moe()
42134

43135
def from_pretrained(
44136
self,
@@ -63,6 +155,7 @@ def from_pretrained(
63155
device_map=device_map,
64156
attn_implementation=attn_implementation,
65157
)
158+
self._patch_omni_generate_devices()
66159

67160
# Load tokenizer
68161
self.tokenizer = AutoTokenizer.from_pretrained(
@@ -74,24 +167,21 @@ def from_pretrained(
74167
model_path, trust_remote_code=trust_remote_code
75168
)
76169

77-
def get_observer_layers(self):
78-
names = [
79-
"k_proj",
80-
"v_proj",
81-
"q_proj",
82-
"o_proj",
83-
"up_proj",
84-
"gate_proj",
85-
"down_proj",
86-
]
170+
def _get_quant_block_names(self):
171+
block_names = [self.thinker_block_name]
172+
if getattr(self.quant_config, "quant_talker", True):
173+
block_names.append(self.talker_block_name)
174+
return block_names
87175

176+
def get_observer_layers(self):
88177
observer_layers_dict = {}
89178
layers_dict = find_layers(self.model, layers=self.observer_layer_classes)
179+
block_names = self._get_quant_block_names()
90180

91181
ignore_layers = self.skip_layer_names()
92182
for name, module in layers_dict.items():
93-
block_condition = any(name.startswith(block) for block in self.block_name)
94-
if block_condition and name.split(".")[-1] in names:
183+
block_condition = any(name.startswith(block) for block in block_names)
184+
if block_condition and name.split(".")[-1] in self.observed_names:
95185
observer_layers_dict[name] = module
96186
else:
97187
ignore_layers.append(name)
@@ -106,10 +196,11 @@ def get_observer_layers(self):
106196

107197
def get_kvcache_observer_layers_names(self, observe_names):
108198
names = ["self_attn.k_proj", "self_attn.v_proj"]
199+
block_names = self._get_quant_block_names()
109200
return [
110201
k
111202
for k in observe_names
112-
if any(k.startswith(block) for block in self.block_name)
203+
if any(k.startswith(block) for block in block_names)
113204
and k.split(".")[-2] + "." + k.split(".")[-1] in names
114205
]
115206

@@ -129,10 +220,9 @@ def model_forward(self, dataloader, **kwargs):
129220
if dataloader is not None:
130221
with torch.no_grad():
131222
for batch in tqdm(dataloader, desc="calibrating...", total=len(dataloader)):
132-
inputs = {k: v.to(device) for k, v in batch.items()}
133223
try:
134224
text_ids, audio = self.model.generate(
135-
**inputs, use_audio_in_video=self.use_audio_in_video
225+
**batch, use_audio_in_video=self.use_audio_in_video
136226
)
137227
calibrated_cnt += 1
138228
except ValueError:

angelslim/utils/config_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ class QuantizationConfig:
200200
quant_method: Algorithm used for quantization
201201
modules_to_quantize: List of module types to quantize
202202
ignore_layers: List of layer names to skip
203+
quant_talker: Whether to quantize Qwen3-Omni talker LLM module
203204
"""
204205

205206
name: str = field(default="fp8_dynamic")
@@ -222,6 +223,7 @@ class QuantizationConfig:
222223
ignore_layers: List[str] = field(default_factory=list)
223224
quant_analyse: bool = field(default=False)
224225
quant_vit: bool = field(default=False)
226+
quant_talker: bool = field(default=False)
225227
# DAQ-specific fields
226228
base_model_path: Optional[str] = field(default=None)
227229
base_is_fp8: bool = field(default=False)

configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ compression:
2323
quant_method:
2424
weight: "per-tensor"
2525
activation: "per-tensor"
26+
quant_talker: false # Whether to quantize Qwen3-Omni talker LLM module

configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ compression:
2323
quant_method:
2424
weight: "per-tensor"
2525
activation: "per-tensor"
26+
quant_talker: false # Whether to quantize Qwen3-Omni talker LLM module
2627

2728
# Dataset for calibration
2829
dataset:
324 KB
Binary file not shown.
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
{"messages": [{"role": "user", "content": "What happens after the text disappears from the screen?"}], "video_path": "./videos/0.mp4"}
2-
{"messages": [{"role": "user", "content": "How many food item is shown in the bar graph?"}], "image_path": "./images/0.png"}
3-
{"messages": [{"role": "user", "content": "Why is the speech described as rich in frequency content?"}], "audio_path": "./audios/0.png"}
1+
{"messages": [{"role": "user", "content": "描述这个视频的内容。"}], "video_path": "./videos/0.mp4"}
2+
{"messages": [{"role": "user", "content": "请描述这张图片的内容。"}], "image_path": "./images/0.png"}
3+
{"messages": [{"role": "user", "content": "请将这段语音转写成文字。"}], "audio_path": "./audios/0.wav"}

0 commit comments

Comments
 (0)