1919 AutoTokenizer ,
2020 Qwen3OmniMoeForConditionalGeneration ,
2121)
22+ from transformers .models .qwen3_omni_moe .modeling_qwen3_omni_moe import (
23+ Qwen3OmniMoeTalkerTextExperts ,
24+ Qwen3OmniMoeTalkerTextTopKRouter ,
25+ Qwen3OmniMoeThinkerTextExperts ,
26+ Qwen3OmniMoeThinkerTextTopKRouter ,
27+ )
2228
2329from ...compressor .quant .core import PTQVLMSaveVllmHF
24- from ...utils import find_layers , print_info
30+ from ...utils import find_layers , find_parent_layer_and_sub_name , print_info
2531from ..base_model import BaseLLMModel
32+ from ..llm .qwen import QwenMoeExpertsWithLinear
2633from ..model_factory import SlimModelFactory
2734
2835
@@ -38,7 +45,92 @@ def __init__(
3845 deploy_backend = deploy_backend ,
3946 )
4047 self .modal_type = "Omni"
41- self .block_name = ["thinker.model.layers" , "talker.model.layers" ]
48+ self .thinker_block_name = "thinker.model.layers"
49+ self .talker_block_name = "talker.model.layers"
50+ self .observer_layer_classes = [
51+ torch .nn .Linear ,
52+ Qwen3OmniMoeThinkerTextTopKRouter ,
53+ Qwen3OmniMoeTalkerTextTopKRouter ,
54+ ]
55+ self .observed_names = [
56+ "k_proj" ,
57+ "v_proj" ,
58+ "q_proj" ,
59+ "o_proj" ,
60+ "gate_proj" ,
61+ "up_proj" ,
62+ "down_proj" ,
63+ ]
64+
65+ def replace_moe (self ):
66+ for name , module in self .model .thinker .named_modules ():
67+ if isinstance (module , Qwen3OmniMoeThinkerTextExperts ) and not isinstance (
68+ module , QwenMoeExpertsWithLinear
69+ ):
70+ print (name )
71+ parent_layer , sub_name = find_parent_layer_and_sub_name (self .model .thinker , name )
72+ moe_linear = QwenMoeExpertsWithLinear (module )
73+ del module
74+ setattr (parent_layer , sub_name , moe_linear )
75+
76+ for name , module in self .model .talker .named_modules ():
77+ if isinstance (module , Qwen3OmniMoeTalkerTextExperts ) and not isinstance (
78+ module , QwenMoeExpertsWithLinear
79+ ):
80+ print (name )
81+ parent_layer , sub_name = find_parent_layer_and_sub_name (self .model .talker , name )
82+ moe_linear = QwenMoeExpertsWithLinear (module )
83+ del module
84+ setattr (parent_layer , sub_name , moe_linear )
85+
86+ def _patch_inputs_embeds_generate_device (self , module ):
87+ if module is None or getattr (module , "_angelslim_generate_device_patch" , False ):
88+ return
89+
90+ original_generate = module .generate
91+ skip_keys = {"past_key_values" , "encoder_outputs" }
92+
93+ def move_to_target_device (value , target_device ):
94+ if isinstance (value , torch .Tensor ):
95+ if value .device .type == "meta" or value .device == target_device :
96+ return value
97+ return value .to (target_device )
98+ if isinstance (value , tuple ):
99+ return tuple (move_to_target_device (item , target_device ) for item in value )
100+ if isinstance (value , list ):
101+ return [move_to_target_device (item , target_device ) for item in value ]
102+ if isinstance (value , dict ):
103+ return {
104+ key : item if key in skip_keys else move_to_target_device (item , target_device )
105+ for key , item in value .items ()
106+ }
107+ return value
108+
109+ def generate_on_module_device (* args , ** kwargs ):
110+ inputs_embeds = kwargs .get ("inputs_embeds" )
111+ if inputs_embeds is not None :
112+ target_device = getattr (module , "device" , inputs_embeds .device )
113+ if target_device .type == "meta" :
114+ target_device = inputs_embeds .device
115+
116+ kwargs = {
117+ key : value if key in skip_keys else move_to_target_device (value , target_device )
118+ for key , value in kwargs .items ()
119+ }
120+
121+ return original_generate (* args , ** kwargs )
122+
123+ module .generate = generate_on_module_device
124+ module ._angelslim_generate_device_patch = True
125+
126+ def _patch_omni_generate_devices (self ):
127+ talker = getattr (self .model , "talker" , None )
128+ self ._patch_inputs_embeds_generate_device (talker )
129+ self ._patch_inputs_embeds_generate_device (getattr (talker , "code_predictor" , None ))
130+
131+ def init_ptq (self , slim_config ):
132+ super ().init_ptq (slim_config )
133+ self .replace_moe ()
42134
43135 def from_pretrained (
44136 self ,
@@ -63,6 +155,7 @@ def from_pretrained(
63155 device_map = device_map ,
64156 attn_implementation = attn_implementation ,
65157 )
158+ self ._patch_omni_generate_devices ()
66159
67160 # Load tokenizer
68161 self .tokenizer = AutoTokenizer .from_pretrained (
@@ -74,24 +167,21 @@ def from_pretrained(
74167 model_path , trust_remote_code = trust_remote_code
75168 )
76169
77- def get_observer_layers (self ):
78- names = [
79- "k_proj" ,
80- "v_proj" ,
81- "q_proj" ,
82- "o_proj" ,
83- "up_proj" ,
84- "gate_proj" ,
85- "down_proj" ,
86- ]
170+ def _get_quant_block_names (self ):
171+ block_names = [self .thinker_block_name ]
172+ if getattr (self .quant_config , "quant_talker" , True ):
173+ block_names .append (self .talker_block_name )
174+ return block_names
87175
176+ def get_observer_layers (self ):
88177 observer_layers_dict = {}
89178 layers_dict = find_layers (self .model , layers = self .observer_layer_classes )
179+ block_names = self ._get_quant_block_names ()
90180
91181 ignore_layers = self .skip_layer_names ()
92182 for name , module in layers_dict .items ():
93- block_condition = any (name .startswith (block ) for block in self . block_name )
94- if block_condition and name .split ("." )[- 1 ] in names :
183+ block_condition = any (name .startswith (block ) for block in block_names )
184+ if block_condition and name .split ("." )[- 1 ] in self . observed_names :
95185 observer_layers_dict [name ] = module
96186 else :
97187 ignore_layers .append (name )
@@ -106,10 +196,11 @@ def get_observer_layers(self):
106196
107197 def get_kvcache_observer_layers_names (self , observe_names ):
108198 names = ["self_attn.k_proj" , "self_attn.v_proj" ]
199+ block_names = self ._get_quant_block_names ()
109200 return [
110201 k
111202 for k in observe_names
112- if any (k .startswith (block ) for block in self . block_name )
203+ if any (k .startswith (block ) for block in block_names )
113204 and k .split ("." )[- 2 ] + "." + k .split ("." )[- 1 ] in names
114205 ]
115206
@@ -129,10 +220,9 @@ def model_forward(self, dataloader, **kwargs):
129220 if dataloader is not None :
130221 with torch .no_grad ():
131222 for batch in tqdm (dataloader , desc = "calibrating..." , total = len (dataloader )):
132- inputs = {k : v .to (device ) for k , v in batch .items ()}
133223 try :
134224 text_ids , audio = self .model .generate (
135- ** inputs , use_audio_in_video = self .use_audio_in_video
225+ ** batch , use_audio_in_video = self .use_audio_in_video
136226 )
137227 calibrated_cnt += 1
138228 except ValueError :
0 commit comments