Skip to content
This repository was archived by the owner on Apr 21, 2026. It is now read-only.

Commit d202eff

Browse files
Merge pull request #1291 from ChrisRackauckas-Claude/precompile-subarray-dual-broadcast
Add precompile workload for Dual and SubArray broadcast operations
2 parents 7dfa3b8 + 0ba42fc commit d202eff

2 files changed

Lines changed: 131 additions & 1 deletion

File tree

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,4 +225,134 @@ if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual})
225225
end
226226
end
227227

228+
import PrecompileTools
229+
PrecompileTools.@compile_workload begin
230+
# Scalar operations on Dual numbers (arithmetic, math functions, comparisons)
231+
d1 = dualT(1.0, ForwardDiff.Partials((0.5,)))
232+
d2 = dualT(2.0, ForwardDiff.Partials((1.0,)))
233+
s = 3.14
234+
235+
# Arithmetic: Dual-Dual and Dual-scalar
236+
d1 + d2
237+
d1 - d2
238+
d1 * d2
239+
d1 / d2
240+
d1 + s
241+
s + d1
242+
d1 - s
243+
s - d1
244+
d1 * s
245+
s * d1
246+
d1 / s
247+
s / d1
248+
-d1
249+
abs(d1)
250+
251+
# Powers and roots
252+
d1^2
253+
d1^3
254+
d2^0.5
255+
sqrt(d2)
256+
cbrt(d2)
257+
258+
# Transcendental functions
259+
exp(d1)
260+
log(d2)
261+
sin(d1)
262+
cos(d1)
263+
tan(d1)
264+
asin(dualT(0.5, ForwardDiff.Partials((1.0,))))
265+
acos(dualT(0.5, ForwardDiff.Partials((1.0,))))
266+
atan(d1)
267+
atan(d1, d2)
268+
sinh(d1)
269+
cosh(d1)
270+
tanh(d1)
271+
272+
# Comparisons (used in step size control, event detection)
273+
d1 < d2
274+
d1 > d2
275+
d1 <= d2
276+
d1 >= d2
277+
d1 == d2
278+
isnan(d1)
279+
isinf(d1)
280+
isfinite(d1)
281+
282+
# min/max (used in limiters and error control)
283+
min(d1, d2)
284+
max(d1, d2)
285+
min(d1, s)
286+
max(d1, s)
287+
288+
# Conversion and promotion
289+
zero(dualT)
290+
one(dualT)
291+
float(d1)
292+
ForwardDiff.value(d1)
293+
ForwardDiff.partials(d1)
294+
295+
# Array operations on Vector{dualT}
296+
v1 = [d1, d2, dualT(0.0, ForwardDiff.Partials((0.0,)))]
297+
v2 = [d2, d1, dualT(1.0, ForwardDiff.Partials((0.1,)))]
298+
299+
# Basic array ops
300+
v1 + v2
301+
v1 - v2
302+
v1 .* v2
303+
v1 ./ v2
304+
s .* v1
305+
v1 .+ s
306+
v1 .- s
307+
v1 .^ 2
308+
v1 .^ 0.5
309+
310+
# In-place array operations
311+
out = similar(v1)
312+
out .= v1 .+ v2
313+
out .= v1 .- v2
314+
out .= v1 .* v2
315+
out .= s .* v1
316+
out .= v1 .* s .+ v2
317+
out .= v1 .* s .- v2 .* s
318+
319+
# Reductions (used in norm calculations, error estimation)
320+
sum(v1)
321+
sum(abs2, v1)
322+
maximum(abs, v1)
323+
324+
# LinearAlgebra operations
325+
using LinearAlgebra
326+
dot(v1, v2)
327+
norm(v1)
328+
norm(v1, Inf)
329+
norm(v1, 1)
330+
331+
# copy / fill
332+
copy(v1)
333+
fill!(out, zero(dualT))
334+
335+
# SubArray primitive broadcast operations for Float64 and Dual types.
336+
# These are generic building blocks used by any ODE function with views.
337+
# Note: fused multi-operand broadcast expressions (e.g. `dy .= k .* y1 .+ k .* y2 .* y3`)
338+
# create unique nested Broadcasted types per expression and cannot be generically precompiled.
339+
for T in (Float64, dualT)
340+
x = zeros(T, 4)
341+
dx = zeros(T, 4)
342+
sv1 = @view x[1:2]
343+
sv2 = @view x[3:4]
344+
dsv1 = @view dx[1:2]
345+
k = 0.04
346+
347+
# Primitive SubArray broadcast operations
348+
dsv1 .= sv1
349+
dsv1 .= k .* sv1
350+
dsv1 .= sv1 .* sv2
351+
dsv1 .= sv1 .+ sv2
352+
dsv1 .= sv1 .- sv2
353+
dsv1 .= sv1 .^ 2
354+
dsv1 .= .-sv1
355+
end
356+
end
357+
228358
end

test/downstream/community_callback_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ cb = VectorContinuousCallback(cond!, terminate_affect!, nothing, 1)
225225
u0 = [0.0, 0.0, 1.0]
226226
prob = ODEProblem(f!, u0, (0.0, 10.0); callback = cb)
227227
soln = solve(prob, Tsit5())
228-
@test soln.t[end] 4.712347213360699
228+
@test soln.t[end] 4.712347213360699 atol = 1e-4
229229

230230
odefun = ODEFunction((u, p, t) -> [u[2], u[2] - p]; mass_matrix = [1 0; 0 0])
231231
callback = PresetTimeCallback(0.5, integ -> (integ.p = -integ.p))

0 commit comments

Comments
 (0)