Skip to content

Commit 473b7d5

Browse files
authored
Support reductions without dims arg. (#118)
1 parent 130779c commit 473b7d5

3 files changed

Lines changed: 181 additions & 1 deletion

File tree

src/language/operations.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,16 @@ end
200200
store(arr, (index,), tile; kwargs...)
201201
end
202202

203+
# Scalar value → wrap in 1-element tile and store
204+
@inline function store(arr::TileArray{T}, index, val::T; kwargs...) where {T}
205+
shape = ntuple(_ -> 1, Val(ndims(arr)))
206+
tile = reshape(Intrinsics.from_scalar(val, Tuple{}), shape)
207+
store(arr, index, tile; kwargs...)
208+
end
209+
@inline function store(arr::TileArray{T}, index::Integer, val::T; kwargs...) where {T}
210+
store(arr, (index,), val; kwargs...)
211+
end
212+
203213
@inline function _store_reshaped(arr::TileArray{T}, tile::Tile{T},
204214
order, latency, allow_tma, indices::NTuple{<:Any, <:Integer}) where {T}
205215
tv = Intrinsics.make_tensor_view(arr)
@@ -221,7 +231,7 @@ end
221231
# would DCE the entire call as a pure function with unused result.
222232
Base.Experimental.@consistent_overlay cuTileMethodTable function Base.setindex!(arr::TileArray{T, N}, val::T, indices::Vararg{Integer, N}) where {T, N}
223233
shape = ntuple(_ -> 1, Val(N))
224-
tile = reshape(Intrinsics.from_scalar(val, Val(Tuple{})), shape)
234+
tile = reshape(Intrinsics.from_scalar(val, Tuple{}), shape)
225235
store(arr, indices, tile)
226236
return
227237
end
@@ -697,6 +707,16 @@ Reduced dimensions become size 1.
697707
@inline Base.minimum(tile::Tile{T,S}; dims) where {T<:Number, S} =
698708
reduce(min, tile; dims, init=typemax(T))
699709

710+
# sum/prod/max/min without dims — reduce all dimensions, return scalar T
711+
# Recursive: reduce dim 1, dropdims, recurse. Base case: 0D tile → scalar.
712+
for f in (:sum, :prod, :maximum, :minimum)
713+
@eval @inline Base.$f(tile::Tile{T,Tuple{}}) where {T<:Number} =
714+
Intrinsics.to_scalar(tile)
715+
716+
@eval @inline Base.$f(tile::Tile{T,S}) where {T<:Number, S<:Tuple{Any,Vararg}} =
717+
$f(dropdims($f(tile; dims=1); dims=1))
718+
end
719+
700720
"""
701721
any(tile::Tile{Bool,S}; dims) -> Tile{Bool, reduced_shape}
702722
@@ -730,6 +750,22 @@ n_positive = count(tile .> 0.0f0; dims=1)
730750
sum(convert(Tile{Int32}, tile); dims)
731751
end
732752

753+
# any/all without dims — return scalar Bool
754+
for f in (:any, :all)
755+
@eval @inline Base.$f(tile::Tile{Bool,Tuple{}}) =
756+
Intrinsics.to_scalar(tile)
757+
758+
@eval @inline Base.$f(tile::Tile{Bool,S}) where {S<:Tuple{Any,Vararg}} =
759+
$f(dropdims($f(tile; dims=1); dims=1))
760+
end
761+
762+
# count without dims — return scalar Int32
763+
# count reduces to Int32 (not Bool), so after first dim use sum for remaining.
764+
@inline Base.count(tile::Tile{Bool,Tuple{}}) =
765+
Intrinsics.to_scalar(tile)
766+
@inline Base.count(tile::Tile{Bool,S}) where {S<:Tuple{Any,Vararg}} =
767+
sum(dropdims(count(tile; dims=1); dims=1))
768+
733769
"""
734770
argmax(tile::Tile{T,S}; dims) -> Tile{Int32, reduced_shape}
735771

test/execution/basic.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,22 @@ end
151151
end
152152
end
153153

154+
@testset "scalar store" begin
155+
function scalar_store_1d(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, tileSz::Int)
156+
tile = ct.load(a, ct.bid(1), (tileSz,))
157+
ct.store(b, ct.bid(1), sum(tile))
158+
return nothing
159+
end
160+
161+
sz = 32; N = 1024
162+
a = CUDA.rand(Float32, N)
163+
b = CUDA.zeros(Float32, cld(N, sz))
164+
ct.launch(scalar_store_1d, cld(N, sz), a, b, ct.Constant(sz))
165+
166+
a_cpu = reshape(Array(a), sz, :)
167+
@test Array(b) vec(sum(a_cpu; dims=1)) rtol=1e-3
168+
end
169+
154170
@testset "transpose" begin
155171
function transpose_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})
156172
bidx = ct.bid(1)

