-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest.py
More file actions
133 lines (106 loc) · 4.02 KB
/
Copy pathtest.py
File metadata and controls
133 lines (106 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# -*- coding: utf-8 -*-
# @File : train_amp.py
# @Project: BP-Net
# @Author : jie
# @Time : 10/27/21 3:58 PM
import torch
torch.backends.cudnn.enabled = False
from tqdm import tqdm
import hydra
from PIL import Image
import os
from omegaconf import OmegaConf
from utils import *
def hflip_inputs(rgb, lidar, K):
"""
rgb: B x 3 x H x W
lidar: B x 1 x H x W
K: B x 3 x 3
"""
W = rgb.shape[-1]
rgb_f = torch.flip(rgb, dims=[-1])
lidar_f = torch.flip(lidar, dims=[-1])
K_f = K.clone()
K_f[:, 0, 2] = W - 1 - K[:, 0, 2]
return rgb_f, lidar_f, K_f
def test(run, mode='selval', save=False):
dataloader = run.testloader
net = run.net_ema.module
net.eval()
tops = [AverageMeter() for i in range(len(run.metric.metric_name))]
if save:
dir_path = f'results/{run.cfg.name}/{mode}'
os.makedirs(dir_path, exist_ok=True)
with torch.no_grad():
for idx, datas in enumerate(
tqdm(dataloader, desc="test ", dynamic_ncols=True, leave=False, disable=run.rank)):
datas = run.init_cuda(*datas)
# import IPython
# IPython.embed()
# exit()
output = net(*datas[:-1])
if isinstance(output, (list, tuple)):
output = output[-1]
precs = run.metric(output, datas[-1])
for prec, top in zip(precs, tops):
top.update(prec.mean().detach().cpu().item())
if save:
for i in range(output.shape[0]):
index = idx * output.shape[0] + i
file_path = os.path.join(dir_path, f'{index:010d}.png')
img = (output[i, 0] * 256.0).detach().cpu().numpy().astype('uint16')
Img = Image.fromarray(img)
Img.save(file_path)
logs = ""
for name, top in zip(run.metric.metric_name, tops):
logs += f" {name}:{top.avg:.7f} "
run.ddp_log(logs, always=True)
def test_aug(run, mode='selval', save=False):
dataloader = run.testloader
net = run.net_ema.module
net.eval()
tops = [AverageMeter() for i in range(len(run.metric.metric_name))]
if save:
dir_path = f'results/{run.cfg.name}/{mode}'
os.makedirs(dir_path, exist_ok=True)
with torch.no_grad():
for idx, datas in enumerate(
tqdm(dataloader, desc="test ", dynamic_ncols=True, leave=False, disable=run.rank)):
datas = run.init_cuda(*datas)
rgb, lidar, K, gt = datas
# 原图预测
output = net(rgb, lidar, K)
if isinstance(output, (list, tuple)):
output = output[-1]
# flip 预测
rgb_f, lidar_f, K_f = hflip_inputs(rgb, lidar, K)
output_f = net(rgb_f, lidar_f, K_f)
if isinstance(output_f, (list, tuple)):
output_f = output_f[-1]
# flip 回原坐标系
output_f = torch.flip(output_f, dims=[-1])
# TTA 平均
output = 0.5 * (output + output_f)
# import IPython
# IPython.embed()
# exit()
precs = run.metric(output, gt)
for prec, top in zip(precs, tops):
top.update(prec.mean().detach().cpu().item())
if save:
for i in range(output.shape[0]):
index = idx * output.shape[0] + i
file_path = os.path.join(dir_path, f'{index:010d}.png')
img = (output[i, 0] * 256.0).detach().cpu().numpy().astype('uint16')
Img = Image.fromarray(img)
Img.save(file_path)
logs = ""
for name, top in zip(run.metric.metric_name, tops):
logs += f" {name}:{top.avg:.7f} "
run.ddp_log(logs, always=True)
@hydra.main(config_path='configs', config_name='config', version_base='1.2')
def main(cfg):
with Trainer(cfg) as run:
test_aug(run, mode=cfg.data.testset.mode, save=OmegaConf.select(cfg, 'save', default=False))
if __name__ == '__main__':
main()