Skip to content

Commit 967c58e

Browse files
committed
Fix UCME (#118)
1 parent 8ba7021 commit 967c58e

7 files changed

Lines changed: 50 additions & 48 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "CalibrationErrors"
22
uuid = "33913031-fe46-5864-950f-100836f47845"
33
authors = ["David Widmann <david.widmann@it.uu.se>"]
4-
version = "0.5.21"
4+
version = "0.5.22"
55

66
[deps]
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/ucme.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ function unsafe_ucme_eval(
9292
return res
9393
end
9494
function unsafe_ucme_eval(kernel::Kernel, p::Real, y::Bool, testp::Real, testy::Bool)
95-
noty = !y
96-
return (y - p) * kernel((p, y), (testp, testy)) +
97-
(noty - p) * kernel((p, noty), (testp, testy))
95+
return (kernel((p, true), (testp, testy)) - kernel((p, false), (testp, testy))) *
96+
(y - p)
9897
end
9998

10099
function unsafe_ucme_eval(kernel::KernelTensorProduct, p, y, testp, testy)
@@ -126,15 +125,12 @@ function unsafe_ucme_eval_targets(
126125
testp::AbstractVector{<:Real},
127126
testy::Integer,
128127
)
129-
res = sum(((z == y) - pz) * kernel(z, testy) for (z, pz) in enumerate(p))
130-
return res
128+
return sum(((z == y) - pz) * kernel(z, testy) for (z, pz) in enumerate(p))
131129
end
132130
function unsafe_ucme_eval_targets(
133131
kernel::Kernel, p::Real, y::Bool, testp::Real, testy::Bool
134132
)
135-
noty = !y
136-
res = (y - p) * kernel(y, testy) + (noty - p) * kernel(noty, testy)
137-
return res
133+
return (kernel(true, testy) - kernel(false, testy)) * (y - p)
138134
end
139135

140136
function unsafe_ucme_eval_targets(
@@ -144,11 +140,10 @@ function unsafe_ucme_eval_targets(
144140
testp::AbstractVector{<:Real},
145141
testy::Integer,
146142
)
147-
@inbounds res = (y == testy) - p[testy]
148-
return res
143+
return @inbounds (y == testy) - p[testy]
149144
end
150145
function unsafe_ucme_eval_targets(
151146
kernel::WhiteKernel, p::Real, y::Bool, testp::Real, testy::Bool
152147
)
153-
return (2 * testy - 1) * (y - p)
148+
return (testy - !testy) * (y - p)
154149
end

test/binning/medianvariance.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
@test sum(bin -> bin.nsamples, bins) == nsamples
2828
@test sum(bin -> bin.nsamples .* bin.mean_predictions, bins) sum(predictions)
2929
@test sum(bin -> bin.nsamples .* bin.proportions_targets, bins)
30-
counts(targets, nclasses)
30+
counts(targets, nclasses)
3131
end
3232

3333
# set maximum number of bins
@@ -47,7 +47,7 @@
4747
@test sum(bin -> bin.nsamples, bins) == nsamples
4848
@test sum(bin -> bin.nsamples .* bin.mean_predictions, bins) sum(predictions)
4949
@test sum(bin -> bin.nsamples .* bin.proportions_targets, bins)
50-
counts(targets, nclasses)
50+
counts(targets, nclasses)
5151
end
5252
end
5353

@@ -102,7 +102,7 @@
102102
@test bins[i].nsamples == length(idxs)
103103
@test bins[i].mean_predictions mean(predictions[idxs])
104104
@test bins[i].proportions_targets ==
105-
vec(mean(Matrix{Float64}(I, 3, 3)[:, targets[idxs]]; dims=2))
105+
vec(mean(Matrix{Float64}(I, 3, 3)[:, targets[idxs]]; dims=2))
106106
end
107107

108108
bins = CalibrationErrors.perform(MedianVarianceBinning(3), predictions, targets)

test/binning/uniform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
@test bins[i].nsamples == length(idxs)
9999
@test bins[i].mean_predictions == mean(predictions[idxs])
100100
@test bins[i].proportions_targets ==
101-
vec(mean(Matrix{Float64}(I, 3, 3)[:, targets[idxs]]; dims=2))
101+
vec(mean(Matrix{Float64}(I, 3, 3)[:, targets[idxs]]; dims=2))
102102
end
103103

104104
bins = CalibrationErrors.perform(UniformBinning(1), predictions, targets)
@@ -134,7 +134,7 @@
134134
@test bins[i].nsamples == length(idxs)
135135
@test bins[i].mean_predictions == mean(predictions[idxs])
136136
@test bins[i].proportions_targets ==
137-
vec(mean(Matrix{Float64}(I, 3, 3)[:, targets[idxs]]; dims=2))
137+
vec(mean(Matrix{Float64}(I, 3, 3)[:, targets[idxs]]; dims=2))
138138
end
139139

140140
bins = CalibrationErrors.perform(UniformBinning(1), predictions, targets)

test/deprecated.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
for estimator in (ece, skce1, skce2, skce3, ucme)
2121
estimate = estimator(predictions, targets)
2222
@test @test_deprecated(calibrationerror(estimator, predictions, targets)) ==
23-
estimate
23+
estimate
2424
@test @test_deprecated(calibrationerror(estimator, (predictions, targets))) ==
25-
estimate
25+
estimate
2626
@test @test_deprecated(
2727
calibrationerror(estimator, reduce(hcat, predictions), targets)
2828
) == estimate

test/skce/generic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
@test unsafe_skce_eval(kernelfull WhiteKernel(), pfull, yint, p̃full, ỹint) val
3434
@test unsafe_skce_eval(kernelfull WhiteKernel2(), pfull, yint, p̃full, ỹint)
35-
val
35+
val
3636
@test unsafe_skce_eval(
3737
TensorProduct2(kernelfull, WhiteKernel()), pfull, yint, p̃full, ỹint
3838
) val

test/ucme.jl

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -93,37 +93,44 @@
9393
end
9494

9595
@testset "binary classification" begin
96-
# probabilities and boolean targets
96+
# probabilities and corresponding full categorical distribution
9797
p, testp = rand(2)
98-
y, testy = rand(Bool, 2)
99-
scale = rand()
100-
kernel = SqExponentialKernel() ScaleTransform(scale)
101-
val = unsafe_ucme_eval(kernel WhiteKernel(), p, y, testp, testy)
102-
@test unsafe_ucme_eval(kernel WhiteKernel2(), p, y, testp, testy) val
103-
@test unsafe_ucme_eval(TensorProduct2(kernel, WhiteKernel()), p, y, testp, testy)
104-
val
105-
@test unsafe_ucme_eval(TensorProduct2(kernel, WhiteKernel2()), p, y, testp, testy)
106-
val
107-
108-
# corresponding values and kernel for full categorical distribution
10998
pfull = [p, 1 - p]
110-
yint = 2 - y
11199
testpfull = [testp, 1 - testp]
112-
testyint = 2 - testy
100+
101+
# kernel for probabilities and corresponding one for full categorical distributions
102+
scale = rand()
103+
kernel = SqExponentialKernel() ScaleTransform(scale)
113104
kernelfull = SqExponentialKernel() ScaleTransform(scale / sqrt(2))
114105

115-
@test unsafe_ucme_eval(
116-
kernelfull WhiteKernel(), pfull, yint, testpfull, testyint
117-
) val
118-
@test unsafe_ucme_eval(
119-
kernelfull WhiteKernel2(), pfull, yint, testpfull, testyint
120-
) val
121-
@test unsafe_ucme_eval(
122-
TensorProduct2(kernelfull, WhiteKernel()), pfull, yint, testpfull, testyint
123-
) val
124-
@test unsafe_ucme_eval(
125-
TensorProduct2(kernelfull, WhiteKernel2()), pfull, yint, testpfull, testyint
126-
) val
106+
# for different targets
107+
for y in (true, false), testy in (true, false)
108+
# check values for probabilities
109+
val = unsafe_ucme_eval(kernel WhiteKernel(), p, y, testp, testy)
110+
@test unsafe_ucme_eval(kernel WhiteKernel2(), p, y, testp, testy) val
111+
@test unsafe_ucme_eval(
112+
TensorProduct2(kernel, WhiteKernel()), p, y, testp, testy
113+
) val
114+
@test unsafe_ucme_eval(
115+
TensorProduct2(kernel, WhiteKernel2()), p, y, testp, testy
116+
) val
117+
118+
# check values for categorical distributions
119+
yint = 2 - y
120+
testyint = 2 - testy
121+
@test unsafe_ucme_eval(
122+
kernelfull WhiteKernel(), pfull, yint, testpfull, testyint
123+
) val
124+
@test unsafe_ucme_eval(
125+
kernelfull WhiteKernel2(), pfull, yint, testpfull, testyint
126+
) val
127+
@test unsafe_ucme_eval(
128+
TensorProduct2(kernelfull, WhiteKernel()), pfull, yint, testpfull, testyint
129+
) val
130+
@test unsafe_ucme_eval(
131+
TensorProduct2(kernelfull, WhiteKernel2()), pfull, yint, testpfull, testyint
132+
) val
133+
end
127134
end
128135

129136
@testset "multi-class classification" begin
@@ -140,8 +147,8 @@
140147

141148
@test unsafe_ucme_eval(kernel WhiteKernel2(), p, y, testp, testy) val
142149
@test unsafe_ucme_eval(TensorProduct2(kernel, WhiteKernel()), p, y, testp, testy)
143-
val
150+
val
144151
@test unsafe_ucme_eval(TensorProduct2(kernel, WhiteKernel2()), p, y, testp, testy)
145-
val
152+
val
146153
end
147154
end

0 commit comments

Comments
 (0)