Skip to content

Commit a6306bf

Browse files
committed
Simplify code and fix test
1 parent 190e85f commit a6306bf

2 files changed

Lines changed: 11 additions & 18 deletions

File tree

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ForwardDiffStaticArraysExt
33
using ForwardDiff, StaticArrays
44
using ForwardDiff.LinearAlgebra
55
using ForwardDiff.DiffResults
6-
using ForwardDiff: Dual, partials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk,
6+
using ForwardDiff: Dual, partials, npartials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk,
77
gradient, hessian, jacobian, gradient!, hessian!, jacobian!,
88
extract_gradient!, extract_jacobian!, extract_value!,
99
vector_mode_gradient, vector_mode_gradient!,
@@ -71,17 +71,8 @@ end
7171
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian!(result, f, x)
7272
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian!(result, f, x)
7373

74-
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
75-
M, N = length(ydual), length(x)
76-
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
77-
return quote
78-
$(Expr(:meta, :inline))
79-
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
80-
return V($result)
81-
end
82-
end
83-
84-
@generated function extract_jacobian(::Type{T}, ydual::Partials{M}, x::S) where {M, T, S<:StaticArray}
74+
@generated function extract_jacobian(::Type{T}, ydual::Union{StaticArray,Partials}, x::S) where {T,S<:StaticArray}
75+
M = ydual <: Partials ? npartials(ydual) : length(ydual)
8576
N = length(x)
8677
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
8778
return quote

test/HessianTest.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,15 @@ end
163163
@test ForwardDiff.hessian(x->dot(x,H,x), zeros(3)) [2 6 10; 6 10 14; 10 14 18]
164164
end
165165

166+
#https://github.com/JuliaDiff/ForwardDiff.jl/issues/720
166167
@testset "allocation-free hessian with StaticArrays" begin
167-
#https://github.com/JuliaDiff/ForwardDiff.jl/issues/720
168-
g = r -> (r[1]^2 - 3) * (r[2]^2 - 2)
169-
x = SA_F32[0.5, 2.7]
170-
hres = DiffResults.HessianResult(x)
171-
ForwardDiff.hessian!(hres, g, x)
172-
@test @allocated(ForwardDiff.hessian!(hres, g, x)) == 0
168+
function hessian_allocs()
169+
g = r -> (r[1]^2 - 3) * (r[2]^2 - 2)
170+
x = SVector(0.5, 2.8)
171+
hres = DiffResults.HessianResult(x)
172+
return @allocated(ForwardDiff.hessian!(hres, g, x))
173+
end
174+
@test iszero(hessian_allocs())
173175
end
174176

175177
end # module

0 commit comments

Comments
 (0)