@@ -5,11 +5,25 @@ struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}}
55@Base . constprop :aggressive accum (a:: Tuple , b:: Tuple ) = map (accum, a, b)
66@Base . constprop :aggressive @generated function accum (x:: NamedTuple , y:: NamedTuple )
77 fnames = union (fieldnames (x), fieldnames (y))
8+ isempty (fnames) && return :((;)) # code below makes () instead
89 gradx (f) = f in fieldnames (x) ? :(getfield (x, $ (quot (f)))) : :(ZeroTangent ())
910 grady (f) = f in fieldnames (y) ? :(getfield (y, $ (quot (f)))) : :(ZeroTangent ())
1011 Expr (:tuple , [:($ f= accum ($ (gradx (f)), $ (grady (f)))) for f in fnames]. .. )
1112end
1213@Base . constprop :aggressive accum (a, b, c, args... ) = accum (accum (a, b), c, args... )
13- @Base . constprop :aggressive accum (a:: NoTangent , b) = b
14- @Base . constprop :aggressive accum (a, b:: NoTangent ) = a
15- @Base . constprop :aggressive accum (a:: NoTangent , b:: NoTangent ) = NoTangent ()
14+ @Base . constprop :aggressive accum (a:: AbstractZero , b) = b
15+ @Base . constprop :aggressive accum (a, b:: AbstractZero ) = a
16+ @Base . constprop :aggressive accum (a:: AbstractZero , b:: AbstractZero ) = NoTangent ()
17+
18+ using ChainRulesCore: Tangent, backing
19+
20+ function accum (x:: Tangent{T} , y:: NamedTuple ) where T
21+ # @warn "gradient is both a Tangent and a NamedTuple" x y
22+ _tangent (T, accum (backing (x), y))
23+ end
24+ accum (x:: NamedTuple , y:: Tangent ) = accum (y, x)
25+ # This solves an ambiguity, but also avoids Tangent{ZeroTangent}() which + does not:
26+ accum (x:: Tangent{T} , y:: Tangent ) where T = _tangent (T, accum (backing (x), backing (y)))
27+
28+ _tangent (:: Type{T} , z) where T = Tangent {T,typeof(z)} (z)
29+ _tangent (:: Type , :: NamedTuple{()} ) = NoTangent ()
0 commit comments