@@ -717,12 +717,14 @@ def main():
717717 unet .requires_grad_ (False )
718718
719719 # Freeze all parameters except for the token embeddings in text encoder
720- text_encoder_1 .text_model .encoder .requires_grad_ (False )
721- text_encoder_1 .text_model .final_layer_norm .requires_grad_ (False )
722- text_encoder_1 .text_model .embeddings .position_embedding .requires_grad_ (False )
723- text_encoder_2 .text_model .encoder .requires_grad_ (False )
724- text_encoder_2 .text_model .final_layer_norm .requires_grad_ (False )
725- text_encoder_2 .text_model .embeddings .position_embedding .requires_grad_ (False )
720+ text_module_1 = text_encoder_1 .text_model if hasattr (text_encoder_1 , "text_model" ) else text_encoder_1
721+ text_module_1 .encoder .requires_grad_ (False )
722+ text_module_1 .final_layer_norm .requires_grad_ (False )
723+ text_module_1 .embeddings .position_embedding .requires_grad_ (False )
724+ text_module_2 = text_encoder_2 .text_model if hasattr (text_encoder_2 , "text_model" ) else text_encoder_2
725+ text_module_2 .encoder .requires_grad_ (False )
726+ text_module_2 .final_layer_norm .requires_grad_ (False )
727+ text_module_2 .embeddings .position_embedding .requires_grad_ (False )
726728
727729 if args .gradient_checkpointing :
728730 text_encoder_1 .gradient_checkpointing_enable ()
@@ -767,8 +769,12 @@ def main():
767769 optimizer = optimizer_class (
768770 # only optimize the embeddings
769771 [
770- text_encoder_1 .text_model .embeddings .token_embedding .weight ,
771- text_encoder_2 .text_model .embeddings .token_embedding .weight ,
772+ (
773+ text_encoder_1 .text_model if hasattr (text_encoder_1 , "text_model" ) else text_encoder_1
774+ ).embeddings .token_embedding .weight ,
775+ (
776+ text_encoder_2 .text_model if hasattr (text_encoder_2 , "text_model" ) else text_encoder_2
777+ ).embeddings .token_embedding .weight ,
772778 ],
773779 lr = args .learning_rate ,
774780 betas = (args .adam_beta1 , args .adam_beta2 ),
0 commit comments