Skip to content

Commit e93c641

Browse files
authored
Fix scalar indexing on TileArrays and add codegen test (#94)
1 parent 8a6350c commit e93c641

3 files changed

Lines changed: 53 additions & 1 deletion

File tree

src/language/operations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ end
135135
@overlay function Base.getindex(arr::TileArray{T, N}, indices::Vararg{Integer, N}) where {T, N}
136136
tv = Intrinsics.make_tensor_view(arr)
137137
shape = ntuple(_ -> 1, Val(N))
138-
pv = Intrinsics.make_partition_view(tv, Val(shape), PaddingMode.Undetermined)
138+
pv = Intrinsics.make_partition_view(tv, shape, PaddingMode.Undetermined, nothing)
139139
tile = Intrinsics.load_partition_view(pv, nothing, true, promote(indices...) .- One())
140140
Intrinsics.to_scalar(reshape(tile, ()))
141141
end

test/codegen/operations.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,21 @@
15291529
end
15301530
end
15311531

1532+
@testset "TileArray scalar getindex" begin
1533+
@test @filecheck begin
1534+
@check_label "entry"
1535+
code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,1,spec1d}}) do lengths, out
1536+
bid = ct.bid(1)
1537+
@check "make_partition_view"
1538+
@check "load_view_tko"
1539+
@check "reshape"
1540+
len = lengths[bid]
1541+
ct.store(out, bid, ct.broadcast_to(ct.Tile(len), (16,)))
1542+
return
1543+
end
1544+
end
1545+
end
1546+
15321547
@testset "num_tiles helper" begin
15331548
spec = ct.ArraySpec{2}(16, true)
15341549
@test @filecheck begin

test/execution/basic.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,3 +1098,40 @@ end
10981098

10991099
@test Array(b) Array(a)
11001100
end
1101+
1102+
@testset "scalar indexing as loop bound" begin
1103+
function scalar_index_loop_kernel(data::ct.TileArray{Float32,1},
1104+
lengths::ct.TileArray{Int32,1},
1105+
out::ct.TileArray{Float32,1})
1106+
bid = ct.bid(1)
1107+
len = lengths[bid]
1108+
acc = ct.zeros((16,), Float32)
1109+
j = Int32(1)
1110+
while j <= len
1111+
tile = ct.load(data, j, (16,))
1112+
acc = acc .+ tile
1113+
j += Int32(1)
1114+
end
1115+
ct.store(out, bid, acc)
1116+
return
1117+
end
1118+
1119+
# 3 blocks, each sums a different number of tiles
1120+
n_tiles = Int32[2, 3, 1]
1121+
data = CUDA.rand(Float32, 48) # 3 tiles of 16
1122+
lengths = CuArray(n_tiles)
1123+
out = CUDA.zeros(Float32, 48)
1124+
1125+
ct.launch(scalar_index_loop_kernel, 3, data, lengths, out)
1126+
1127+
data_cpu = Array(data)
1128+
out_cpu = Array(out)
1129+
for bid in 1:3
1130+
expected = zeros(Float32, 16)
1131+
for j in 1:n_tiles[bid]
1132+
expected .+= data_cpu[(j-1)*16+1 : j*16]
1133+
end
1134+
@test out_cpu[(bid-1)*16+1 : bid*16] expected
1135+
end
1136+
end
1137+

0 commit comments

Comments
 (0)