-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark_model.py
More file actions
29 lines (23 loc) · 1.35 KB
/
benchmark_model.py
File metadata and controls
29 lines (23 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from robustbench import benchmark
import torchvision.transforms as transforms
from timm import create_model
import swin_transformer_timm_version
from benchmark_model_utils import get_RodriguezMunoz2024Characterizing_model, get_data_transform
# Edit data_dir and ckpt_location to your own values
data_dir = '/data/vision/torralba/datasets/imagenet_pytorch'
# arch = 'swin_base_patch4_window7_224'
# ckpt_location = '/vision-nfs/torralba/projects/adrianr/input_norm/eccv_outputs/gradnorm_swinb_variant/2024-02-14_11-30-41/last.pth.tar'
arch = 'swin_large_patch4_window7_224'
ckpt_location = '/vision-nfs/torralba/projects/adrianr/robustness_input_gradients/outputs/gradnorm_swinl_variant/2024-10-03_13-09-02/last.pth.tar'
transform = get_data_transform()
model = get_RodriguezMunoz2024Characterizing_model(arch, ckpt_location)
threat_model = "Linf" # one of {"Linf", "L2", "corruptions"}
dataset = "imagenet" # one of {"cifar10", "cifar100", "imagenet"}
model_name = "RodriguezMunoz2024Characterizing"
device = torch.device("cuda:0")
clean_acc, robust_acc = benchmark(model, model_name=model_name, n_examples=5000, dataset=dataset,
threat_model=threat_model, eps=4./255., device=device,
preprocessing=transform, data_dir=data_dir,
to_disk=True)