Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

Commit 45c1a72

Browse files
tiger-of-shawnwenhu.xwh
andauthored
Add Qwen2.5-Omni support. (#759)
Signed-off-by: wenhu.xwh <wenhu.xwh@alibaba-inc.com> Co-authored-by: wenhu.xwh <wenhu.xwh@alibaba-inc.com>
1 parent f0dd15d commit 45c1a72

5 files changed

Lines changed: 96 additions & 5 deletions

File tree

awq/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@
3131
from .internlm2 import InternLM2AWQForCausalLM
3232
from .minicpm3 import MiniCPM3AWQForCausalLM
3333
from .qwen2vl import Qwen2VLAWQForCausalLM
34-
from .qwen2_5_vl import Qwen2_5_VLAWQForCausalLM
34+
from .qwen2_5_vl import Qwen2_5_VLAWQForCausalLM
35+
from .qwen2_5_omni import Qwen2_5_OmniAWQForConditionalGeneration

awq/models/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"minicpm3": MiniCPM3AWQForCausalLM,
4444
"qwen2_vl": Qwen2VLAWQForCausalLM,
4545
"qwen2_5_vl": Qwen2_5_VLAWQForCausalLM,
46+
"qwen2_5_omni": Qwen2_5_OmniAWQForConditionalGeneration
4647
}
4748

4849

@@ -79,7 +80,6 @@ def from_pretrained(
7980
model_type = check_and_get_model_type(
8081
model_path, trust_remote_code, **model_init_kwargs
8182
)
82-
8383
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
8484
model_path,
8585
model_type,

awq/models/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"internlm2": "AutoModelForCausalLM",
9191
"qwen2_vl": "AutoModelForVision2Seq",
9292
"qwen2_5_vl": "AutoModelForVision2Seq",
93+
"qwen2_5_omni": "AutoModelForTextToWaveform",
9394
}
9495

9596

@@ -377,12 +378,11 @@ def from_pretrained(
377378
target_cls = getattr(transformers, target_cls_name)
378379

379380
processor = None
380-
if target_cls_name == "AutoModelForVision2Seq":
381+
if target_cls_name == "AutoModelForVision2Seq" or target_cls_name == "AutoModelForTextToWaveform":
381382
processor = AutoProcessor.from_pretrained(model_weights_path)
382-
383383
if model_init_kwargs.get("low_cpu_mem_usage") is None:
384384
model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
385-
if model_init_kwargs.get("use_cache") is None and target_cls_name != "AutoModelForVision2Seq":
385+
if model_init_kwargs.get("use_cache") is None and not ((target_cls_name == "AutoModelForVision2Seq") or (target_cls_name == "AutoModelForTextToWaveform")):
386386
model_init_kwargs["use_cache"] = use_cache
387387

388388
# If not quantized, must load with AutoModelForCausalLM

awq/models/qwen2_5_omni.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from .base import BaseAWQForCausalLM
2+
from typing_extensions import TYPE_CHECKING
3+
4+
if TYPE_CHECKING:
5+
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
6+
Qwen2_5OmniDecoderLayer
7+
)
8+
from transformers import Qwen2_5OmniForConditionalGeneration
9+
10+
11+
class Qwen2_5_OmniAWQForConditionalGeneration(BaseAWQForCausalLM):
12+
layer_type = "Qwen2_5OmniDecoderLayer"
13+
max_seq_len_key = "max_position_embeddings"
14+
modules_to_not_convert = ["visual"]
15+
@staticmethod
16+
def get_model_layers(model: "Qwen2_5OmniForConditionalGeneration"):
17+
return model.thinker.model.layers
18+
19+
@staticmethod
20+
def get_act_for_scaling(module: "Qwen2_5OmniForConditionalGeneration"):
21+
return dict(is_scalable=False)
22+
23+
@staticmethod
24+
def move_embed(model: "Qwen2_5OmniForConditionalGeneration", device: str):
25+
model.thinker.model.embed_tokens = model.thinker.model.embed_tokens.to(device)
26+
model.thinker.visual = model.thinker.visual.to(device)
27+
model.thinker.audio_tower = model.thinker.audio_tower.to(device)
28+
29+
model.thinker.visual.rotary_pos_emb = model.thinker.visual.rotary_pos_emb.to(device)
30+
model.thinker.model.rotary_emb = model.thinker.model.rotary_emb.to(device)
31+
32+
for layer in model.thinker.model.layers:
33+
layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device)
34+
35+
@staticmethod
36+
def get_layers_for_scaling(
37+
module: "Qwen2_5OmniDecoderLayer", input_feat, module_kwargs
38+
):
39+
layers = []
40+
41+
# attention input
42+
layers.append(
43+
dict(
44+
prev_op=module.input_layernorm,
45+
layers=[
46+
module.self_attn.q_proj,
47+
module.self_attn.k_proj,
48+
module.self_attn.v_proj,
49+
],
50+
inp=input_feat["self_attn.q_proj"],
51+
module2inspect=module.self_attn,
52+
kwargs=module_kwargs,
53+
)
54+
)
55+
56+
# attention out
57+
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
58+
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
59+
layers.append(
60+
dict(
61+
prev_op=module.self_attn.v_proj,
62+
layers=[module.self_attn.o_proj],
63+
inp=input_feat["self_attn.o_proj"],
64+
)
65+
)
66+
67+
# linear 1
68+
layers.append(
69+
dict(
70+
prev_op=module.post_attention_layernorm,
71+
layers=[module.mlp.gate_proj, module.mlp.up_proj],
72+
inp=input_feat["mlp.gate_proj"],
73+
module2inspect=module.mlp,
74+
)
75+
)
76+
77+
# linear 2
78+
layers.append(
79+
dict(
80+
prev_op=module.mlp.up_proj,
81+
layers=[module.mlp.down_proj],
82+
inp=input_feat["mlp.down_proj"],
83+
)
84+
)
85+
86+
return layers

awq/quantize/quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ def quantize(self):
163163
self.inps, self.module_kwargs["position_ids"]
164164
)
165165

166+
if (transformers.__version__ >= "4.48.0"
167+
and self.module_kwargs.get('attention_mask') is None):
168+
self.module_kwargs['attention_mask'] = None
169+
166170
for k, v in self.module_kwargs.items():
167171
# position embeddings found in tuple
168172
if isinstance(v, tuple):

0 commit comments

Comments
 (0)