Skip to content

Commit faae072

Browse files
run.sh
1 parent a1d2593 commit faae072

4 files changed

Lines changed: 49 additions & 22 deletions

File tree

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,4 @@ wandb
5656
pretrained-*
5757
tuning-*
5858
models
59-
*.sh
6059
grid.png

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ If your target image is your face, you need to pre-train on a large face image d
3232
Or, if you have an artistic image, you might want to train on WikiArt like so.
3333
```
3434
accelerate launch pretrain_e4t.py \
35-
--pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
36-
--clip_model_name_or_path="ViT-H-14::laion2b_s32b_b79k" \
35+
--mixed_precision="fp16" \
36+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
3737
--domain_class_token="art" \
3838
--placeholder_token="*s" \
3939
--prompt_template="art" \
@@ -44,13 +44,12 @@ accelerate launch pretrain_e4t.py \
4444
--train_image_dataset="Artificio/WikiArt" \
4545
--iterable_dataset \
4646
--resolution=512 \
47-
--train_batch_size=16 \
47+
--train_batch_size=1 \
4848
--learning_rate=1e-6 --scale_lr \
4949
--checkpointing_steps=10000 \
5050
--log_steps=1000 \
5151
--max_train_steps=100000 \
5252
--unfreeze_clip_vision \
53-
--mixed_precision="fp16" \
5453
--enable_xformers_memory_efficient_attention
5554
```
5655

e4t/utils.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __getstate__(self):
2222
return self.obj.items()
2323

2424
def __setstate__(self, items):
25-
if not hasattr(self, 'obj'):
25+
if not hasattr(self, "obj"):
2626
self.obj = {}
2727
for key, val in items:
2828
self.obj[key] = val
@@ -43,11 +43,7 @@ def keys(self):
4343
def download_from_huggingface(repo, filename, **kwargs):
4444
while True:
4545
try:
46-
return huggingface_hub.hf_hub_download(
47-
repo,
48-
filename=filename,
49-
**kwargs
50-
)
46+
return huggingface_hub.hf_hub_download(repo, filename=filename, **kwargs)
5147
except HTTPError as e:
5248
if e.response.status_code == 401:
5349
# Need to log into huggingface api
@@ -76,13 +72,17 @@ def download_from_huggingface(repo, filename, **kwargs):
7672
def load_config_from_pretrained(pretrained_model_name_or_path):
7773
if os.path.exists(pretrained_model_name_or_path):
7874
if "config.json" not in pretrained_model_name_or_path:
79-
pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, "config.json")
75+
pretrained_model_name_or_path = os.path.join(
76+
pretrained_model_name_or_path, "config.json"
77+
)
8078
else:
81-
assert pretrained_model_name_or_path in MODELS, f"Choose from {list(MODELS.keys())}"
79+
assert (
80+
pretrained_model_name_or_path in MODELS
81+
), f"Choose from {list(MODELS.keys())}"
8282
pretrained_model_name_or_path = download_from_huggingface(
8383
repo=MODELS[pretrained_model_name_or_path]["repo"],
8484
filename="config.json",
85-
subfolder=MODELS[pretrained_model_name_or_path]["subfolder"]
85+
subfolder=MODELS[pretrained_model_name_or_path]["subfolder"],
8686
)
8787
with open(pretrained_model_name_or_path, "r", encoding="utf-8") as f:
8888
pretrained_args = AttributeDict(json.load(f))
@@ -91,9 +91,12 @@ def load_config_from_pretrained(pretrained_model_name_or_path):
9191

