-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_experiments.py
More file actions
57 lines (49 loc) · 1.93 KB
/
run_experiments.py
File metadata and controls
57 lines (49 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import pandas as pd
import numpy as np
import torch
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
from scipy.special import softmax
from sklearn.metrics import pairwise
from tqdm import tqdm
from ece_bin import ECE_bin
from CV_pipeline import ECE_bin_CV
from experiments import TCE_experiments, CWCE_experiments, CCE_experiments, unpickle_probs
notion = 'tce'
settings_c10 = [
('Cifar10', 'LeNet-5', 'logits/probs_lenet5_c10_logits.p'),
('Cifar10', 'Densenet-40', 'logits/probs_densenet40_c10_logits.p'),
('Cifar10', 'ResNetWide-32', 'logits/probs_resnet_wide32_c10_logits.p'),
('Cifar10', 'Resnet-110', 'logits/probs_resnet110_c10_logits.p'),
('Cifar10', 'Resnet-110 SD', 'logits/probs_resnet110_SD_c10_logits.p'),
]
settings_c100 = [
('Cifar100', 'LeNet-5', 'logits/probs_lenet5_c100_logits.p'),
('Cifar100', 'Densenet-40', 'logits/probs_densenet40_c100_logits.p'),
('Cifar100', 'ResNetWide-32', 'logits/probs_resnet_wide32_c100_logits.p'),
('Cifar100', 'Resnet-110', 'logits/probs_resnet110_c100_logits.p'),
('Cifar100', 'Resnet-110 SD', 'logits/probs_resnet110_SD_c100_logits.p'),
]
if notion=='tce':
settings_imgnet = [
('ImageNet', 'DenseNet-161', 'logits/diag_densenet161_imgnet'),
('ImageNet', 'Resnet-152', 'logits/diag_resnet152_imgnet'),
('ImageNet', 'Pnasnet-5', 'logits/diag_pnasnet5_large_imgnet'),
]
else:
settings_imgnet = []
if notion=='tce':
exp_results = TCE_experiments(
settings_c10+settings_c100+settings_imgnet, k_folds_val=5, k_folds_test=1
)
filename = 'results/real_data/TCE_binning_krr_kkrr_kde.pkl'
elif notion=='cce':
exp_results = CCE_experiments(
settings_c10+settings_c100+settings_imgnet, k_folds_val=5, k_folds_test=1,
estimators=['kkrr', 'krr', 'kde']
)
filename = 'results/real_data/CCE_krr_kkrr_kde.pkl'
with open(filename, 'wb') as file:
pickle.dump(exp_results, file)