Skip to content

Commit 85db48d

Browse files
fix: fix AD through inline linear SCCs
1 parent 2c2d513 commit 85db48d

2 files changed

Lines changed: 26 additions & 1 deletion

File tree

lib/ModelingToolkitBase/src/systems/problem_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ function get_mtkparameters_reconstructor(
849849
initialvals = getters[2](valp)
850850
nonnumerics = getters[5](valp)
851851
if !iszero(diffcache_buffer_idx)
852-
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{eltype(initialvals)}.(nonnumerics[diffcache_buffer_idx])
852+
@set! nonnumerics[diffcache_buffer_idx] = DiffCacheAllocatorAPIWrapper{ForwardDiff.valtype(eltype(initialvals))}.(nonnumerics[diffcache_buffer_idx])
853853
end
854854
return promote_with_nothing(
855855
promote_type_with_nothing(eltype(tunablevals), initialvals),

test/structural_transformation/tearing.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ using SymbolicIndexingInterface
99
using ModelingToolkit: t_nounits as t, D_nounits as D
1010
import StateSelection
1111
import SymbolicUtils as SU
12+
using ForwardDiff
13+
1214
###
1315
### Nonlinear system
1416
###
@@ -291,3 +293,26 @@ end
291293
@test Initial(sys.resistor1.v) in Set(ModelingToolkit.get_ps(sys))
292294
@test Initial(sys.resistor2.v) in Set(ModelingToolkit.get_ps(sys))
293295
end
296+
297+
@testset "AD through inline linear SCCs works" begin
298+
reassemble_alg = StructuralTransformations.DefaultReassembleAlgorithm(; inline_linear_sccs = true)
299+
@mtkcompile sys = RCModel() reassemble_alg = reassemble_alg
300+
prob = ODEProblem(sys, [], (0.0, 10.0))
301+
@assert prob.p.nonnumeric[1] isa
302+
Vector{ModelingToolkitBase.DiffCacheAllocatorAPIWrapper{Float64}}
303+
@assert SciMLBase.has_initializeprob(prob.f)
304+
305+
setter = setsym_oop(prob, [sys.R, sys.C])
306+
307+
function loss(x)
308+
new_u0, new_p = setter(prob, x)
309+
new_prob = remake(prob; u0 = new_u0, p = new_p)
310+
sol = solve(new_prob, Tsit5(); abstol = 1.0e-8, reltol = 1.0e-8)
311+
return sol[sys.capacitor.v][end]
312+
end
313+
314+
# Primal works: returns ~0.993
315+
@test_nowarn loss([1.0, 1.0])
316+
317+
@test_nowarn ForwardDiff.gradient(loss, [1.0, 1.0])
318+
end

0 commit comments

Comments
 (0)