Skip to content

Commit 5845c78

Browse files
Merge pull request #41 from ChrisRackauckas-Claude/fix-enzyme-batch-width
Generalize Enzyme forward rules to arbitrary batch width
2 parents 726ed55 + 4d6349c commit 5845c78

2 files changed

Lines changed: 48 additions & 13 deletions

File tree

ext/FunctionWrappersWrappersEnzymeExt.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,54 @@ using EnzymeCore
66
using EnzymeCore.EnzymeRules
77

88
# =============================================================================
9-
# Forward mode rules
9+
# Forward mode rules — generalized to arbitrary batch width W
1010
# =============================================================================
1111

1212
# Shadow only (Forward mode, no primal)
1313
function EnzymeRules.forward(
14-
::EnzymeRules.FwdConfig{false, true, 1, RuntimeActivity, StrongZero},
14+
::EnzymeRules.FwdConfig{false, true, W, RuntimeActivity, StrongZero},
1515
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
1616
RT::Type{<:EnzymeCore.Annotation{T}},
1717
args::Vararg{EnzymeCore.Annotation, N}
18-
) where {T, N, RuntimeActivity, StrongZero}
18+
) where {T, W, N, RuntimeActivity, StrongZero}
1919
f_orig = unwrap(func.val)
20-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{T}, args...)
21-
return shadow_result[1]::T
20+
if W == 1
21+
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{T}, args...)
22+
return shadow_result[1]::T
23+
else
24+
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), BatchDuplicated{T, W}, args...)
25+
return shadow_result[1]::NTuple{W, T}
26+
end
2227
end
2328

2429
# Both primal and shadow (ForwardWithPrimal mode)
2530
function EnzymeRules.forward(
26-
::EnzymeRules.FwdConfig{true, true, 1, RuntimeActivity, StrongZero},
31+
::EnzymeRules.FwdConfig{true, true, W, RuntimeActivity, StrongZero},
2732
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
2833
RT::Type{<:EnzymeCore.Annotation{T}},
2934
args::Vararg{EnzymeCore.Annotation, N}
30-
) where {T, N, RuntimeActivity, StrongZero}
35+
) where {T, W, N, RuntimeActivity, StrongZero}
3136
f_orig = unwrap(func.val)
3237
pargs = ntuple(i -> args[i].val, Val(N))
3338
primal = f_orig(pargs...)::T
34-
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{T}, args...)
35-
shadow = shadow_result[1]::T
36-
return Duplicated(primal, shadow)
39+
if W == 1
40+
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), Duplicated{T}, args...)
41+
shadow = shadow_result[1]::T
42+
return Duplicated(primal, shadow)
43+
else
44+
shadow_result = Enzyme.autodiff(Forward, Const(f_orig), BatchDuplicated{T, W}, args...)
45+
shadows = shadow_result[1]::NTuple{W, T}
46+
return BatchDuplicated(primal, shadows)
47+
end
3748
end
3849

39-
# Primal only (Const return type)
50+
# Primal only (Const return type) — width-independent
4051
function EnzymeRules.forward(
41-
::EnzymeRules.FwdConfig{true, false, 1, RuntimeActivity, StrongZero},
52+
::EnzymeRules.FwdConfig{true, false, W, RuntimeActivity, StrongZero},
4253
func::EnzymeCore.Const{<:FunctionWrappersWrapper},
4354
RT::Type{<:EnzymeCore.Annotation},
4455
args::Vararg{EnzymeCore.Annotation, N}
45-
) where {N, RuntimeActivity, StrongZero}
56+
) where {W, N, RuntimeActivity, StrongZero}
4657
f_orig = unwrap(func.val)
4758
pargs = ntuple(i -> args[i].val, Val(N))
4859
return f_orig(pargs...)

test/enzyme_tests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,27 @@ end
5050
result = Enzyme.autodiff(Reverse, Const(fww_sin), Active, Active(1.0))
5151
@test result[1][1] cos(1.0)
5252
end
53+
54+
@testset "Enzyme batch forward mode (width > 1)" begin
55+
f(x) = x^2
56+
fww = FunctionWrappersWrapper(f, (Tuple{Float64},), (Float64,))
57+
58+
# Batch width = 2: compute derivatives for two tangent directions simultaneously.
59+
# f(x) = x^2 → f'(x) = 2x; at x=3.0 with tangents (1.0, 2.0) → shadows (6.0, 12.0)
60+
result = Enzyme.autodiff(
61+
Forward, Const(fww), BatchDuplicated,
62+
BatchDuplicated(3.0, (1.0, 2.0))
63+
)
64+
shadows = result[1]
65+
@test shadows[1] 6.0 # f'(3) * 1.0
66+
@test shadows[2] 12.0 # f'(3) * 2.0
67+
68+
# ForwardWithPrimal, batch width = 2
69+
result_wp = Enzyme.autodiff(
70+
ForwardWithPrimal, Const(fww), BatchDuplicated,
71+
BatchDuplicated(3.0, (1.0, 2.0))
72+
)
73+
@test result_wp[1][1] 6.0 # shadow 1
74+
@test result_wp[1][2] 12.0 # shadow 2
75+
@test result_wp[2] 9.0 # primal f(3) = 9
76+
end

0 commit comments

Comments
 (0)