Skip to content

Commit c219dfe

Browse files
committed
Improved computational efficiency when evaluating segmentation models by reducing the saved information.
1 parent f5e51ba commit c219dfe

2 files changed

Lines changed: 72 additions & 49 deletions

File tree

src/thunder/tasks/train_eval_probe.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -580,32 +580,43 @@ def train_eval(
580580
)
581581

582582
# Applying masking (removing pixels where gt == -1)
583-
unmasked_label = label != -1
584-
label = label[unmasked_label]
583+
if task_type == "segmentation":
584+
unmasked_label = [l != -1 for l in label]
585+
label = [l[u] for l, u in zip(label, unmasked_label)]
586+
else:
587+
label = label.view(-1)
585588
loss = 0
586589
for i in range(len(outputs)):
587590
output = outputs[i]
588591
out = []
589-
for c in range(output.shape[1]):
590-
out.append(output[:, c][unmasked_label].unsqueeze(-1))
591-
out = torch.cat(out, dim=-1)
592-
593-
# Compute loss
594-
curr_loss = criterion(out, label)
592+
if task_type == "segmentation":
593+
for o, m in zip(output, unmasked_label):
594+
out.append(torch.cat([o[c][m].unsqueeze(-1) for c in range(o.shape[0])], dim=-1))
595+
curr_loss = sum([criterion(o, l) for o, l in zip(out, label)]) / len(out)
596+
else:
597+
for c in range(output.shape[1]):
598+
out.append(output[:, c].unsqueeze(-1))
599+
out = torch.cat(out, dim=-1)
600+
curr_loss = criterion(out, label)
595601
loss += curr_loss
596602

597603
# Logging
598604
if batch_id == 0:
599605
tot_loss.append([curr_loss.item()])
600606
if comp_metrics:
601-
all_out.append([out.detach().cpu()])
607+
all_out.append([[o.detach().cpu() for o in out]] if task_type == 'segmentation'
608+
else [out.detach().cpu()])
602609
else:
603610
tot_loss[i].append(curr_loss.item())
604611
if comp_metrics:
605-
all_out[i].append(out.detach().cpu())
612+
all_out[i].append([o.detach().cpu() for o in out] if task_type == 'segmentation'
613+
else out.detach().cpu())
606614
# Logging
607615
if comp_metrics:
608-
all_label.append(label.cpu())
616+
if task_type == "segmentation":
617+
all_label.extend([l.cpu() for l in label])
618+
else:
619+
all_label.append(label.cpu())
609620

610621
if run_type == "train":
611622
# Compute gradients
@@ -627,18 +638,27 @@ def train_eval(
627638
viz_im = None
628639

629640
if comp_metrics:
630-
# Computing metrics
631-
all_label = torch.cat(all_label)
632-
metrics = []
633-
for i in range(len(all_out)):
634-
all_out[i] = torch.cat(all_out[i])
635-
all_out[i] = F.softmax(all_out[i], dim=1)
636-
classification_metrics = compute_metrics(all_out[i], None, all_label)
637-
conformal_metrics = compute_calibration_metrics(all_out[i], all_label)
638-
curr_metrics = (
639-
classification_metrics | conformal_metrics
640-
) # merging dictionaries
641-
metrics.append(curr_metrics)
641+
if task_type == "segmentation":
642+
metrics = []
643+
for i in range(len(all_out)):
644+
all_out[i] = [F.softmax(item, dim=1) for batch in all_out[i] for item in batch]
645+
all_metrics = [compute_metrics(o, None, l, True) for o, l in zip(all_out[i], all_label) if len(l) > 0]
646+
weights = [len(l) for l in all_label if len(l) > 0]
647+
metrics.append({key: np.average([d[key] for d in all_metrics], weights=weights) for key in all_metrics[0]} |
648+
{f'{key}_per_sample': [d[key] for d in all_metrics] for key in all_metrics[0]})
649+
else:
650+
# Computing metrics
651+
all_label = torch.cat(all_label)
652+
metrics = []
653+
for i in range(len(all_out)):
654+
all_out[i] = torch.cat(all_out[i])
655+
all_out[i] = F.softmax(all_out[i], dim=1)
656+
classification_metrics = compute_metrics(all_out[i], None, all_label)
657+
conformal_metrics = compute_calibration_metrics(all_out[i], all_label)
658+
curr_metrics = (
659+
classification_metrics | conformal_metrics
660+
) # merging dictionaries
661+
metrics.append(curr_metrics)
642662
else:
643663
metrics = None
644664

src/thunder/utils/downstream_metrics.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def compute_metrics(
1414
out_proba: Union[torch.Tensor, np.array],
1515
out_pred: Union[torch.Tensor, np.array],
1616
label: Union[torch.Tensor, np.array],
17+
is_segmentation: bool = False
1718
) -> dict:
1819
"""
1920
Computing performance metrics.
@@ -34,35 +35,37 @@ def compute_metrics(
3435

3536
# Computing metrics
3637
accuracy = accuracy_score(y_true=label, y_pred=out_pred)
37-
balanced_accuracy = balanced_accuracy_score(y_true=label, y_pred=out_pred)
3838
f1 = f1_score(y_true=label, y_pred=out_pred, average="macro")
3939
jaccard = jaccard_score(y_true=label, y_pred=out_pred, average="macro")
4040

41-
if out_proba is not None:
42-
if out_proba.shape[1] > 2:
43-
roc_auc = roc_auc_score(
44-
y_true=label,
45-
y_score=out_proba,
46-
multi_class="ovo",
47-
labels=torch.arange(out_proba.shape[1]).tolist(),
48-
)
41+
metrics = {"f1": f1,
42+
"accuracy": accuracy,
43+
"jaccard": jaccard}
44+
45+
if not is_segmentation:
46+
balanced_accuracy = balanced_accuracy_score(y_true=label, y_pred=out_pred)
47+
if out_proba is not None:
48+
if out_proba.shape[1] > 2:
49+
roc_auc = roc_auc_score(
50+
y_true=label,
51+
y_score=out_proba,
52+
multi_class="ovo",
53+
labels=torch.arange(out_proba.shape[1]).tolist(),
54+
)
55+
else:
56+
assert out_proba.shape[1] == 2
57+
roc_auc = roc_auc_score(y_true=label, y_score=out_proba[:, 1])
4958
else:
50-
assert out_proba.shape[1] == 2
51-
roc_auc = roc_auc_score(y_true=label, y_score=out_proba[:, 1])
52-
else:
53-
roc_auc = None
59+
roc_auc = None
60+
61+
# Per-sample metrics
62+
per_sample_acc = (out_pred == label).astype(np.int8).tolist()
5463

55-
# Per-sample metrics
56-
per_sample_acc = (out_pred == label).astype(np.int8).tolist()
64+
metrics = (metrics | {"balanced_accuracy": balanced_accuracy,
65+
"roc_auc": roc_auc,
66+
"per_sample_acc": per_sample_acc,
67+
"per_sample_pred": out_pred.tolist(),
68+
"per_sample_proba": out_proba.tolist() if out_proba is not None else None,
69+
"label": label.tolist()})
5770

58-
return {
59-
"accuracy": accuracy,
60-
"balanced_accuracy": balanced_accuracy,
61-
"f1": f1,
62-
"jaccard": jaccard,
63-
"roc_auc": roc_auc,
64-
"per_sample_acc": per_sample_acc,
65-
"per_sample_pred": out_pred.tolist(),
66-
"per_sample_proba": out_proba.tolist() if out_proba is not None else None,
67-
"label": label.tolist(),
68-
}
71+
return metrics

0 commit comments

Comments
 (0)