test/execution/reductions.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,3 +567,131 @@ end
567567
@test Array(b_min_rand)[row, 1] == expected_min
568568
end
569569
end
570+
571+
@testset "sum without dims (1D)" begin
572+
function sum_no_dims_1d(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, tileSz::Int)
573+
tile = ct.load(a, ct.bid(1), (tileSz,))
574+
b[ct.bid(1)] = sum(tile)
575+
return nothing
576+
end
577+
578+
sz = 32; N = 1024
579+
a = CUDA.rand(Float32, N)
580+
b = CUDA.zeros(Float32, cld(N, sz))
581+
ct.launch(sum_no_dims_1d, cld(N, sz), a, b, ct.Constant(sz))
582+
583+
a_cpu = reshape(Array(a), sz, :)
584+
@test Array(b) vec(sum(a_cpu; dims=1)) rtol=1e-3
585+
end
586+
587+
@testset "sum without dims (2D)" begin
588+
function sum_no_dims_2d(a::ct.TileArray{Float32,2}, b::ct.TileArray{Float32,1})
589+
pid = ct.bid(1)
590+
tile = ct.load(a, (pid, 1), (1, 128))
591+
b[pid] = sum(tile)
592+
return nothing
593+
end
594+
595+
m, n = 64, 128
596+
a = CUDA.rand(Float32, m, n)
597+
b = CUDA.zeros(Float32, m)
598+
ct.launch(sum_no_dims_2d, m, a, b)
599+
600+
a_cpu = Array(a)
601+
for i in 1:m
602+
@test Array(b)[i] sum(a_cpu[i, :]) rtol=1e-3
603+
end
604+
end
605+
606+
@testset "any without dims (1D)" begin
607+
function any_no_dims_1d(a::ct.TileArray{Float32,1}, b::ct.TileArray{Int32,1}, tileSz::Int)
608+
tile = ct.load(a, ct.bid(1), (tileSz,))
609+
# Store as Int32 since we can't store Bool scalars directly
610+
b[ct.bid(1)] = Int32(any(tile .> 0.0f0))
611+
return nothing
612+
end
613+
614+
sz = 32; N = 1024
615+
a = CUDA.rand(Float32, N) .- 0.5f0 # some positive, some negative
616+
b = CUDA.zeros(Int32, cld(N, sz))
617+
ct.launch(any_no_dims_1d, cld(N, sz), a, b, ct.Constant(sz))
618+
619+
a_cpu = reshape(Array(a), sz, :)
620+
for i in 1:cld(N, sz)
621+
@test Array(b)[i] == Int32(any(a_cpu[:, i] .> 0.0f0))
622+
end
623+
end
624+
625+
@testset "all without dims (1D)" begin
626+
function all_no_dims_1d(a::ct.TileArray{Float32,1}, b::ct.TileArray{Int32,1}, tileSz::Int)
627+
tile = ct.load(a, ct.bid(1), (tileSz,))
628+
b[ct.bid(1)] = Int32(all(tile .> 0.0f0))
629+
return nothing
630+
end
631+
632+
sz = 32; N = 1024
633+
a = CUDA.rand(Float32, N) # all positive
634+
b = CUDA.zeros(Int32, cld(N, sz))
635+
ct.launch(all_no_dims_1d, cld(N, sz), a, b, ct.Constant(sz))
636+
637+
a_cpu = reshape(Array(a), sz, :)
638+
for i in 1:cld(N, sz)
639+
@test Array(b)[i] == Int32(all(a_cpu[:, i] .> 0.0f0))
640+
end
641+
end
642+
643+
@testset "sum without dims (3D)" begin
644+
function sum_no_dims_3d(a::ct.TileArray{Float32,3}, b::ct.TileArray{Float32,1})
645+
pid = ct.bid(1)
646+
tile = ct.load(a, (pid, 1, 1), (1, 8, 16))
647+
b[pid] = sum(tile)
648+
return nothing
649+
end
650+
651+
d1, d2, d3 = 4, 8, 16
652+
a = CUDA.rand(Float32, d1, d2, d3)
653+
b = CUDA.zeros(Float32, d1)
654+
ct.launch(sum_no_dims_3d, d1, a, b)
655+
656+
a_cpu = Array(a)
657+
for i in 1:d1
658+
@test Array(b)[i] sum(a_cpu[i, :, :]) rtol=1e-3
659+
end
660+
end
661+
662+
@testset "maximum without dims (3D)" begin
663+
function maximum_no_dims_3d(a::ct.TileArray{Float32,3}, b::ct.TileArray{Float32,1})
664+
pid = ct.bid(1)
665+
tile = ct.load(a, (pid, 1, 1), (1, 8, 16))
666+
b[pid] = maximum(tile)
667+
return nothing
668+
end
669+
670+
d1, d2, d3 = 4, 8, 16
671+
a = CUDA.rand(Float32, d1, d2, d3)
672+
b = CUDA.zeros(Float32, d1)
673+
ct.launch(maximum_no_dims_3d, d1, a, b)
674+
675+
a_cpu = Array(a)
676+
for i in 1:d1
677+
@test Array(b)[i] maximum(a_cpu[i, :, :]) rtol=1e-3
678+
end
679+
end
680+
681+
@testset "count without dims (1D)" begin
682+
function count_no_dims_1d(a::ct.TileArray{Float32,1}, b::ct.TileArray{Int32,1}, tileSz::Int)
683+
tile = ct.load(a, ct.bid(1), (tileSz,))
684+
b[ct.bid(1)] = count(tile .> 0.0f0)
685+
return nothing
686+
end
687+
688+
sz = 32; N = 1024
689+
a = CUDA.rand(Float32, N) .- 0.5f0
690+
b = CUDA.zeros(Int32, cld(N, sz))
691+
ct.launch(count_no_dims_1d, cld(N, sz), a, b, ct.Constant(sz))
692+
693+
a_cpu = reshape(Array(a), sz, :)
694+
for i in 1:cld(N, sz)
695+
@test Array(b)[i] == count(a_cpu[:, i] .> 0.0f0)
696+
end
697+
end

0 commit comments

Comments
 (0)