@@ -6,43 +6,54 @@ using EnzymeCore
66using EnzymeCore. EnzymeRules
77
88# =============================================================================
9- # Forward mode rules
9+ # Forward mode rules — generalized to arbitrary batch width W
1010# =============================================================================
1111
1212# Shadow only (Forward mode, no primal)
1313function EnzymeRules. forward (
14- :: EnzymeRules.FwdConfig{false, true, 1 , RuntimeActivity, StrongZero} ,
14+ :: EnzymeRules.FwdConfig{false, true, W , RuntimeActivity, StrongZero} ,
1515 func:: EnzymeCore.Const{<:FunctionWrappersWrapper} ,
1616 RT:: Type{<:EnzymeCore.Annotation{T}} ,
1717 args:: Vararg{EnzymeCore.Annotation, N}
18- ) where {T, N, RuntimeActivity, StrongZero}
18+ ) where {T, W, N, RuntimeActivity, StrongZero}
1919 f_orig = unwrap (func. val)
20- shadow_result = Enzyme. autodiff (Forward, Const (f_orig), Duplicated{T}, args... )
21- return shadow_result[1 ]:: T
20+ if W == 1
21+ shadow_result = Enzyme. autodiff (Forward, Const (f_orig), Duplicated{T}, args... )
22+ return shadow_result[1 ]:: T
23+ else
24+ shadow_result = Enzyme. autodiff (Forward, Const (f_orig), BatchDuplicated{T, W}, args... )
25+ return shadow_result[1 ]:: NTuple{W, T}
26+ end
2227end
2328
2429# Both primal and shadow (ForwardWithPrimal mode)
2530function EnzymeRules. forward (
26- :: EnzymeRules.FwdConfig{true, true, 1 , RuntimeActivity, StrongZero} ,
31+ :: EnzymeRules.FwdConfig{true, true, W , RuntimeActivity, StrongZero} ,
2732 func:: EnzymeCore.Const{<:FunctionWrappersWrapper} ,
2833 RT:: Type{<:EnzymeCore.Annotation{T}} ,
2934 args:: Vararg{EnzymeCore.Annotation, N}
30- ) where {T, N, RuntimeActivity, StrongZero}
35+ ) where {T, W, N, RuntimeActivity, StrongZero}
3136 f_orig = unwrap (func. val)
3237 pargs = ntuple (i -> args[i]. val, Val (N))
3338 primal = f_orig (pargs... ):: T
34- shadow_result = Enzyme. autodiff (Forward, Const (f_orig), Duplicated{T}, args... )
35- shadow = shadow_result[1 ]:: T
36- return Duplicated (primal, shadow)
39+ if W == 1
40+ shadow_result = Enzyme. autodiff (Forward, Const (f_orig), Duplicated{T}, args... )
41+ shadow = shadow_result[1 ]:: T
42+ return Duplicated (primal, shadow)
43+ else
44+ shadow_result = Enzyme. autodiff (Forward, Const (f_orig), BatchDuplicated{T, W}, args... )
45+ shadows = shadow_result[1 ]:: NTuple{W, T}
46+ return BatchDuplicated (primal, shadows)
47+ end
3748end
3849
39- # Primal only (Const return type)
50+ # Primal only (Const return type) — width-independent
4051function EnzymeRules. forward (
41- :: EnzymeRules.FwdConfig{true, false, 1 , RuntimeActivity, StrongZero} ,
52+ :: EnzymeRules.FwdConfig{true, false, W , RuntimeActivity, StrongZero} ,
4253 func:: EnzymeCore.Const{<:FunctionWrappersWrapper} ,
4354 RT:: Type{<:EnzymeCore.Annotation} ,
4455 args:: Vararg{EnzymeCore.Annotation, N}
45- ) where {N, RuntimeActivity, StrongZero}
56+ ) where {W, N, RuntimeActivity, StrongZero}
4657 f_orig = unwrap (func. val)
4758 pargs = ntuple (i -> args[i]. val, Val (N))
4859 return f_orig (pargs... )
0 commit comments