Skip to content

Commit 8d2ec1f

Browse files
Merge pull request #63 from JuliaComputing/as/fix-diffcache-reference
fix: improve reference calculation for `DiffCache`
2 parents be4e95d + 1dab2bc commit 8d2ec1f

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

lib/ModelingToolkitTearing/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2626
BipartiteGraphs = "0.1.3"
2727
CommonSolve = "0.2"
2828
DocStringExtensions = "0.7, 0.8, 0.9"
29+
ForwardDiff = "1.3"
2930
Graphs = "1"
3031
LinearAlgebra = "1"
3132
ModelingToolkit = "11"
@@ -44,8 +45,9 @@ UUIDs = "1"
4445
julia = "1.10"
4546

4647
[extras]
48+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4749
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
4850
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4951

5052
[targets]
51-
test = ["Test", "ModelingToolkit"]
53+
test = ["Test", "ModelingToolkit", "ForwardDiff"]

lib/ModelingToolkitTearing/src/reassemble.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,10 @@ function get_linear_scc_linsol(state::TearingState, alg_eqs::Vector{Int},
606606
else
607607
reference = fullvars[state_idx]
608608
end
609+
reference = Symbolics.STerm(
610+
promote, Symbolics.SArgsT((reference, MTKBase.get_iv(sys)::SymbolicT));
611+
type = Vector{Real}, shape = [1:2]
612+
)[1]
609613
sys, A_cache = MTKBase.add_diffcache(sys, length(A))
610614
A_allocator = A_cache(reference)
611615
A = SU.Code.with_allocator(A_allocator, SU.Const{VartypeT}(A))

lib/ModelingToolkitTearing/test/runtests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Graphs
88
import StateSelection
99
using ModelingToolkit: t_nounits as t, D_nounits as D
1010
import SymbolicUtils as SU
11+
using ForwardDiff
1112

1213
@testset "`InferredDiscrete` validation" begin
1314
k = ShiftIndex()
@@ -110,3 +111,30 @@ end
110111
@test isempty(unknowns(sys))
111112
@test length(observed(sys)) == 2
112113
end
114+
115+
@testset "Duals through inline linear SCC DiffCaches" begin
116+
@variables x(t) y(t) z(t) w(t) q(t)
117+
reassemble_alg = MTKTearing.DefaultReassembleAlgorithm(; inline_linear_sccs = true)
118+
@mtkcompile sys = System(
119+
[
120+
D(x) ~ 2t + 1,
121+
t * y + x * z + w ~ 4,
122+
4y + 3z + 2w ~ 7,
123+
2x * y + 3t * z + w ~ 10,
124+
D(q) ~ 2w,
125+
],
126+
t
127+
) reassemble_alg = reassemble_alg
128+
129+
prob = ODEProblem(sys, [x => 1, q => 1], (0.0, 1.0); guesses = [x => 1, y => 1, z => 1, w => 1, q => 1])
130+
131+
buffer = similar(prob.u0)
132+
@test_nowarn ForwardDiff.jacobian(buffer, prob.u0) do du, u
133+
@test eltype(u) <: ForwardDiff.Dual
134+
prob.f.f.f_iip(du, u, prob.p, 1.0)
135+
end
136+
@test_nowarn ForwardDiff.derivative(buffer, 1.2) do du, t
137+
@test t isa ForwardDiff.Dual
138+
prob.f.f.f_iip(du, prob.u0, prob.p, t)
139+
end
140+
end

0 commit comments

Comments
 (0)