Skip to content

Commit d4386f4

Browse files
committed
fix textual inversion
1 parent 4548e68 commit d4386f4

2 files changed

Lines changed: 18 additions & 11 deletions

File tree

examples/textual_inversion/textual_inversion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,10 @@ def main():
702702
vae.requires_grad_(False)
703703
unet.requires_grad_(False)
704704
# Freeze all parameters except for the token embeddings in text encoder
705-
text_encoder.text_model.encoder.requires_grad_(False)
706-
text_encoder.text_model.final_layer_norm.requires_grad_(False)
707-
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
705+
text_module = text_encoder.text_model if hasattr(text_encoder, "text_model") else text_encoder
706+
text_module.encoder.requires_grad_(False)
707+
text_module.final_layer_norm.requires_grad_(False)
708+
text_module.embeddings.position_embedding.requires_grad_(False)
708709

709710
if args.gradient_checkpointing:
710711
# Keep unet in train mode if we are using gradient checkpointing to save memory.

examples/textual_inversion/textual_inversion_sdxl.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -717,12 +717,14 @@ def main():
717717
unet.requires_grad_(False)
718718

719719
# Freeze all parameters except for the token embeddings in text encoder
720-
text_encoder_1.text_model.encoder.requires_grad_(False)
721-
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
722-
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
723-
text_encoder_2.text_model.encoder.requires_grad_(False)
724-
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
725-
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
720+
text_module_1 = text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
721+
text_module_1.encoder.requires_grad_(False)
722+
text_module_1.final_layer_norm.requires_grad_(False)
723+
text_module_1.embeddings.position_embedding.requires_grad_(False)
724+
text_module_2 = text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
725+
text_module_2.encoder.requires_grad_(False)
726+
text_module_2.final_layer_norm.requires_grad_(False)
727+
text_module_2.embeddings.position_embedding.requires_grad_(False)
726728

727729
if args.gradient_checkpointing:
728730
text_encoder_1.gradient_checkpointing_enable()
@@ -767,8 +769,12 @@ def main():
767769
optimizer = optimizer_class(
768770
# only optimize the embeddings
769771
[
770-
text_encoder_1.text_model.embeddings.token_embedding.weight,
771-
text_encoder_2.text_model.embeddings.token_embedding.weight,
772+
(
773+
text_encoder_1.text_model if hasattr(text_encoder_1, "text_model") else text_encoder_1
774+
).embeddings.token_embedding.weight,
775+
(
776+
text_encoder_2.text_model if hasattr(text_encoder_2, "text_model") else text_encoder_2
777+
).embeddings.token_embedding.weight,
772778
],
773779
lr=args.learning_rate,
774780
betas=(args.adam_beta1, args.adam_beta2),

0 commit comments

Comments
 (0)