Skip to content

Commit 8cfcfd2

Browse files
Add batch/vectorized finite difference Jacobian evaluation
Implements feature requested in #210: allows computing the full Jacobian in a single batched function call instead of N sequential calls. This is useful for GPU-parallelized functions that can evaluate multiple inputs simultaneously. Adds `batch=true` keyword to `finite_difference_jacobian` and `finite_difference_jacobian!`. When enabled, `f` receives a matrix where each column is an input point and returns a matrix of outputs. Supports forward, central, and complex step methods. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 788e098 commit 8cfcfd2

2 files changed

Lines changed: 252 additions & 2 deletions

File tree

src/jacobians.jl

Lines changed: 174 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,167 @@ function _make_Ji(::AbstractArray, xtype, dx, color_i, nrows, ncols)
186186
size(Ji) != (nrows, ncols) ? reshape(Ji, (nrows, ncols)) : Ji #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
187187
end
188188

189+
"""
190+
_finite_difference_jacobian_batch(f, x, fdtype, returntype, f_in; relstep, absstep, dir)
191+
192+
Internal function implementing vectorized/batched finite difference Jacobian computation.
193+
194+
When `batch=true` is passed to `finite_difference_jacobian`, this function is called instead
195+
of the standard column-by-column approach. The function `f` is expected to accept a matrix
196+
where each column is an input point, and return a matrix where each column is the
197+
corresponding output. This allows GPU-parallelized or otherwise vectorized functions to
198+
evaluate all perturbations in a single call.
199+
200+
For forward differences, a single call to `f` is made with `n+1` columns (base point + `n`
201+
perturbations) if `f_in` is not provided, or `n` columns if `f_in` is provided.
202+
For central differences, `2n` columns are used (forward and backward perturbations).
203+
For complex step, `n` columns are used with complex perturbations.
204+
"""
205+
function _finite_difference_jacobian_batch(f, x, fdtype, returntype, f_in;
206+
relstep, absstep, dir)
207+
fdtype isa Type && (fdtype = fdtype())
208+
n = length(x)
209+
vecx = _vec(x)
210+
211+
if fdtype == Val(:forward)
212+
epsilons = [compute_epsilon(Val(:forward), vecx[i], relstep, absstep, dir) for i in 1:n]
213+
214+
if f_in isa Nothing
215+
# Include x as the first column so we only call f once
216+
X = repeat(vecx, 1, n + 1)
217+
for i in 1:n
218+
X[i, i + 1] += epsilons[i]
219+
end
220+
FX = f(X)
221+
fx_col = @view FX[:, 1]
222+
J = similar(FX, size(FX, 1), n)
223+
for i in 1:n
224+
@. J[:, i] = (FX[:, i + 1] - fx_col) / epsilons[i]
225+
end
226+
else
227+
X = repeat(vecx, 1, n)
228+
for i in 1:n
229+
X[i, i] += epsilons[i]
230+
end
231+
FX = f(X)
232+
vfx = _vec(f_in)
233+
J = similar(FX, size(FX, 1), n)
234+
for i in 1:n
235+
@. J[:, i] = (FX[:, i] - vfx) / epsilons[i]
236+
end
237+
end
238+
return J
239+
240+
elseif fdtype == Val(:central)
241+
epsilons = [compute_epsilon(Val(:central), vecx[i], relstep, absstep, dir) for i in 1:n]
242+
243+
# Build matrix with 2n columns: [x+eps1*e1, x-eps1*e1, x+eps2*e2, x-eps2*e2, ...]
244+
X = repeat(vecx, 1, 2n)
245+
for i in 1:n
246+
X[i, 2i - 1] += epsilons[i]
247+
X[i, 2i] -= epsilons[i]
248+
end
249+
FX = f(X)
250+
J = similar(FX, size(FX, 1), n)
251+
for i in 1:n
252+
@. J[:, i] = (FX[:, 2i - 1] - FX[:, 2i]) / (2 * epsilons[i])
253+
end
254+
return J
255+
256+
elseif fdtype == Val(:complex) && returntype <: Real
257+
epsilon = eps(eltype(x))
258+
259+
# Build complex matrix with n columns
260+
X = repeat(complex.(vecx), 1, n)
261+
for i in 1:n
262+
X[i, i] += im * epsilon
263+
end
264+
FX = f(X)
265+
J = similar(FX, real(eltype(FX)), size(FX, 1), n)
266+
for i in 1:n
267+
@. J[:, i] = imag(FX[:, i]) / epsilon
268+
end
269+
return J
270+
else
271+
fdtype_error(returntype)
272+
end
273+
end
274+
275+
"""
276+
_finite_difference_jacobian_batch!(J, f, x, fdtype, returntype, f_in; relstep, absstep, dir)
277+
278+
Internal in-place function implementing vectorized/batched finite difference Jacobian computation.
279+
280+
When `batch=true` is passed to `finite_difference_jacobian!`, this function is called instead
281+
of the standard column-by-column approach. The function `f` is expected to accept two matrix
282+
arguments `f(FX, X)` where `X` has columns of input points and `FX` is filled with the
283+
corresponding outputs.
284+
"""
285+
function _finite_difference_jacobian_batch!(J, f, x, fdtype, returntype, f_in;
286+
relstep, absstep, dir)
287+
fdtype isa Type && (fdtype = fdtype())
288+
m, n = size(J)
289+
vecx = _vec(x)
290+
291+
if fdtype == Val(:forward)
292+
epsilons = [compute_epsilon(Val(:forward), vecx[i], relstep, absstep, dir) for i in 1:n]
293+
294+
if f_in isa Nothing
295+
# n+1 columns: base point + n perturbations
296+
X = repeat(vecx, 1, n + 1)
297+
for i in 1:n
298+
X[i, i + 1] += epsilons[i]
299+
end
300+
FX = similar(x, m, n + 1)
301+
f(FX, X)
302+
for i in 1:n
303+
@. J[:, i] = (FX[:, i + 1] - FX[:, 1]) / epsilons[i]
304+
end
305+
else
306+
X = repeat(vecx, 1, n)
307+
for i in 1:n
308+
X[i, i] += epsilons[i]
309+
end
310+
FX = similar(x, m, n)
311+
f(FX, X)
312+
vfx = _vec(f_in)
313+
for i in 1:n
314+
@. J[:, i] = (FX[:, i] - vfx) / epsilons[i]
315+
end
316+
end
317+
318+
elseif fdtype == Val(:central)
319+
epsilons = [compute_epsilon(Val(:central), vecx[i], relstep, absstep, dir) for i in 1:n]
320+
321+
X = repeat(vecx, 1, 2n)
322+
for i in 1:n
323+
X[i, 2i - 1] += epsilons[i]
324+
X[i, 2i] -= epsilons[i]
325+
end
326+
FX = similar(x, m, 2n)
327+
f(FX, X)
328+
for i in 1:n
329+
@. J[:, i] = (FX[:, 2i - 1] - FX[:, 2i]) / (2 * epsilons[i])
330+
end
331+
332+
elseif fdtype == Val(:complex) && returntype <: Real
333+
epsilon = eps(eltype(x))
334+
335+
X = repeat(complex.(vecx), 1, n)
336+
for i in 1:n
337+
X[i, i] += im * epsilon
338+
end
339+
FX = similar(X, Complex{eltype(x)}, m, n)
340+
f(FX, X)
341+
for i in 1:n
342+
@. J[:, i] = imag(FX[:, i]) / epsilon
343+
end
344+
else
345+
fdtype_error(returntype)
346+
end
347+
nothing
348+
end
349+
189350
"""
190351
FiniteDiff.finite_difference_jacobian(
191352
f,
@@ -246,7 +407,12 @@ function finite_difference_jacobian(f, x,
246407
colorvec = 1:length(x),
247408
sparsity = nothing,
248409
jac_prototype = nothing,
249-
dir = true)
410+
dir = true,
411+
batch = false)
412+
if batch
413+
return _finite_difference_jacobian_batch(f, x, fdtype, returntype, f_in;
414+
relstep = relstep, absstep = absstep, dir = dir)
415+
end
250416
if f_in isa Nothing
251417
fx = f(x)
252418
else
@@ -452,7 +618,13 @@ function finite_difference_jacobian!(J,
452618
relstep = default_relstep(fdtype, eltype(x)),
453619
absstep = relstep,
454620
colorvec = 1:length(x),
455-
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing)
621+
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing,
622+
batch = false)
623+
if batch
624+
_finite_difference_jacobian_batch!(J, f, x, fdtype, returntype, f_in;
625+
relstep = relstep, absstep = absstep, dir = true)
626+
return nothing
627+
end
456628
if f_in isa Nothing && fdtype == Val(:forward)
457629
if size(J, 1) == length(x)
458630
fx = zero(x)

test/finitedifftests.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,81 @@ end
612612
@test FiniteDiff.finite_difference_hessian(f, x1, FiniteDiff.HessianCache(x1)) == Diagonal(2*ones(4))
613613
@test FiniteDiff.finite_difference_hessian(f, x1, FiniteDiff.HessianCache(x2)) == Diagonal(2*ones(4))
614614
end
615+
616+
# Batched Jacobian tests (issue #210)
617+
@time @testset "Batched Jacobian tests" begin
618+
# Out-of-place batched function: f(X) where X is n×k, returns m×k
619+
function oopf_scalar(x)
620+
[(x[1] + 3) * (x[2]^3 - 7) + 18,
621+
sin(x[2] * exp(x[1]) - 1)]
622+
end
623+
function oopf_batch(X::AbstractMatrix)
624+
hcat([oopf_scalar(X[:, j]) for j in 1:size(X, 2)]...)
625+
end
626+
# Also handle single vector input for the batch function
627+
oopf_batch(x::AbstractVector) = oopf_scalar(x)
628+
629+
# In-place batched function: f(FX, X) where X is n×k, FX is m×k
630+
function iipf_batch(FX::AbstractMatrix, X::AbstractMatrix)
631+
for j in 1:size(X, 2)
632+
FX[1, j] = (X[1, j] + 3) * (X[2, j]^3 - 7) + 18
633+
FX[2, j] = sin(X[2, j] * exp(X[1, j]) - 1)
634+
end
635+
end
636+
637+
x = [1.5, 0.7]
638+
J_ref = [[-7 + x[2]^3 3 * (3 + x[1]) * x[2]^2];
639+
[exp(x[1]) * x[2] * cos(1 - exp(x[1]) * x[2]) exp(x[1]) * cos(1 - exp(x[1]) * x[2])]]
640+
641+
@testset "Out-of-place batch" begin
642+
J_fwd = FiniteDiff.finite_difference_jacobian(oopf_batch, x, Val{:forward}; batch=true)
643+
@test err_func(J_fwd, J_ref) < 1e-6
644+
645+
# With f_in provided
646+
f_in = oopf_scalar(x)
647+
J_fwd2 = FiniteDiff.finite_difference_jacobian(oopf_batch, x, Val{:forward}, eltype(x), f_in; batch=true)
648+
@test err_func(J_fwd2, J_ref) < 1e-6
649+
650+
J_cen = FiniteDiff.finite_difference_jacobian(oopf_batch, x, Val{:central}; batch=true)
651+
@test err_func(J_cen, J_ref) < 1e-8
652+
653+
J_cpx = FiniteDiff.finite_difference_jacobian(oopf_batch, x, Val{:complex}; batch=true)
654+
@test err_func(J_cpx, J_ref) < 1e-14
655+
end
656+
657+
@testset "In-place batch" begin
658+
J = zero(J_ref)
659+
FiniteDiff.finite_difference_jacobian!(J, iipf_batch, x, Val{:forward}; batch=true)
660+
@test err_func(J, J_ref) < 1e-6
661+
662+
# With f_in provided
663+
f_in = oopf_scalar(x)
664+
J .= 0
665+
FiniteDiff.finite_difference_jacobian!(J, iipf_batch, x, Val{:forward}, eltype(x), f_in; batch=true)
666+
@test err_func(J, J_ref) < 1e-6
667+
668+
J .= 0
669+
FiniteDiff.finite_difference_jacobian!(J, iipf_batch, x, Val{:central}; batch=true)
670+
@test err_func(J, J_ref) < 1e-8
671+
672+
J .= 0
673+
FiniteDiff.finite_difference_jacobian!(J, iipf_batch, x, Val{:complex}; batch=true)
674+
@test err_func(J, J_ref) < 1e-14
675+
end
676+
677+
@testset "Batch matches non-batch" begin
678+
# Test on a larger function to make sure batch and non-batch agree
679+
f_oop(x) = [x[1]^2 + x[2]*x[3], sin(x[1]) + x[3]^2, x[1]*x[2]*x[3]]
680+
function f_batch(X::AbstractMatrix)
681+
hcat([f_oop(X[:, j]) for j in 1:size(X, 2)]...)
682+
end
683+
f_batch(x::AbstractVector) = f_oop(x)
684+
685+
x3 = [2.0, 3.0, 1.5]
686+
for fdtype in (Val{:forward}, Val{:central}, Val{:complex})
687+
J_std = FiniteDiff.finite_difference_jacobian(f_oop, x3, fdtype)
688+
J_bat = FiniteDiff.finite_difference_jacobian(f_batch, x3, fdtype; batch=true)
689+
@test J_std J_bat
690+
end
691+
end
692+
end

0 commit comments

Comments
 (0)