Skip to content

Commit 21ec63e

Browse files
committed
Refactor SKCE (#126)
1 parent b7bb04a commit 21ec63e

14 files changed

Lines changed: 377 additions & 447 deletions

File tree

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ estimator(predictions, targets)
3131

3232
The sets of predictions and targets have to be provided as vectors.
3333

34-
This package implements the estimator `ECE` of the ECE, the estimators
35-
`BiasedSKCE`, `UnbiasedSKCE`, and `BlockUnbiasedSKCE` for the SKCE, and `UCME` for the
36-
UCME.
34+
This package implements the estimator `ECE` of the ECE, the estimator `SKCE` for the SKCE
35+
(unbiased and biased variants with different sample complexity), and `UCME` for the UCME.
3736

3837
## Related packages
3938

docs/src/kce.md

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -74,28 +74,12 @@ which only the most-confident predictions are considered.[^WLZ19]
7474

7575
[^WLZ21]: Widmann, D., Lindsten, F., & Zachariah, D. (2021). [Calibration tests beyond classification](https://openreview.net/forum?id=-bxf89v3Nx). To be presented at *ICLR 2021*.
7676

77-
## Estimators
77+
## Estimator
7878

7979
For the SKCE biased and unbiased estimators exist. In CalibrationErrors.jl
80-
three types of estimators are available, namely [`BiasedSKCE`](@ref),
81-
[`UnbiasedSKCE`](@ref), and [`BlockUnbiasedSKCE`](@ref). Unsurprisingly,
82-
[`BiasedSKCE`](@ref) is a biased estimator whereas the other two
83-
estimators are unbiased. [`BiasedSKCE`](@ref) and [`UnbiasedSKCE`](@ref)
84-
have quadratic sample complexity whereas [`BlockUnbiasedSKCE`](@ref)
85-
is an estimator with linear sample complexity.
86-
87-
### Biased estimator
88-
89-
```@docs
90-
BiasedSKCE
91-
```
92-
93-
### Unbiased estimators
94-
95-
```@docs
96-
UnbiasedSKCE
97-
```
80+
[`SKCE`](@ref) lets you construct unbiased and biased estimators with quadratic
81+
and sub-quadratic sample complexity.
9882

9983
```@docs
100-
BlockUnbiasedSKCE
84+
SKCE
10185
```

examples/classification/script.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,14 @@ kernel = with_lengthscale(GaussianKernel(), ν) ⊗ WhiteKernel();
228228

229229
# We obtain the following biased estimate of the squared KCE (SKCE):
230230

231-
skce = BiasedSKCE(kernel)
231+
skce = SKCE(kernel; unbiased=false)
232232
skce(val_probs, val_yint)
233233

234234
# Similar to the biased estimates of the ECE, the biased estimates of the SKCE are always
235235
# non-negative. The unbiased estimates can be negative as well, in particular if the model
236236
# is (close to being) calibrated:
237237

238-
skce = UnbiasedSKCE(kernel)
238+
skce = SKCE(kernel)
239239
skce(val_probs, val_yint)
240240

241241
# When the datasets are large, the quadratic sample complexity of the standard biased and
@@ -246,5 +246,5 @@ skce(val_probs, val_yint)
246246
# Here we consider the extreme case of blocks with two samples, which yields an estimator
247247
# with linear sample complexity:
248248

249-
skce = BlockUnbiasedSKCE(kernel, 2)
249+
skce = SKCE(kernel; blocksize=2)
250250
skce(val_probs, val_yint)

examples/distribution/script.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,20 +226,24 @@ Random.seed!(1234)
226226
data = estimates(_ -> ECE(MedianVarianceBinning(10), TotalVariation()))
227227
plot_estimates(data; ece=true)
228228

229-
# ## Biased estimator of the squared kernel calibration error
229+
# ## Unbiased estimators of the squared kernel calibration error
230230
#
231231

232232
Random.seed!(1234)
233-
data = estimates(BiasedSKCE MedianHeuristicKernel(250))
233+
data = estimates(SKCE MedianHeuristicKernel(250))
234234
plot_estimates(data)
235235

236-
# ## Unbiased estimators of the squared kernel calibration error
237-
#
238-
239236
Random.seed!(1234)
240-
data = estimates(UnbiasedSKCE MedianHeuristicKernel(250))
237+
data = estimates() do predictions_targets
238+
return SKCE(MedianHeuristicKernel(250)(predictions_targets); blocksize=2)
239+
end
241240
plot_estimates(data)
242241

242+
# ## Biased estimator of the squared kernel calibration error
243+
#
244+
243245
Random.seed!(1234)
244-
data = estimates(BlockUnbiasedSKCE MedianHeuristicKernel(250))
246+
data = estimates() do predictions_targets
247+
return SKCE(MedianHeuristicKernel(250)(predictions_targets); unbiased=false)
248+
end
245249
plot_estimates(data)

src/CalibrationErrors.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const OT = ExactOptimalTransport
2121
export calibrationerror
2222

2323
# estimators
24-
export ECE, BiasedSKCE, UnbiasedSKCE, BlockUnbiasedSKCE, UCME
24+
export ECE, SKCE, UCME
2525

2626
# binning algorithms
2727
export UniformBinning, MedianVarianceBinning
@@ -35,9 +35,7 @@ include("binning/uniform.jl")
3535
include("binning/medianvariance.jl")
3636
include("ece.jl")
3737

38-
include("skce/generic.jl")
39-
include("skce/biased.jl")
40-
include("skce/unbiased.jl")
38+
include("skce.jl")
4139

4240
include("ucme.jl")
4341

Lines changed: 215 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,218 @@
1-
abstract type SKCE <: CalibrationErrorEstimator end
1+
@doc raw"""
2+
SKCE(k; unbiased::Bool=true, blocksize=identity)
3+
4+
Estimator of the squared kernel calibration error (SKCE) with kernel `k`.
5+
6+
Kernel `k` on the product space of predictions and targets has to be a `Kernel` from the
7+
Julia package
8+
[KernelFunctions.jl](https://github.com/JuliaGaussianProcesses/KernelFunctions.jl)
9+
that can be evaluated for inputs that are tuples of predictions and targets.
10+
11+
One can choose an unbiased or a biased variant with `unbiased=true` or `unbiased=false`,
12+
respectively (see details below).
13+
14+
The SKCE is estimated as the average estimate of different blocks of samples. The number of
15+
samples per block is set by `blocksize`:
16+
- If `blocksize` is a function `blocksize(n::Int)`, then the number of samples per block is
17+
set to `blocksize(n)` where `n` is the total number of samples.
18+
- If `blocksize` is an integer, then the number of samplers per block is set to `blocksize`,
19+
indepedent of the total number of samples.
20+
The default setting `blocksize=identity` implies that a single block with all samples is
21+
used.
22+
23+
The number of samples per block must be at least 1 if `unbiased=false` and 2 if
24+
`unbiased=true`. Additionally, it must be at most the total number of samples. Note that the
25+
last block is neglected if it is incomplete (see details below).
26+
27+
# Details
28+
29+
The unbiased estimator is not guaranteed to be non-negative whereas the biased estimator is
30+
always non-negative.
31+
32+
The sample complexity of the estimator is ``O(mn)``, where ``m`` is the block size and ``n``
33+
is the total number of samples. In particular, with the default setting `blocksize=identity`
34+
the estimator has a quadratic sample complexity.
35+
36+
Let ``(P_{X_i}, Y_i)_{i=1,\ldots,n}`` be a data set of predictions and corresponding
37+
targets. The estimator with block size ``m`` is defined as
38+
```math
39+
{\bigg\lfloor \frac{n}{m} \bigg\rfloor}^{-1} \sum_{b=1}^{\lfloor n/m \rfloor}
40+
|B_b|^{-1} \sum_{(i, j) \in B_b} h_k\big((P_{X_i}, Y_i), (P_{X_j}, Y_j)\big),
41+
```
42+
where
43+
```math
44+
\begin{aligned}
45+
h_k\big((μ, y), (μ', y')\big) ={}& k\big((μ, y), (μ', y')\big)
46+
- 𝔼_{Z ∼ μ} k\big((μ, Z), (μ', y')\big) \\
47+
& - 𝔼_{Z' ∼ μ'} k\big((μ, y), (μ', Z')\big)
48+
+ 𝔼_{Z ∼ μ, Z' ∼ μ'} k\big((μ, Z), (μ', Z')\big)
49+
\end{aligned}
50+
```
51+
and blocks ``B_b`` (``b = 1, \ldots, \lfloor n/m \rfloor``) are defined as
52+
```math
53+
B_b = \begin{cases}
54+
\{(i, j): (b - 1) m < i < j \leq bm \} & \text{(unbiased)}, \\
55+
\{(i, j): (b - 1) m < i, j \leq bm \} & \text{(biased)}.
56+
\end{cases}
57+
```
58+
59+
# References
60+
61+
Widmann, D., Lindsten, F., & Zachariah, D. (2019). [Calibration tests in multi-class
62+
classification: A unifying framework](https://proceedings.neurips.cc/paper/2019/hash/1c336b8080f82bcc2cd2499b4c57261d-Abstract.html).
63+
In: Advances in Neural Information Processing Systems (NeurIPS 2019) (pp. 12257–12267).
64+
65+
Widmann, D., Lindsten, F., & Zachariah, D. (2021). [Calibration tests beyond
66+
classification](https://openreview.net/forum?id=-bxf89v3Nx).
67+
"""
68+
struct SKCE{K<:Kernel,B} <: CalibrationErrorEstimator
69+
"""Kernel of estimator."""
70+
kernel::K
71+
"""Whether the unbiased estimator is used."""
72+
unbiased::Bool
73+
"""Number of samples per block."""
74+
blocksize::B
75+
76+
function SKCE{K,B}(kernel::K, unbiased::Bool, blocksize::B) where {K,B}
77+
if blocksize isa Integer
78+
blocksize 1 + unbiased || throw(
79+
ArgumentError(
80+
"there must be at least $(1 + unbiased) $(unbiased ? "samples" : "sample") per block",
81+
),
82+
)
83+
end
84+
return new{K,B}(kernel, unbiased, blocksize)
85+
end
86+
end
87+
88+
function SKCE(kernel::Kernel; unbiased::Bool=true, blocksize::B=identity) where {B}
89+
return SKCE{typeof(kernel),B}(kernel, unbiased, blocksize)
90+
end
91+
92+
## estimators without blocks
93+
function (skce::SKCE{<:Kernel,typeof(identity)})(
94+
predictions::AbstractVector, targets::AbstractVector
95+
)
96+
@unpack kernel, unbiased = skce
97+
return if unbiased
98+
unbiasedskce(kernel, predictions, targets)
99+
else
100+
biasedskce(kernel, predictions, targets)
101+
end
102+
end
103+
104+
### unbiased estimator (no blocks)
105+
function unbiasedskce(kernel::Kernel, predictions::AbstractVector, targets::AbstractVector)
106+
# obtain number of samples
107+
nsamples = check_nsamples(predictions, targets, 2)
108+
109+
@inbounds begin
110+
# evaluate the kernel function for the first pair of samples
111+
hij = unsafe_skce_eval(
112+
kernel, predictions[1], targets[1], predictions[2], targets[2]
113+
)
114+
115+
# initialize the estimate
116+
estimate = hij / 1
117+
118+
# for all other pairs of samples
119+
n = 1
120+
for j in 3:nsamples
121+
predictionj = predictions[j]
122+
targetj = targets[j]
123+
124+
for i in 1:(j - 1)
125+
predictioni = predictions[i]
126+
targeti = targets[i]
127+
128+
# evaluate the kernel function
129+
hij = unsafe_skce_eval(kernel, predictioni, targeti, predictionj, targetj)
130+
131+
# update the estimate
132+
n += 1
133+
estimate += (hij - estimate) / n
134+
end
135+
end
136+
end
137+
138+
return estimate
139+
end
140+
141+
### biased estimator (no blocks)
142+
function biasedskce(kernel::Kernel, predictions::AbstractVector, targets::AbstractVector)
143+
# obtain number of samples
144+
nsamples = check_nsamples(predictions, targets, 1)
145+
146+
@inbounds begin
147+
# evaluate kernel function for the first sample
148+
prediction = predictions[1]
149+
target = targets[1]
150+
hij = unsafe_skce_eval(kernel, prediction, target, prediction, target)
151+
152+
# initialize the calibration error estimate
153+
estimate = hij / 1
154+
155+
# for all other pairs of samples
156+
n = 1
157+
for i in 2:nsamples
158+
predictioni = predictions[i]
159+
targeti = targets[i]
160+
161+
for j in 1:(i - 1)
162+
predictionj = predictions[j]
163+
targetj = targets[j]
164+
165+
# evaluate the kernel function
166+
hij = unsafe_skce_eval(kernel, predictioni, targeti, predictionj, targetj)
167+
168+
# update the estimate (add two terms due to symmetry!)
169+
n += 2
170+
estimate += 2 * (hij - estimate) / n
171+
end
172+
173+
# evaluate the kernel function
174+
hij = unsafe_skce_eval(kernel, predictioni, targeti, predictioni, targeti)
175+
176+
# update the estimate
177+
n += 1
178+
estimate += (hij - estimate) / n
179+
end
180+
end
181+
182+
return estimate
183+
end
184+
185+
## estimators with blocks
186+
function (skce::SKCE)(predictions::AbstractVector, targets::AbstractVector)
187+
@unpack kernel, unbiased, blocksize = skce
188+
189+
# obtain number of samples
190+
nsamples = check_nsamples(predictions, targets, 1 + unbiased)
191+
192+
# compute number of blocks
193+
_blocksize = blocksize isa Integer ? blocksize : blocksize(nsamples)
194+
(_blocksize isa Integer && _blocksize >= 1 + unbiased) ||
195+
error("number of samples per block must be an integer >= $(1 + unbiased)")
196+
nblocks = nsamples ÷ _blocksize
197+
nblocks >= 1 || error("at least one block of samples is required")
198+
199+
# create iterator of partitions
200+
blocks = Iterators.take(
201+
zip(
202+
Iterators.partition(predictions, _blocksize),
203+
Iterators.partition(targets, _blocksize),
204+
),
205+
nblocks,
206+
)
207+
208+
# compute average estimate
209+
estimator = SKCE(kernel; unbiased=unbiased)
210+
estimate = mean(
211+
estimator(_predictions, _targets) for (_predictions, _targets) in blocks
212+
)
213+
214+
return estimate
215+
end
2216

3217
"""
4218
unsafe_skce_eval(k, p, y, p̃, ỹ)

0 commit comments

Comments
 (0)