Skip to content

Commit 763327a

Browse files
broadcast: split ComposedFunction on AbstractGPUArrayStyle (#719)
Rewrites `(f ∘ g).(args...)` as `f.(g.(args...))` so the kernel closure doesn't hit `ComposedFunction`'s kwarg-accepting call, whose kwsorter GPUCompiler can't resolve (e.g. NNlib.tanh_fast on CUDA) Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent d65ce69 commit 763327a

2 files changed

Lines changed: 24 additions & 0 deletions

File tree

src/host/broadcast.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ using Base.Broadcast
44

55
using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
66

7+
# split `ComposedFunction` so the kernel closure doesn't hit its kwarg-accepting
8+
# call, whose kwsorter dispatch GPUCompiler can't resolve statically.
9+
@inline Broadcast.broadcasted(S::AbstractGPUArrayStyle, c::ComposedFunction, args...) =
10+
Broadcast.broadcasted(S, c.outer, Broadcast.broadcasted(S, c.inner, args...))
11+
712
# but make sure we don't dispatch to the optimized copy method that directly indexes
813
function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
914
ElType = Broadcast.combine_eltypes(bc.f, bc.args)

test/testsuite/broadcasting.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
broadcasting(AT, eltypes)
33
vec3(AT, eltypes)
44
unknown_wrapper(AT, eltypes)
5+
composed_function(AT, eltypes)
56
end
67

78
test_idx(idx, A::AbstractArray{T}) where T = A[idx] * T(2)
@@ -228,3 +229,21 @@ function unknown_wrapper(AT, eltypes)
228229
end
229230
end
230231
end
232+
233+
function composed_function(AT, eltypes)
234+
sq(x) = x*x
235+
for ET in eltypes
236+
@testset "ComposedFunction $ET" begin
237+
a = AT(rand(ET, 8))
238+
b = AT(rand(ET, 8))
239+
ca, cb = Array(a), Array(b)
240+
241+
@test Array(broadcast(sq (+), a, b)) (ca .+ cb).^2
242+
@test Array((sq (+)).(a, b)) (ca .+ cb).^2
243+
@test Array((sq sq (+)).(a, b)) ((ca .+ cb).^2).^2
244+
@test Array((sq identity).(a)) ca.^2
245+
@test Array((sq (+)).(a, Ref(ET(2)))) (ca .+ ET(2)).^2
246+
@test Array((identity (-)).(a, b)) ca .- cb
247+
end
248+
end
249+
end

0 commit comments

Comments
 (0)