@@ -249,8 +249,8 @@ def set_quant_config(self):
249249 self .config ['model' ]['type' ] in ['Opt' , 'Llama' ]
250250 ), 'Please set online_rotate=False'
251251 self .fp32_had = special_config .get ('fp32_had' , False )
252- self .hidden_size = self . model . model_config . hidden_size
253- self .set_model_config ()
252+ if self .quant_config . modality != 'video_gen' :
253+ self .set_model_config ()
254254 self .modality = self .quant_config .modality
255255 logger .info (f'self.quant_objects : { self .quant_config .modality } ' )
256256
@@ -373,12 +373,12 @@ def block_forward(self, block, input_data=None):
373373 if torch .is_tensor (self .input ['kwargs' ][i ][k ]):
374374 self .input ['kwargs' ][i ][k ] = self .input ['kwargs' ][i ][k ].to (
375375 device = next (block .parameters ()).device
376- ) # noqa
376+ )
377377 if isinstance (self .input ['kwargs' ][i ][k ], tuple ):
378378 self .input ['kwargs' ][i ][k ] = tuple (
379379 tmp .to (device = next (block .parameters ()).device )
380380 for tmp in self .input ['kwargs' ][i ][k ]
381- ) # noqa
381+ )
382382 with torch .no_grad ():
383383 out = block (input_data [i ], ** self .input ['kwargs' ][i ])
384384 if isinstance (out , tuple ):
@@ -474,9 +474,10 @@ def block_transform(self, block, input_feat, block_kwargs):
474474 inspect_has_kwargs = subset ['has_kwargs' ]
475475 if inspect_has_kwargs :
476476 if 'sub_keys' in subset :
477- subset_kwargs = [
478- {k : block_kwargs [0 ][v ] for k , v in subset ['sub_keys' ].items ()}
479- ]
477+ subset_kwargs = []
478+ for i in range (len (block_kwargs )):
479+ for k , v in subset ['sub_keys' ].items ():
480+ subset_kwargs .append ({k : block_kwargs [i ][v ]})
480481 else :
481482 subset_kwargs = block_kwargs
482483 else :
@@ -746,7 +747,10 @@ def shift_ln_fcs(self, ln, fcs, shifts):
746747 def scale_ln_fcs (self , ln , fcs , scales ):
747748 if not isinstance (fcs , list ):
748749 fcs = [fcs ]
750+
749751 scales = scales .to (ln .weight .device )
752+ scales = scales .to (ln .weight .dtype )
753+
750754 ln .weight .div_ (scales )
751755
752756 if hasattr (ln , 'bias' ) and ln .bias is not None :
@@ -954,6 +958,13 @@ def deploy(self, quant_format, keep_device=False):
954958 self .get_replacement_params (mode = quant_format , w_only = self .w_only ),
955959 keep_device = keep_device ,
956960 )
961+ if self .modality == 'video_gen' :
962+ self .model .replace_video_gen_module_all (
963+ module ,
964+ self .get_replacement_params (mode = quant_format , w_only = self .w_only ),
965+ keep_device = keep_device ,
966+ )
967+
957968 self .set_non_linear_mode (quant_format , self .model .model , False )
958969
959970 if self .quant_kvcache :
@@ -973,8 +984,11 @@ def deploy(self, quant_format, keep_device=False):
973984
974985 @torch .no_grad ()
975986 def copy_tokenizer (self , path ):
976- self .model .tokenizer .save_pretrained (path )
977- logger .info ('copy tokenizer done --' )
987+ if self .model .tokenizer is not None :
988+ self .model .tokenizer .save_pretrained (path )
989+ logger .info ('copy tokenizer done --' )
990+ else :
991+ logger .info ('no tokenizer, skip --' )
978992
979993 @torch .no_grad ()
980994 def contiguous_params (self ):
0 commit comments