Skip to content

Commit 4565da8

Browse files
committed
Added computation of 95% bootstrap confidence intervals
1 parent 1aa03a3 commit 4565da8

8 files changed

Lines changed: 272 additions & 56 deletions

File tree

src/thunder/tasks/adversarial_attack.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@
1515
import torchvision.transforms as T
1616
import wandb
1717
from omegaconf import DictConfig
18+
from sklearn.metrics import (
19+
accuracy_score,
20+
balanced_accuracy_score,
21+
f1_score,
22+
jaccard_score,
23+
)
1824
from torch.utils.data import DataLoader, Subset
1925
from tqdm import tqdm
2026

2127
from ..models.pretrained_models import load_pretrained_model
2228
from ..utils.constants import UtilsConstants
2329
from ..utils.data import PatchDataset, get_data
24-
from ..utils.downstream_metrics import compute_metrics
30+
from ..utils.downstream_metrics import compute_metric, compute_metrics
2531
from ..utils.pgd_attack_linear import PGDImageAttack
2632
from ..utils.utils import log_metrics, save_outputs
2733

@@ -177,14 +183,51 @@ def adversarial_attack(
177183
all_adv_preds = torch.cat(all_adv_preds).cpu().numpy()
178184
all_labels = torch.cat(all_labels).cpu().numpy()
179185

180-
# Compute accuracy before and after attack
186+
# Compute metrics before and after attack
181187
clean_metrics = compute_metrics(None, all_clean_preds, all_labels)
182188
adv_metrics = compute_metrics(None, all_adv_preds, all_labels)
183189
metrics = {
184190
"clean": clean_metrics,
185191
"adversarial": adv_metrics,
186192
}
187193

194+
# Drop in metrics (mean and confidence intervals)
195+
f1_drop = compute_metric(
196+
all_labels,
197+
np.concatenate([all_clean_preds[:, None], all_adv_preds[:, None]], axis=1),
198+
lambda y, y_pred: f1_score(y_true=y, y_pred=y_pred[:, 0], average="macro")
199+
- f1_score(y_true=y, y_pred=y_pred[:, 1], average="macro"),
200+
label_indices=np.arange(len(all_labels)),
201+
)
202+
accuracy_drop = compute_metric(
203+
all_labels,
204+
np.concatenate([all_clean_preds[:, None], all_adv_preds[:, None]], axis=1),
205+
lambda y, y_pred: accuracy_score(y_true=y, y_pred=y_pred[:, 0])
206+
- accuracy_score(y_true=y, y_pred=y_pred[:, 1]),
207+
label_indices=np.arange(len(all_labels)),
208+
)
209+
jaccard_drop = compute_metric(
210+
all_labels,
211+
np.concatenate([all_clean_preds[:, None], all_adv_preds[:, None]], axis=1),
212+
lambda y, y_pred: jaccard_score(y_true=y, y_pred=y_pred[:, 0], average="macro")
213+
- jaccard_score(y_true=y, y_pred=y_pred[:, 1], average="macro"),
214+
label_indices=np.arange(len(all_labels)),
215+
)
216+
balanced_accuracy_drop = compute_metric(
217+
all_labels,
218+
np.concatenate([all_clean_preds[:, None], all_adv_preds[:, None]], axis=1),
219+
lambda y, y_pred: balanced_accuracy_score(y_true=y, y_pred=y_pred[:, 0])
220+
- balanced_accuracy_score(y_true=y, y_pred=y_pred[:, 1]),
221+
label_indices=np.arange(len(all_labels)),
222+
)
223+
224+
metrics["drop"] = {
225+
"f1": f1_drop,
226+
"accuracy": accuracy_drop,
227+
"jaccard": jaccard_drop,
228+
"balanced_accuracy": balanced_accuracy_drop,
229+
}
230+
188231
# save ---------------------------------------------------------------
189232
save_outputs(res_folder, metrics)
190233

src/thunder/tasks/image_retrieval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def topk_retrieval(
3535
chunk_size: int = 10000,
3636
return_viz_data: bool = False,
3737
disable_progress_bar: bool = False,
38+
compute_ci: bool = True,
3839
) -> tuple[dict, dict, dict]:
3940
"""
4041
Computing similarities between queries and keys.
@@ -46,6 +47,7 @@ def topk_retrieval(
4647
:param chunk_size: maximum number of query embeddings for which we compute dot product similarity with all key embeddings.
4748
:param return_viz_data: whether to return data to visualize topk samples.
4849
:param disable_progress_bar: whether to hide the progress bar.
50+
:param compute_ci: whether to compute confidence intervals.
4951
:return dict of metrics, sorted images ids and viz data if required.
5052
"""
5153
# Normalizing embeddings
@@ -93,7 +95,9 @@ def topk_retrieval(
9395
# Metrics
9496
metrics_per_k = {}
9597
for k in k_vals:
96-
metrics_per_k[k] = compute_metrics(None, np.array(preds_per_k[k]), query_labels)
98+
metrics_per_k[k] = compute_metrics(
99+
None, np.array(preds_per_k[k]), query_labels, compute_ci=compute_ci
100+
)
97101

98102
return metrics_per_k, sorted_ids_per_k, viz_data
99103

@@ -180,7 +184,7 @@ def image_retrieval(
180184

181185
# Logging
182186
for k in k_vals:
183-
log_metrics(wandb_base_folder, metrics, "test", step=k)
187+
log_metrics(wandb_base_folder, metrics[k], "test", step=k)
184188
save_outputs(res_folder, metrics)
185189

186190
return metrics

src/thunder/tasks/knn_classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def knn(
3030
embs["val"],
3131
labels["val"],
3232
k_vals,
33+
compute_ci=False,
3334
)
3435

3536
# Logging
@@ -40,7 +41,7 @@ def knn(
4041
best_k = None
4142
best_val_f1 = -float("inf")
4243
for k in k_vals:
43-
val_f1 = np.array(val_metrics[k]["f1"]).mean().item()
44+
val_f1 = np.array(val_metrics[k]["f1"]["metric_score"]).mean().item()
4445
if val_f1 > best_val_f1:
4546
best_val_f1 = val_f1
4647
best_k = k

src/thunder/tasks/simple_shot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def simple_shot(
7777
test_labels,
7878
[1],
7979
disable_progress_bar=True,
80+
compute_ci=False,
8081
)
8182

8283
# Logging

src/thunder/tasks/train_eval_probe.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..utils.constants import UtilsConstants
2727
from ..utils.calibration_metrics import compute_calibration_metrics
2828
from ..utils.data import PatchDataset
29-
from ..utils.downstream_metrics import compute_metrics
29+
from ..utils.downstream_metrics import compute_metric, compute_metrics
3030
from ..utils.utils import (
3131
get_hyperaparams_dict,
3232
local_seed,
@@ -272,13 +272,13 @@ def train_probe(
272272
# Updating best ckpt
273273
if (
274274
task_type == "linear_probing"
275-
and metrics[i]["f1"] > best_val_perf
275+
and metrics[i]["f1"]["metric_score"] > best_val_perf
276276
) or (
277277
task_type == "segmentation"
278278
and np.array(losses[i]).mean().item() < best_val_perf
279279
):
280280
if task_type == "linear_probing":
281-
best_val_perf = metrics[i]["f1"]
281+
best_val_perf = metrics[i]["f1"]["metric_score"]
282282
elif task_type == "segmentation":
283283
best_val_perf = np.array(losses[i]).mean().item()
284284
best_ckpt_hyperparam_id = i
@@ -538,7 +538,9 @@ def train_eval(
538538
tot_loss = []
539539
all_out = []
540540
all_label = []
541-
for batch_id, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
541+
for batch_id, batch in tqdm(
542+
enumerate(dataloader), total=len(dataloader), disable=hyperparam_search
543+
):
542544
# Batch data
543545
if "emb" in batch.keys():
544546
emb = batch["emb"].to(device)
@@ -591,8 +593,14 @@ def train_eval(
591593
out = []
592594
if task_type == "segmentation":
593595
for o, m in zip(output, unmasked_label):
594-
out.append(torch.cat([o[c][m].unsqueeze(-1) for c in range(o.shape[0])], dim=-1))
595-
curr_loss = sum([criterion(o, l) for o, l in zip(out, label)]) / len(out)
596+
out.append(
597+
torch.cat(
598+
[o[c][m].unsqueeze(-1) for c in range(o.shape[0])], dim=-1
599+
)
600+
)
601+
curr_loss = sum([criterion(o, l) for o, l in zip(out, label)]) / len(
602+
out
603+
)
596604
else:
597605
for c in range(output.shape[1]):
598606
out.append(output[:, c].unsqueeze(-1))
@@ -604,13 +612,19 @@ def train_eval(
604612
if batch_id == 0:
605613
tot_loss.append([curr_loss.item()])
606614
if comp_metrics:
607-
all_out.append([[o.detach().cpu() for o in out]] if task_type == 'segmentation'
608-
else [out.detach().cpu()])
615+
all_out.append(
616+
[[o.detach().cpu() for o in out]]
617+
if task_type == "segmentation"
618+
else [out.detach().cpu()]
619+
)
609620
else:
610621
tot_loss[i].append(curr_loss.item())
611622
if comp_metrics:
612-
all_out[i].append([o.detach().cpu() for o in out] if task_type == 'segmentation'
613-
else out.detach().cpu())
623+
all_out[i].append(
624+
[o.detach().cpu() for o in out]
625+
if task_type == "segmentation"
626+
else out.detach().cpu()
627+
)
614628
# Logging
615629
if comp_metrics:
616630
if task_type == "segmentation":
@@ -641,20 +655,42 @@ def train_eval(
641655
if task_type == "segmentation":
642656
metrics = []
643657
for i in range(len(all_out)):
644-
all_out[i] = [F.softmax(item, dim=1) for batch in all_out[i] for item in batch]
645-
all_metrics = [compute_metrics(o, None, l, True) for o, l in zip(all_out[i], all_label) if len(l) > 0]
658+
all_out[i] = [
659+
F.softmax(item, dim=1) for batch in all_out[i] for item in batch
660+
]
661+
all_metrics = [
662+
compute_metrics(o, None, l, True, compute_ci=False)
663+
for o, l in zip(all_out[i], all_label)
664+
if len(l) > 0
665+
]
646666
weights = [len(l) for l in all_label if len(l) > 0]
647-
metrics.append({key: np.average([d[key] for d in all_metrics], weights=weights) for key in all_metrics[0]} |
648-
{f'{key}_per_sample': [d[key] for d in all_metrics] for key in all_metrics[0]})
667+
668+
# Averagin per-image performance and computing confidence intervals
669+
all_metrics_out = {}
670+
for key in all_metrics[0]:
671+
metric_vals = [d[key]["metric_score"] for d in all_metrics]
672+
all_metrics_out[key] = compute_metric(
673+
weights,
674+
metric_vals,
675+
lambda weights, metric_vals: np.average(
676+
metric_vals, weights=weights
677+
),
678+
)
679+
all_metrics_out[f"per_sample_{key}"] = metric_vals
680+
metrics.append(all_metrics_out)
649681
else:
650682
# Computing metrics
651683
all_label = torch.cat(all_label)
652684
metrics = []
653685
for i in range(len(all_out)):
654686
all_out[i] = torch.cat(all_out[i])
655687
all_out[i] = F.softmax(all_out[i], dim=1)
656-
classification_metrics = compute_metrics(all_out[i], None, all_label)
657-
conformal_metrics = compute_calibration_metrics(all_out[i], all_label)
688+
classification_metrics = compute_metrics(
689+
all_out[i], None, all_label, compute_ci=(not hyperparam_search)
690+
)
691+
conformal_metrics = compute_calibration_metrics(
692+
all_out[i], all_label, compute_ci=(not hyperparam_search)
693+
)
658694
curr_metrics = (
659695
classification_metrics | conformal_metrics
660696
) # merging dictionaries

src/thunder/utils/calibration_metrics.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import numpy as np
22
import torch
33

4+
from ..utils.downstream_metrics import compute_metric
45

5-
def compute_calibration_metrics(out_proba: torch.Tensor, label: torch.Tensor) -> dict:
6+
7+
def compute_calibration_metrics(
8+
out_proba: torch.Tensor, label: torch.Tensor, compute_ci: bool = True
9+
) -> dict:
610
"""
711
Computing performance metrics.
812
:param out: tensor of proba predictions.
913
:param label: tensor of ground-truth labels.
14+
:param compute_ci: whether to compute confidence intervals.
1015
:return: dict of conformal prediction metrics.
1116
"""
1217

@@ -15,11 +20,41 @@ def compute_calibration_metrics(out_proba: torch.Tensor, label: torch.Tensor) ->
1520
label = label.numpy()
1621

1722
# Computing metrics
18-
ece = expected_calibration_error(out_proba, label)
19-
mce = maximum_calibration_error(out_proba, label)
20-
sce = static_calibration_error(out_proba, label)
21-
ace = adaptive_calibration_error(out_proba, label)
22-
tace = thresholded_adaptive_calibration_error(out_proba, label)
23+
ece = compute_metric(
24+
label,
25+
out_proba,
26+
lambda y, y_proba: expected_calibration_error(y_proba, y),
27+
label_indices=np.arange(len(label)),
28+
compute_ci=compute_ci,
29+
)
30+
mce = compute_metric(
31+
label,
32+
out_proba,
33+
lambda y, y_proba: maximum_calibration_error(y_proba, y),
34+
label_indices=np.arange(len(label)),
35+
compute_ci=compute_ci,
36+
)
37+
sce = compute_metric(
38+
label,
39+
out_proba,
40+
lambda y, y_proba: static_calibration_error(y_proba, y),
41+
label_indices=np.arange(len(label)),
42+
compute_ci=compute_ci,
43+
)
44+
ace = compute_metric(
45+
label,
46+
out_proba,
47+
lambda y, y_proba: adaptive_calibration_error(y_proba, y),
48+
label_indices=np.arange(len(label)),
49+
compute_ci=compute_ci,
50+
)
51+
tace = compute_metric(
52+
label,
53+
out_proba,
54+
lambda y, y_proba: thresholded_adaptive_calibration_error(y_proba, y),
55+
label_indices=np.arange(len(label)),
56+
compute_ci=compute_ci,
57+
)
2358

2459
return {
2560
"ECE": ece,

0 commit comments

Comments
 (0)