Skip to content
This repository was archived by the owner on May 12, 2026. It is now read-only.

Commit 26b83d3

Browse files
Merge pull request #698 from ChrisRackauckas-Claude/fix-skencarp-g1-mutation
Fix SKenCarp g1 mutation bug in non-diagonal noise path
2 parents 8907f29 + 54fcc68 commit 26b83d3

3 files changed

Lines changed: 25 additions & 10 deletions

File tree

src/caches/kencarp_caches.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ end
3939
chi2::randType
4040
g1::rateNoiseType
4141
g4::rateNoiseType
42+
gtmp::rateNoiseType
4243
end
4344

4445
u_cache(c::SKenCarpCache) = (c.z₁, c.z₂, c.z₃, c.z₄, c.nlsolver.dz)
@@ -83,12 +84,13 @@ function alg_cache(
8384

8485
g1 = zero(noise_rate_prototype)
8586
g4 = zero(noise_rate_prototype)
87+
gtmp = zero(noise_rate_prototype)
8688

8789
return SKenCarpCache{
8890
typeof(u), typeof(rate_prototype), typeof(atmp), typeof(nlsolver),
8991
typeof(tab), typeof(k1), typeof(chi2), typeof(g1),
9092
}(
9193
u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k1, k2,
92-
k3, k4, atmp, nlsolver, tab, chi2, g1, g4
94+
k3, k4, atmp, nlsolver, tab, chi2, g1, g4, gtmp
9395
)
9496
end

src/perform_step/kencarp.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ end
150150
(; t, dt, uprev, u, p, f) = integrator
151151
g = integrator.f.g
152152
(; z₁, z₂, z₃, z₄, k1, k2, k3, k4, atmp) = cache
153-
(; g1, g4, chi2, nlsolver) = cache
153+
(; g1, g4, gtmp, chi2, nlsolver) = cache
154154
(; z, tmp) = nlsolver
155155
(; k, dz) = nlsolver.cache # alias to reduce memory
156156
(;
@@ -195,8 +195,9 @@ end
195195

196196
##### Step 2
197197

198-
# TODO: Add a cache so this isn't overwritten near the end, so it can not repeat on fail
199-
g(g1, uprev, p, t)
198+
if !repeat_step && !integrator.last_stepfail
199+
g(g1, uprev, p, t)
200+
end
200201

201202
if is_diagonal_noise(integrator.sol.prob)
202203
@.. z₄ = chi2 * g1 # use z₄ as storage for the g1*chi2
@@ -255,11 +256,9 @@ end
255256
@.. u = tmp + γ * z₃
256257
f2(k3, u, p, t + c3 * dt)
257258
k3 .*= dt
258-
# z₄ is storage for the g1*chi2 from earlier
259259
@.. tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 + nb043 * z₄
260260
else
261261
(; α41, α42) = cache.tab
262-
# z₄ is storage for the g1*chi2
263262
@.. tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + nb043 * z₄
264263
end
265264

@@ -285,8 +284,8 @@ end
285284
@.. u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + eb2 * k2 + eb3 * k3 +
286285
eb4 * k4 + integrator.W.dW * g4 + E₂
287286
else
288-
g1 .-= g4
289-
mul!(E₂, g1, chi2)
287+
@.. gtmp = g1 - g4
288+
mul!(E₂, gtmp, chi2)
290289
mul!(tmp, g4, integrator.W.dW)
291290
@.. u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + eb1 * k1 + eb2 * k2 + eb3 * k3 +
292291
eb4 * k4 + tmp + E₂
@@ -296,8 +295,8 @@ end
296295
@.. E₂ = chi2 * (g1 - g4)
297296
@.. u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + integrator.W.dW * g4 + E₂
298297
else
299-
g1 .-= g4
300-
mul!(E₂, g1, chi2)
298+
@.. gtmp = g1 - g4
299+
mul!(E₂, gtmp, chi2)
301300
mul!(tmp, g4, integrator.W.dW)
302301
@.. u = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + γ * z₄ + tmp + E₂
303302
end

test/alloc_tests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ end
9191
@test allocs_per_step == 0
9292
end
9393

94+
@testset "SKenCarp stepping allocation" begin
95+
integrator = init(prob_iip, SKenCarp(), dt = 0.01, adaptive = false, save_on = false)
96+
97+
for _ in 1:50
98+
step_void!(integrator)
99+
end
100+
101+
allocs_per_step = minimum(@allocated(step_void!(integrator)) for _ in 1:5)
102+
# Pkg.test runs with --check-bounds=yes which causes small allocations
103+
# (144 bytes) in the NL solver's broadcast/bounds-checking paths.
104+
# Without --check-bounds=yes, this is zero.
105+
@test allocs_per_step <= 200
106+
end
107+
94108
# Test with scalar SDE (out-of-place)
95109
@testset "Scalar SDE allocations" begin
96110
f_scalar(u, p, t) = 0.1 * u

0 commit comments

Comments
 (0)