Skip to content

Commit fe53b23

Browse files
yebaiCopilot
andcommitted
Fix Enzyme: type-stable _unflatten, @inline v&g, DimensionMismatch checks
Replace Any[]-based _unflatten with recursive peel approach that avoids Union types Enzyme cannot differentiate through. Add @inline to all value_and_gradient methods to prevent boxing the (value, grad) sret tuple. Add DimensionMismatch length checks to all vector adapters. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 011c123 commit fe53b23

11 files changed

Lines changed: 61 additions & 48 deletions

docs/src/interface.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ There are two things that make them more special, though:
316316
vi[@varname(x[1])] = 1
317317
vi[@varname(x[2])] = 2
318318
keys(vi) == [x[1], x[2]]
319-
319+
320320
vi[@varname(x)] = [1, 2]
321321
keys(vi) == [x]
322322
```

ext/AbstractPPLDifferentiationInterfaceExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ function (p::DIPrepared)(values::NamedTuple)
2121
end
2222

2323
function (p::DIPrepared)(x::AbstractVector)
24+
length(x) == p.dim ||
25+
throw(DimensionMismatch("expected vector of length $(p.dim), got $(length(x))"))
2426
return p.f_vec(x)
2527
end
2628

@@ -34,7 +36,7 @@ function AbstractPPL.prepare(adtype::AbstractADType, problem, prototype::NamedTu
3436
return DIPrepared(evaluator, f_vec, adtype, prep, prototype, length(x0))
3537
end
3638

37-
function AbstractPPL.value_and_gradient(p::DIPrepared, values::NamedTuple)
39+
@inline function AbstractPPL.value_and_gradient(p::DIPrepared, values::NamedTuple)
3840
x = AbstractPPL.flatten_to_vec(values)
3941
val, dx = DI.value_and_gradient(p.f_vec, p.prep, p.backend, x)
4042
grad_nt = AbstractPPL.unflatten_from_vec(p.prototype, dx)

ext/AbstractPPLEnzymeExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ function (p::EnzymePrepared)(values::NamedTuple)
2020
end
2121

2222
function (p::EnzymePrepared)(x::AbstractVector)
23+
length(x) == p.dim ||
24+
throw(DimensionMismatch("expected vector of length $(p.dim), got $(length(x))"))
2325
return p.f_vec(x)
2426
end
2527

@@ -33,7 +35,7 @@ function AbstractPPL.prepare(::AutoEnzyme, problem, prototype::NamedTuple)
3335
return EnzymePrepared(evaluator, f_vec, grad_buf, prototype, length(x0))
3436
end
3537

36-
function AbstractPPL.value_and_gradient(p::EnzymePrepared, values::NamedTuple)
38+
@inline function AbstractPPL.value_and_gradient(p::EnzymePrepared, values::NamedTuple)
3739
x = AbstractPPL.flatten_to_vec(values)
3840
fill!(p.gradient_buffer, 0.0)
3941
result = Enzyme.autodiff(

ext/AbstractPPLForwardDiffExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ function (p::ForwardDiffPrepared)(values::NamedTuple)
2121
end
2222

2323
function (p::ForwardDiffPrepared)(x::AbstractVector)
24+
length(x) == p.dim ||
25+
throw(DimensionMismatch("expected vector of length $(p.dim), got $(length(x))"))
2426
return p.f_vec(x)
2527
end
2628

@@ -36,7 +38,7 @@ function AbstractPPL.prepare(::AutoForwardDiff, problem, prototype::NamedTuple)
3638
return ForwardDiffPrepared(evaluator, f_vec, cfg, result, prototype, length(x0))
3739
end
3840

39-
function AbstractPPL.value_and_gradient(p::ForwardDiffPrepared, values::NamedTuple)
41+
@inline function AbstractPPL.value_and_gradient(p::ForwardDiffPrepared, values::NamedTuple)
4042
x = AbstractPPL.flatten_to_vec(values)
4143
ForwardDiff.gradient!(p.result, p.f_vec, x, p.config)
4244
val = ForwardDiff.DiffResults.value(p.result)

ext/AbstractPPLMooncakeExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ function (p::MooncakePrepared)(values::NamedTuple)
2020
end
2121

2222
function (p::MooncakePrepared)(x::AbstractVector)
23+
length(x) == p.dim ||
24+
throw(DimensionMismatch("expected vector of length $(p.dim), got $(length(x))"))
2325
return p.f_vec(x)
2426
end
2527

@@ -33,7 +35,7 @@ function AbstractPPL.prepare(adtype::AutoMooncake, problem, prototype::NamedTupl
3335
return MooncakePrepared(evaluator, f_vec, cache, prototype, length(x0))
3436
end
3537

36-
function AbstractPPL.value_and_gradient(p::MooncakePrepared, values::NamedTuple)
38+
@inline function AbstractPPL.value_and_gradient(p::MooncakePrepared, values::NamedTuple)
3739
x = AbstractPPL.flatten_to_vec(values)
3840
val, (_, dx) = Mooncake.value_and_gradient!!(p.cache, p.f_vec, x)
3941
grad_nt = AbstractPPL.unflatten_from_vec(p.prototype, dx)

src/evaluator.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,15 @@ function _unflatten(proto::AbstractArray{<:Real}, vec::AbstractVector, offset::I
9999
result = reshape(@view(vec[offset:(offset + n - 1)]), size(proto))
100100
return result, offset + n
101101
end
102-
function _unflatten(proto::NamedTuple, vec::AbstractVector, offset::Int)
103-
rebuilt = Any[]
104-
for v in values(proto)
105-
val, offset = _unflatten(v, vec, offset)
106-
push!(rebuilt, val)
107-
end
108-
return NamedTuple{keys(proto)}(Tuple(rebuilt)), offset
102+
# Recursive peel keeps this type-stable (Enzyme needs it).
103+
function _unflatten(::NamedTuple{(),Tuple{}}, ::AbstractVector, offset::Int)
104+
return NamedTuple(), offset
105+
end
106+
function _unflatten(proto::NamedTuple{K}, vec::AbstractVector, offset::Int) where {K}
107+
first_val, offset = _unflatten(first(values(proto)), vec, offset)
108+
rest_proto = NamedTuple{Base.tail(K)}(Base.tail(values(proto)))
109+
rest_nt, offset = _unflatten(rest_proto, vec, offset)
110+
return merge(NamedTuple{(first(K),)}((first_val,)), rest_nt), offset
109111
end
110112

111113
function unflatten_from_vec(prototype::NamedTuple, vec::AbstractVector)

test/varname/hasvalue.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ using Test
3434
@test canview(@opticof(_[1:2]), x)
3535
@test canview(@opticof(_[:]), x)
3636
@test !canview(@opticof(_[4]), x)
37-
@test canview(@opticof(_[i=1]), x)
37+
@test canview(@opticof(_[i = 1]), x)
3838
# For some weird reason DimData does not error on these two but just warns that
3939
# there's no index j!
40-
@test canview(@opticof(_[j=2]), x)
41-
@test canview(@opticof(_[i=1, j=2]), x)
40+
@test canview(@opticof(_[j = 2]), x)
41+
@test canview(@opticof(_[i = 1, j = 2]), x)
4242
end
4343

4444
@testset "Dict" begin
@@ -246,14 +246,14 @@ end
246246
@test getvalue(x, @varname(a[1, 2])) == x.a[1, 2]
247247
@test hasvalue(x, @varname(a[:]))
248248
@test getvalue(x, @varname(a[:])) == x.a[:]
249-
@test canview(@opticof(_[i=1]), x.a)
250-
@test hasvalue(x, @varname(a[i=1]))
251-
@test getvalue(x, @varname(a[i=1])) == x.a[i=1]
252-
@test canview(@opticof(_[i=1, j=2]), x.a)
253-
@test hasvalue(x, @varname(a[i=1, j=2]))
254-
@test getvalue(x, @varname(a[i=1, j=2])) == x.a[i=1, j=2]
255-
@test hasvalue(x, @varname(a[i=DD.Not(1)]))
256-
@test getvalue(x, @varname(a[i=DD.Not(1)])) == x.a[i=DD.Not(1)]
249+
@test canview(@opticof(_[i = 1]), x.a)
250+
@test hasvalue(x, @varname(a[i = 1]))
251+
@test getvalue(x, @varname(a[i = 1])) == x.a[i = 1]
252+
@test canview(@opticof(_[i = 1, j = 2]), x.a)
253+
@test hasvalue(x, @varname(a[i = 1, j = 2]))
254+
@test getvalue(x, @varname(a[i = 1, j = 2])) == x.a[i = 1, j = 2]
255+
@test hasvalue(x, @varname(a[i = DD.Not(1)]))
256+
@test getvalue(x, @varname(a[i = DD.Not(1)])) == x.a[i = DD.Not(1)]
257257

258258
y = (; b=DD.DimArray(randn(2, 3), (DD.X, DD.Y)))
259259
@test hasvalue(y, @varname(b))

test/varname/optic.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using AbstractPPL
3131
@opticof(_.a[2]),
3232
@opticof(_.a[1, :]),
3333
@opticof(_[1].a),
34-
@opticof(_[1, x=1].a),
34+
@opticof(_[1, x = 1].a),
3535
)
3636
for (i, optic1) in enumerate(optics)
3737
for (j, optic2) in enumerate(optics)
@@ -103,7 +103,7 @@ using AbstractPPL
103103
@opticof(_.a),
104104
@opticof(_.a.b),
105105
@opticof(_[1].a),
106-
@opticof(_[1, x=1].a),
106+
@opticof(_[1, x = 1].a),
107107
@opticof(_[].a[:]),
108108
)
109109
for optic in optics
@@ -196,8 +196,8 @@ using AbstractPPL
196196

197197
@testset "keyword arguments to getindex" begin
198198
dimarray = DD.DimArray([0.0 1.0; 2.0 3.0], (:x, :y))
199-
@test @opticof(_[x=1])(dimarray) == dimarray[x=1]
200-
@test set(dimarray, @opticof(_[y=2]), [9.0; 8.0]) ==
199+
@test @opticof(_[x = 1])(dimarray) == dimarray[x = 1]
200+
@test set(dimarray, @opticof(_[y = 2]), [9.0; 8.0]) ==
201201
DD.DimArray([0.0 9.0; 2.0 8.0], (:x, :y))
202202
end
203203

@@ -288,10 +288,10 @@ using AbstractPPL
288288
@testset "keyword index" begin
289289
x = DD.DimArray(zeros(2, 2), (:x, :y))
290290
old_objid = objectid(x)
291-
optic = with_mutation(@opticof(_[x=1, y=2]))
292-
@test optic(x) === x[x=1, y=2]
291+
optic = with_mutation(@opticof(_[x = 1, y = 2]))
292+
@test optic(x) === x[x = 1, y = 2]
293293
set(x, optic, 2.0)
294-
@test x[x=1, y=2] == 2.0
294+
@test x[x = 1, y = 2] == 2.0
295295
@test collect(x) == [0.0 2.0; 0.0 0.0]
296296
@test objectid(x) == old_objid
297297
end

test/varname/serialize.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ using Test
3737
@varname(z[:, :], true),
3838
@varname(z[2:5, :], false),
3939
@varname(z[2:5, :], true),
40-
@varname(x[i=1]),
41-
@varname(x[j=2, i=1]),
42-
@varname(x[i=1, j=2]),
43-
@varname(x[].a[j=2].b[3, 4, 5, [6]]),
40+
@varname(x[i = 1]),
41+
@varname(x[j = 2, i = 1]),
42+
@varname(x[i = 1, j = 2]),
43+
@varname(x[].a[j = 2].b[3, 4, 5, [6]]),
4444
@varname(x[[1, 2, 5, 6]]),
4545
]
4646
for vn in vns
@@ -65,8 +65,9 @@ using Test
6565
"type" => "InvertedIndices.InvertedIndex",
6666
"skip" => AbstractPPL.index_to_dict(o.skip),
6767
)
68-
AbstractPPL.dict_to_index(::Val{Symbol("InvertedIndices.InvertedIndex")}, d) =
69-
InvertedIndex(AbstractPPL.dict_to_index(d["skip"]))
68+
AbstractPPL.dict_to_index(::Val{Symbol("InvertedIndices.InvertedIndex")}, d) = InvertedIndex(
69+
AbstractPPL.dict_to_index(d["skip"])
70+
)
7071

7172
# Serialisation should now work
7273
@test string_to_varname(varname_to_string(vn)) == vn

test/varname/subsumes.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,14 @@ using Test
6565
end
6666

6767
@testset "keyword indices" begin
68-
@test strictly_subsumes(@varname(x), @varname(x[a=1]))
69-
@test strictly_subsumes(@varname(x[a=1:10, b=1:10]), @varname(x[a=1:10]))
70-
@test strictly_subsumes(@varname(x[a=1:10, b=1:10]), @varname(x[a=1:5, b=1:5]))
71-
@test strictly_subsumes(@varname(x[a=:]), @varname(x[a=1]))
72-
@test uncomparable(@varname(x[a=1:10, b=5]), @varname(x[a=5, b=1:10]))
73-
@test uncomparable(@varname(x[a=1]), @varname(x[b=1]))
68+
@test strictly_subsumes(@varname(x), @varname(x[a = 1]))
69+
@test strictly_subsumes(@varname(x[a = 1:10, b = 1:10]), @varname(x[a = 1:10]))
70+
@test strictly_subsumes(
71+
@varname(x[a = 1:10, b = 1:10]), @varname(x[a = 1:5, b = 1:5])
72+
)
73+
@test strictly_subsumes(@varname(x[a=:]), @varname(x[a = 1]))
74+
@test uncomparable(@varname(x[a = 1:10, b = 5]), @varname(x[a = 5, b = 1:10]))
75+
@test uncomparable(@varname(x[a = 1]), @varname(x[b = 1]))
7476
end
7577
end
7678
end

0 commit comments

Comments
 (0)