-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtest.py
More file actions
103 lines (76 loc) · 3.56 KB
/
test.py
File metadata and controls
103 lines (76 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import glob
import os
from model_incontext_revise import DiT_incontext_revise
from diffusion import create_diffusion
from vae.autoencoder import AutoencoderKL
from vae.cond_encoder import CondEncoder
from vae.encoder_decoder import Decoder2
from utils import util
from torchvision.utils import save_image
from download import load_model
from torch.nn import functional as F
import natsort
from torchvision.transforms import ToTensor
import cv2
import numpy as np
def fiFindByWildcard(wildcard):
return natsort.natsorted(glob.glob(wildcard, recursive=True))
def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
def rgb(t): return (
np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
np.uint8)
def imread(path):
return cv2.imread(path)[:, :, [2, 1, 0]]
def main(inp_dir):
lr_dir = os.path.join(inp_dir, 'low')
global_prior_dir = os.path.join(inp_dir, 'global_score')
local_prior_dir = os.path.join(inp_dir, 'local_prior')
out_dir = os.path.join(inp_dir, 'outputs')
os.makedirs(out_dir, exist_ok=True)
lr_paths = fiFindByWildcard(os.path.join(lr_dir, '*.png'))
global_prior_paths = fiFindByWildcard(os.path.join(global_prior_dir, '*.pt'))
local_prior_paths = fiFindByWildcard(os.path.join(local_prior_dir, '*.pt'))
device = torch.device('cuda:0')
state_dict = torch.load('weight_lolv2_syn.pth')
model = DiT_incontext_revise()
model.load_state_dict(state_dict['dit'], strict=True)
model = model.to(device)
vae = AutoencoderKL()
vae.load_state_dict(state_dict['vae'], strict=True)
vae = vae.to(device)
cond_lq = CondEncoder()
cond_lq.load_state_dict(state_dict['cond'], strict=True)
cond_lq = cond_lq.to(device)
second_decoder = Decoder2()
second_decoder.load_state_dict(state_dict['second_decoder'], strict=True)
second_decoder = second_decoder.to(device)
model.eval()
diffusion_val = create_diffusion(str(25)) # number of sample steps
to_tensor = ToTensor()
for lr_path, global_path, local_path, test_index in zip(lr_paths, global_prior_paths, local_prior_paths, range(len(lr_paths))):
#y = t(imread(lr_path)).to(device)
y = to_tensor(cv2.cvtColor(cv2.imread(lr_path), cv2.COLOR_BGR2RGB)).unsqueeze(0)
#print(y.shape)
global_prior = torch.load(global_path).to(device)
local_prior = torch.load(local_path).to(device)
b, c, h, w = y.shape
with torch.no_grad():
y, enc_feat = cond_lq(y.to(device), True)
latent_size_h = h // 4
latent_size_w = w // 4
z = torch.randn(1, 3, latent_size_h, latent_size_w, device=device)
model_kwargs = dict(y=y, vis=global_prior, q_map=local_prior)
# Sample images:
samples = diffusion_val.p_sample_loop(
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
)
dec_feat = vae.decode(samples, mid_feat=True)
sr = second_decoder(samples, dec_feat, enc_feat)
save_img_path = os.path.join(out_dir, os.path.basename(lr_path))
save_image(sr, save_img_path)
if __name__ == "__main__":
input_dir = 'dataset/LOLv2_syn/Test'# update the input dir, which at least contains such sub-folder: low, global_score, local_prior
main(input_dir)