Skip to content

Commit dabcd9b

Browse files
committed
Support UCME with probabilities and boolean targets
1 parent a140a02 commit dabcd9b

4 files changed

Lines changed: 152 additions & 19 deletions

File tree

src/ucme.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,13 @@ function unsafe_ucme_eval(
8888
testp::AbstractVector{<:Real},
8989
testy::Integer,
9090
)
91-
a = kernel((p, y), (testp, testy))
92-
b = sum(p[z] * kernel((p, z), (testp, testy)) for z in 1:length(p))
93-
94-
return a - b
91+
res = sum(((z == y) - pz) * kernel((p, z), (testp, testy)) for (z, pz) in enumerate(p))
92+
return res
93+
end
94+
function unsafe_ucme_eval(kernel::Kernel, p::Real, y::Bool, testp::Real, testy::Bool)
95+
noty = !y
96+
return (y - p) * kernel((p, y), (testp, testy)) +
97+
(noty - p) * kernel((p, noty), (testp, testy))
9598
end
9699

97100
function unsafe_ucme_eval(kernel::KernelTensorProduct, p, y, testp, testy)
@@ -109,6 +112,30 @@ function unsafe_ucme_eval(
109112
κpredictions, κtargets = kernel.kernels
110113
return unsafe_ucme_eval_targets(κtargets, p, y, testp, testy) * κpredictions(p, testp)
111114
end
115+
function unsafe_ucme_eval(
116+
kernel::KernelTensorProduct, p::Real, y::Bool, testp::Real, testy::Bool
117+
)
118+
κpredictions, κtargets = kernel.kernels
119+
return unsafe_ucme_eval_targets(κtargets, p, y, testp, testy) * κpredictions(p, testp)
120+
end
121+
122+
function unsafe_ucme_eval_targets(
123+
kernel::Kernel,
124+
p::AbstractVector{<:Real},
125+
y::Integer,
126+
testp::AbstractVector{<:Real},
127+
testy::Integer,
128+
)
129+
res = sum(((z == y) - pz) * kernel(z, testy) for (z, pz) in enumerate(p))
130+
return res
131+
end
132+
function unsafe_ucme_eval_targets(
133+
kernel::Kernel, p::Real, y::Bool, testp::Real, testy::Bool
134+
)
135+
noty = !y
136+
res = (y - p) * kernel(y, testy) + (noty - p) * kernel(noty, testy)
137+
return res
138+
end
112139

113140
function unsafe_ucme_eval_targets(
114141
κtargets::WhiteKernel,
@@ -120,3 +147,8 @@ function unsafe_ucme_eval_targets(
120147
@inbounds res = (y == testy) - p[testy]
121148
return res
122149
end
150+
function unsafe_ucme_eval_targets(
151+
kernel::WhiteKernel, p::Real, y::Bool, testp::Real, testy::Bool
152+
)
153+
return (2 * testy - 1) * (y - p)
154+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Random
88
using Statistics
99
using Test
1010

11-
using CalibrationErrors: unsafe_skce_eval
11+
using CalibrationErrors: unsafe_skce_eval, unsafe_ucme_eval
1212

1313
Random.seed!(1234)
1414

test/skce/generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525

2626
# corresponding values and kernel for full categorical distribution
2727
pfull = [p, 1 - p]
28-
yint = y ? 1 : 2
28+
yint = 2 - y
2929
p̃full = [p̃, 1 - p̃]
30-
ỹint = ? 1 : 2
30+
ỹint = 2 -
3131
kernelfull = SqExponentialKernel() ScaleTransform(scale / sqrt(2))
3232

3333
@test unsafe_skce_eval(kernelfull WhiteKernel(), pfull, yint, p̃full, ỹint) val

test/ucme.jl

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,145 @@
11
@testset "ucme.jl" begin
2-
@testset "UCME: Two-dimensional example" begin
3-
# three test locations
2+
@testset "UCME: Binary examples" begin
3+
# categorical distributions
44
ucme = UCME(
55
SqExponentialKernel() WhiteKernel(),
66
[[1.0, 0], [0.5, 0.5], [0.0, 1]],
77
[1, 1, 2],
88
)
9-
10-
# two predictions
119
@test iszero(@inferred(ucme([[1, 0], [0, 1]], [1, 2])))
1210
@test @inferred(ucme([[1, 0], [0, 1]], [1, 1])) (exp(-2) + exp(-0.5) + 1) / 12
1311
@test @inferred(ucme([[1, 0], [0, 1]], [2, 1])) (1 - exp(-1))^2 / 6
1412
@test @inferred(ucme([[1, 0], [0, 1]], [2, 2])) (exp(-2) + exp(-0.5) + 1) / 12
13+
14+
# probabilities
15+
ucme = UCME(
16+
(SqExponentialKernel() ScaleTransform(sqrt(2))) WhiteKernel(),
17+
[1.0, 0.5, 0.0],
18+
[true, true, false],
19+
)
20+
@test iszero(@inferred(ucme([1, 0], [true, false])))
21+
@test @inferred(ucme([1, 0], [true, true])) (exp(-2) + exp(-0.5) + 1) / 12
22+
@test @inferred(ucme([1, 0], [false, true])) (1 - exp(-1))^2 / 6
23+
@test @inferred(ucme([1, 0], [false, false])) (exp(-2) + exp(-0.5) + 1) / 12
1524
end
1625

1726
@testset "UCME: Basic properties" begin
1827
estimates = Vector{Float64}(undef, 1_000)
1928

20-
for ntest in (1, 5, 10), nclasses in (2, 10, 100)
21-
dist = Dirichlet(nclasses, 1.0)
29+
for ntest in (1, 5, 10)
30+
# categorical distributions
31+
for nclasses in (2, 10, 100)
32+
dist = Dirichlet(nclasses, 1.0)
33+
34+
testpredictions = [rand(dist) for _ in 1:ntest]
35+
testtargets = rand(1:nclasses, ntest)
36+
ucme = UCME(
37+
(ExponentialKernel() ScaleTransform(0.1)) WhiteKernel(),
38+
testpredictions,
39+
testtargets,
40+
)
41+
42+
predictions = [Vector{Float64}(undef, nclasses) for _ in 1:20]
43+
targets = Vector{Int}(undef, 20)
2244

23-
testpredictions = [rand(dist) for _ in 1:ntest]
24-
testtargets = rand(1:nclasses, ntest)
45+
for i in 1:length(estimates)
46+
rand!.(Ref(dist), predictions)
47+
targets .= rand.(Categorical.(predictions))
48+
49+
estimates[i] = ucme(predictions, targets)
50+
end
51+
52+
@test all(x > zero(x) for x in estimates)
53+
end
54+
55+
# probabilities
56+
testpredictions = rand(ntest)
57+
testtargets = rand(Bool, ntest)
2558
ucme = UCME(
2659
(ExponentialKernel() ScaleTransform(0.1)) WhiteKernel(),
2760
testpredictions,
2861
testtargets,
2962
)
3063

31-
predictions = [Vector{Float64}(undef, nclasses) for _ in 1:20]
32-
targets = Vector{Int}(undef, 20)
64+
predictions = Vector{Float64}(undef, 20)
65+
targets = Vector{Bool}(undef, 20)
3366

3467
for i in 1:length(estimates)
35-
rand!.(Ref(dist), predictions)
36-
targets .= rand.(Categorical.(predictions))
68+
rand!(predictions)
69+
map!(targets, predictions) do p
70+
return rand() < p
71+
end
3772

3873
estimates[i] = ucme(predictions, targets)
3974
end
4075

4176
@test all(x > zero(x) for x in estimates)
4277
end
4378
end
79+
80+
# alternative implementation of white kernel
81+
struct WhiteKernel2 <: Kernel end
82+
(::WhiteKernel2)(x, y) = x == y
83+
84+
# alternative implementation TensorProductKernel
85+
struct TensorProduct2{K1<:Kernel,K2<:Kernel} <: Kernel
86+
kernel1::K1
87+
kernel2::K2
88+
end
89+
function (kernel::TensorProduct2)((x1, x2), (y1, y2))
90+
return kernel.kernel1(x1, y1) * kernel.kernel2(x2, y2)
91+
end
92+
93+
@testset "binary classification" begin
94+
# probabilities and boolean targets
95+
p, testp = rand(2)
96+
y, testy = rand(Bool, 2)
97+
scale = rand()
98+
kernel = SqExponentialKernel() ScaleTransform(scale)
99+
val = unsafe_ucme_eval(kernel WhiteKernel(), p, y, testp, testy)
100+
@test unsafe_ucme_eval(kernel WhiteKernel2(), p, y, testp, testy) val
101+
@test unsafe_ucme_eval(TensorProduct2(kernel, WhiteKernel()), p, y, testp, testy)
102+
val
103+
@test unsafe_ucme_eval(TensorProduct2(kernel, WhiteKernel2()), p, y, testp, testy)
104+
val
105+
106+
# corresponding values and kernel for full categorical distribution
107+
pfull = [p, 1 - p]
108+
yint = 2 - y
109+
testpfull = [testp, 1 - testp]
110+
testyint = 2 - testy
111+
kernelfull = SqExponentialKernel() ScaleTransform(scale / sqrt(2))
112+
113+
@test unsafe_ucme_eval(
114+
kernelfull WhiteKernel(), pfull, yint, testpfull, testyint
115+
) val
116+
@test unsafe_ucme_eval(
117+
kernelfull WhiteKernel2(), pfull, yint, testpfull, testyint
118+
) val
119+
@test unsafe_ucme_eval(
120+
TensorProduct2(kernelfull, WhiteKernel()), pfull, yint, testpfull, testyint
121+
) val
122+
@test unsafe_ucme_eval(
123+
TensorProduct2(kernelfull, WhiteKernel2()), pfull, yint, testpfull, testyint
124+
) val
125+
end
126+
127+
@testset "multi-class classification" begin
128+
n = 10
129+
p = rand(n)
130+
p ./= sum(p)
131+
y = rand(1:n)
132+
testp = rand(n)
133+
testp ./= sum(testp)
134+
testy = rand(1:n)
135+
136+
kernel = SqExponentialKernel() ScaleTransform(rand())
137+
val = unsafe_ucme_eval(kernel WhiteKernel(), p, y, testp, testy)
138+
139+
@test unsafe_ucme_eval(kernel WhiteKernel2(), p, y, testp, testy) val
140+
@test unsafe_ucme_eval(TensorProduct2(kernel, WhiteKernel()), p, y, testp, testy)
141+
val
142+
@test unsafe_ucme_eval(TensorProduct2(kernel, WhiteKernel2()), p, y, testp, testy)
143+
val
144+
end
44145
end

0 commit comments

Comments
 (0)