-
Notifications
You must be signed in to change notification settings - Fork 141
Expand file tree
/
Copy pathdiffusers_sample.py
More file actions
80 lines (69 loc) · 3.36 KB
/
Copy pathdiffusers_sample.py
File metadata and controls
80 lines (69 loc) · 3.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# ==========================================================================================
#
# MIT License. To view a copy of the license, visit MIT_LICENSE.md.
#
# ==========================================================================================
import argparse
import sys
import os
import numpy as np
import torch
from PIL import Image
from diffusers import StableDiffusionXLPipeline
sys.path.append('./')
from src.diffusers_model_pipeline import CustomDiffusionPipeline, CustomDiffusionXLPipeline
def sample(ckpt, delta_ckpt, from_file, prompt, compress, batch_size, freeze_model, sdxl=False):
model_id = ckpt
if sdxl:
pipe = CustomDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
else:
pipe = CustomDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_model(delta_ckpt, compress)
outdir = os.path.dirname(delta_ckpt)
generator = torch.Generator(device='cuda').manual_seed(42)
all_images = []
if prompt is not None:
images = pipe([prompt]*batch_size, num_inference_steps=200, guidance_scale=6., eta=1., generator=generator).images
all_images += images
images = np.hstack([np.array(x) for x in images])
images = Image.fromarray(images)
# takes only first 50 characters of prompt to name the image file
name = '-'.join(prompt[:50].split())
images.save(f'{outdir}/{name}.png')
else:
print(f"reading prompts from {from_file}")
with open(from_file, "r") as f:
data = f.read().splitlines()
data = [[prompt]*batch_size for prompt in data]
for prompt in data:
images = pipe(prompt, num_inference_steps=200, guidance_scale=6., eta=1., generator=generator).images
all_images += images
images = np.hstack([np.array(x) for x in images], 0)
images = Image.fromarray(images)
# takes only first 50 characters of prompt to name the image file
name = '-'.join(prompt[0][:50].split())
images.save(f'{outdir}/{name}.png')
os.makedirs(f'{outdir}/samples', exist_ok=True)
for i, im in enumerate(all_images):
im.save(f'{outdir}/samples/{i}.jpg')
def parse_args():
parser = argparse.ArgumentParser('', add_help=False)
parser.add_argument('--ckpt', help='target string for query',
type=str)
parser.add_argument('--delta_ckpt', help='target string for query', default=None,
type=str)
parser.add_argument('--from-file', help='path to prompt file', default='./',
type=str)
parser.add_argument('--prompt', help='prompt to generate', default=None,
type=str)
parser.add_argument("--compress", action='store_true')
parser.add_argument("--sdxl", action='store_true')
parser.add_argument("--batch_size", default=5, type=int)
parser.add_argument('--freeze_model', help='crossattn or crossattn_kv', default='crossattn_kv',
type=str)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
sample(args.ckpt, args.delta_ckpt, args.from_file, args.prompt, args.compress, args.batch_size, args.freeze_model, args.sdxl)