Skip to content

Commit 019e33d

Browse files
committed
Use getsensdomain in getadjointsensitivities & use Reversediff if length(p) larger than threshold
1 parent a7fe3c3 commit 019e33d

1 file changed

Lines changed: 57 additions & 2 deletions

File tree

src/Simulation.jl

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,41 @@ function getadjointsensitivities(bsol::Q,target::String,solver::W;sensalg::W2=In
148148
end
149149
else
150150
ind = findfirst(isequal(target),bsol.names)
151+
sensdomain,sensspcnames,senstooriginspcind,senstooriginrxnind = getsensdomain(bsol.domain,ind)
152+
if :thermovariabledict in fieldnames(typeof(bsol.domain))
153+
yinds = vcat(senstooriginspcind,collect(values(bsol.domain.thermovariabledict)))
154+
else
155+
yinds = vcat(senstooriginspcind)
156+
end
157+
pinds = vcat(senstooriginspcind,length(bsol.domain.phase.species).+senstooriginrxnind)
158+
ind = findfirst(isequal(target),sensspcnames)
159+
end
160+
161+
function sensg(y::X,p::Array{Y,1},t::Z) where {Q,V,X,Y<:Float64,Z}
162+
sensy = y[yinds]
163+
sensp = p[pinds]
164+
dy = similar(sensy,length(sensy))
165+
return dydtreactor!(dy,sensy,t,sensdomain,[],p=sensp)[ind]
166+
end
167+
function sensg(y::Array{X,1},p::Y,t::Z) where {Q,V,X<:Float64,Y,Z}
168+
sensy = y[yinds]
169+
sensp = p[pinds]
170+
dy = similar(sensp,length(sensy))
171+
return dydtreactor!(dy,sensy,t,sensdomain,[],p=sensp)[ind]
151172
end
173+
function sensg(y::Array{X,1},p::Array{Y,1},t::Z) where {Q,V,X<:Float64,Y<:Float64,Z}
174+
sensy = y[yinds]
175+
sensp = p[pinds]
176+
dy = similar(sensy,length(sensy))
177+
return dydtreactor!(dy,sensy,t,sensdomain,[],p=sensp)[ind]
178+
end
179+
function sensg(y::Array{X,1},p::Array{Y,1},t::Z) where {Q,V,X<:ForwardDiff.Dual,Y<:ForwardDiff.Dual,Z}
180+
sensy = y[yinds]
181+
sensp = p[pinds]
182+
dy = similar(sensy,length(sensy))
183+
return dydtreactor!(dy,sensy,t,sensdomain,[],p=sensp)[ind]
184+
end
185+
152186
function g(y::X,p::Array{Y,1},t::Z) where {Q,V,X,Y<:Float64,Z}
153187
dy = similar(y,length(y))
154188
return dydtreactor!(dy,y,t,bsol.domain,[],p=p)[ind]
@@ -165,12 +199,33 @@ function getadjointsensitivities(bsol::Q,target::String,solver::W;sensalg::W2=In
165199
dy = similar(y,length(y))
166200
return dydtreactor!(dy,y,t,bsol.domain,[],p=p)[ind]
167201
end
202+
203+
dsensgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> sensg(y, p, t), y)
204+
dsensgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> sensg(y, p, t), p)
168205
dgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> g(y, p, t), y)
169206
dgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> g(y, p, t), p)
170-
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,g,nothing,(dgdu,dgdp);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
207+
dsensgdurevdiff(out, y, p, t) = ReverseDiff.gradient!(out, y -> sensg(y, p, t), y)
208+
dsensgdprevdiff(out, y, p, t) = ReverseDiff.gradient!(out, p -> sensg(y, p, t), p)
209+
dgdurevdiff(out, y, p, t) = ReverseDiff.gradient!(out, y -> g(y, p, t), y)
210+
dgdprevdiff(out, y, p, t) = ReverseDiff.gradient!(out, p -> g(y, p, t), p)
211+
212+
pethane = 160
213+
if length(bsol.domain.p)<= pethane
214+
if target in ["T","V","P"]
215+
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,g,nothing,(dgdu,dgdp);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
216+
else
217+
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,sensg,nothing,(dsensgdu,dsensgdp);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
218+
end
219+
else
220+
if target in ["T","V","P"]
221+
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,g,nothing,(dgdurevdiff,dgdprevdiff);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
222+
else
223+
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,sensg,nothing,(dsensgdurevdiff,dsensgdprevdiff);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
224+
end
225+
end
171226
dpadj[length(bsol.domain.phase.species)+1:end] .*= bsol.domain.p[length(bsol.domain.phase.species)+1:end]
172227
if !(target in ["T","V","P"])
173-
dpadj ./= bsol.sol(bsol.sol.t[end])[ind]
228+
dpadj ./= bsol.sol(bsol.sol.t[end])[senstooriginspcind[ind]]
174229
end
175230
return dpadj
176231
end

0 commit comments

Comments
 (0)