@@ -128,13 +128,14 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
128128 rrule_via_ad (cfg, f, a... )
129129 end
130130 function back_generic (dys)
131- deltas = unzip_broadcast (backs, unthunk ( dys) ) do back, dy # (could be map, sizes match)
131+ deltas = unzip_broadcast (backs, dys) do back, dy # (could be map, sizes match)
132132 map (unthunk, back (dy))
133133 end
134134 dargs = map (unbroadcast, args, Base. tail (deltas))
135135 df = ProjectTo (f)(sum (first (deltas)))
136136 return (NoTangent (), NoTangent (), df, dargs... )
137137 end
138+ back_generic (dys:: AbstractThunk ) = back_generic (unthunk (dys))
138139 back_generic (z:: AbstractZero ) = (TRI_NO... , map (Returns (z), args)... )
139140 return ys3, back_generic
140141end
@@ -318,7 +319,7 @@ rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |
318319
319320function unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx_raw)
320321 dx = unthunk (dx_raw)
321- N = ndims (dx)
322+ N = _ndims (dx)
322323 if length (x) == length (dx)
323324 ProjectTo (x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
324325 else
@@ -328,6 +329,9 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx_raw)
328329end
329330unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx:: AbstractZero ) = dx
330331
332+ _ndims (x) = ndims (x)
333+ _ndims (:: Tuple ) = 1
334+
331335function unbroadcast (x:: T , dx_raw) where {T<: Tuple{Vararg{Any,N}} } where {N}
332336 dx = unthunk (dx_raw)
333337 val = if N == length (dx)
0 commit comments