9292
def load_e4t_unet(pretrained_model_name_or_path=None, ckpt_path=None, **kwargs):
9393
assert pretrained_model_name_or_path is not None or ckpt_path is not None
94-
if pretrained_model_name_or_path is None or not os.path.exists(ckpt_path):
94+
if pretrained_model_name_or_path is None:
9595
if os.path.exists(ckpt_path):
96-
assert os.path.basename(ckpt_path) == "unet.pt" or os.path.basename(ckpt_path) == "weight_offsets.pt", "You must specify the filename! (`unet.pt` or `weight_offsets.pt`)"
96+
assert (
97+
os.path.basename(ckpt_path) == "unet.pt"
98+
or os.path.basename(ckpt_path) == "weight_offsets.pt"
99+
), "You must specify the filename! (`unet.pt` or `weight_offsets.pt`)"
97100
config = load_config_from_pretrained(os.path.dirname(ckpt_path))
98101
else:
99102
assert ckpt_path in MODELS, f"Choose from {list(MODELS.keys())}"
@@ -102,16 +105,22 @@ def load_e4t_unet(pretrained_model_name_or_path=None, ckpt_path=None, **kwargs):
102105
ckpt_path = download_from_huggingface(
103106
repo=MODELS[ckpt_path]["repo"],
104107
filename="weight_offsets.pt",
105-
subfolder=MODELS[ckpt_path]["subfolder"]
108+
subfolder=MODELS[ckpt_path]["subfolder"],
106109
)
107110
except EntryNotFoundError:
108111
ckpt_path = download_from_huggingface(
109112
repo=MODELS[ckpt_path]["repo"],
110113
filename="unet.pt",
111-
subfolder=MODELS[ckpt_path]["subfolder"]
114+
subfolder=MODELS[ckpt_path]["subfolder"],
112115
)
113-
pretrained_model_name_or_path = config.pretrained_model_name_or_path if config.pretrained_args is None else config.pretrained_args["pretrained_model_name_or_path"]
114-
unet = OriginalUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", **kwargs)
116+
pretrained_model_name_or_path = (
117+
config.pretrained_model_name_or_path
118+
if config.pretrained_args is None
119+
else config.pretrained_args["pretrained_model_name_or_path"]
120+
)
121+
unet = OriginalUNet2DConditionModel.from_pretrained(
122+
pretrained_model_name_or_path, subfolder="unet", **kwargs
123+
)
115124
state_dict = dict(unet.state_dict())
116125
if ckpt_path:
117126
ckpt_sd = torch.load(ckpt_path, map_location="cpu")
@@ -142,7 +151,7 @@ def load_e4t_encoder(ckpt_path=None, **kwargs):
142151
ckpt_path = download_from_huggingface(
143152
repo=MODELS[ckpt_path]["repo"],
144153
filename="encoder.pt",
145-
subfolder=MODELS[ckpt_path]["subfolder"]
154+
subfolder=MODELS[ckpt_path]["subfolder"],
146155
)
147156
state_dict = torch.load(ckpt_path, map_location="cpu")
148157
print(f"Resuming from {ckpt_path}")
@@ -182,7 +191,7 @@ def image_grid(imgs, rows, cols):
182191
assert len(imgs) == rows * cols
183192

184193
w, h = imgs[0].size
185-
grid = Image.new('RGB', size=(cols * w, rows * h))
194+
grid = Image.new("RGB", size=(cols * w, rows * h))
186195
grid_w, grid_h = grid.size
187196

188197
for i, img in enumerate(imgs):

run.sh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
accelerate launch pretrain_e4t.py \
2+
--mixed_precision="fp16" \
3+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
4+
--domain_class_token="art" \
5+
--placeholder_token="*s" \
6+
--prompt_template="art" \
7+
--save_sample_prompt="a photo of the *s,a photo of the *s in monet style" \
8+
--reg_lambda=0.01 \
9+
--domain_embed_scale=0.1 \
10+
--output_dir="pretrained-wikiart" \
11+
--train_image_dataset="Artificio/WikiArt" \
12+
--iterable_dataset \
13+
--resolution=512 \
14+
--train_batch_size=1 \
15+
--learning_rate=1e-6 --scale_lr \
16+
--checkpointing_steps=10000 \
17+
--log_steps=1000 \
18+
--max_train_steps=100000 \
19+
--unfreeze_clip_vision \
20+
--enable_xformers_memory_efficient_attention

0 commit comments

Comments
 (0)