-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernel_density_ratio.py
More file actions
92 lines (74 loc) · 3.18 KB
/
kernel_density_ratio.py
File metadata and controls
92 lines (74 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# from https://github.com/tpopordanoska/ece-kde/blob/main/ece_kde.py
import torch
from torch import nn
def get_bandwidth(f, device='cpu'):
"""
Select a bandwidth for the kernel based on maximizing the leave-one-out likelihood (LOO MLE).
:param f: The vector containing the probability scores, shape [num_samples, num_classes]
:param device: The device type: 'cpu' or 'cuda'
:return: The bandwidth of the kernel
"""
bandwidths = torch.cat((torch.logspace(start=-5, end=-1, steps=15), torch.linspace(0.2, 1, steps=5)))
max_b = -1
max_l = 0
n = torch.tensor(f.shape[0])
for b in bandwidths:
log_kern = get_kernel(f, b, device)
log_fhat = torch.logsumexp(log_kern, 1) - torch.log(n-1)
l = torch.sum(log_fhat)
if l > max_l:
max_l = l
max_b = b
return max_b
# Note for training: Make sure there are at least two examples for every class present in the batch, otherwise
# LogsumexpBackward returns nans.
def get_kde_ratio_log(f, y, bandwidth, p=2, device='cpu'):
log_kern = get_kernel(f, bandwidth, device)
y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)
log_y = torch.log(y_onehot)
log_den = torch.logsumexp(log_kern, dim=1)
final_ratio = 0
for k in range(f.shape[1]):
log_kern_y = log_kern + (torch.ones([f.shape[0], 1]) * log_y[:, k].unsqueeze(0))
log_inner_ratio = torch.logsumexp(log_kern_y, dim=1) - log_den
inner_ratio = torch.exp(log_inner_ratio)
inner_diff = inner_ratio**p
final_ratio += inner_diff
return torch.mean(final_ratio).item()
def kde_ratio_estimator(f, y, bandwidth, device='cpu'):
p = 2
if f.shape[1] > 60:
# Slower but more numerically stable implementation for larger number of classes
return get_kde_ratio_log(f, y, bandwidth, p, device)
log_kern = get_kernel(f, bandwidth, device)
kern = torch.exp(log_kern)
y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float64)
kern_y = torch.matmul(kern, y_onehot)
den = torch.sum(kern, dim=1)
# to avoid division by 0
den = torch.clamp(den, min=1e-10)
ratio = kern_y / den.unsqueeze(-1)
ratio = torch.sum(ratio**p, dim=1)
return torch.mean(ratio).item()
def get_kernel(f, bandwidth, device):
# if num_classes == 1
if f.shape[1] == 1:
log_kern = beta_kernel(f, f, bandwidth).squeeze()
else:
log_kern = dirichlet_kernel(f, bandwidth).squeeze()
# Trick: -inf on the diagonal
return log_kern + torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device)
def beta_kernel(z, zi, bandwidth=0.1):
p = zi / bandwidth + 1
q = (1-zi) / bandwidth + 1
z = z.unsqueeze(-2)
log_beta = torch.lgamma(p) + torch.lgamma(q) - torch.lgamma(p + q)
log_num = (p-1) * torch.log(z) + (q-1) * torch.log(1-z)
log_beta_pdf = log_num - log_beta
return log_beta_pdf
def dirichlet_kernel(z, bandwidth=0.1):
alphas = z / bandwidth + 1
log_beta = (torch.sum((torch.lgamma(alphas)), dim=1) - torch.lgamma(torch.sum(alphas, dim=1)))
log_num = torch.matmul(torch.log(z), (alphas-1).T)
log_dir_pdf = log_num - log_beta
return log_dir_pdf