-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils_eval.py
More file actions
26 lines (20 loc) · 843 Bytes
/
utils_eval.py
File metadata and controls
26 lines (20 loc) · 843 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils_loss_metric import dice_coeff
@torch.inference_mode()
def evaluate(net, dataloader, device, amp):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32)
mask_true = mask_true.to(device=device, dtype=torch.long)
# predict the mask
mask_pred = net(image)
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
net.train()
return dice_score / max(num_val_batches, 1)