Skip to content

Commit b8c9f19

Browse files
Add support for PermutedDimsArray (#48)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent f7170cf commit b8c9f19

2 files changed

Lines changed: 33 additions & 0 deletions

File tree

src/language/types.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ function TileArray(arr::AbstractArray{T, N}) where {T, N}
181181
TileArray(ptr, sizes, strides_val)
182182
end
183183

184+
function TileArray(arr::PermutedDimsArray{T, N}) where {T, N}
185+
ptr = reinterpret(Ptr{T}, pointer(parent(arr)))
186+
sizes = NTuple{N, Int32}(Int32.(size(arr)))
187+
strides_val = NTuple{N, Int32}(Int32.(strides(arr)))
188+
TileArray(ptr, sizes, strides_val)
189+
end
190+
184191

185192
"""
186193
Tile{T, Shape}

test/execution.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,32 @@ end
418418
end
419419
end
420420

421+
@testset "strided" begin
422+
@testset "PermutedDimsArray" begin
423+
function copy_kernel_2d(
424+
src::ct.TileArray{Float32, 2}, dst::ct.TileArray{Float32, 2},
425+
tile_x::ct.Constant{Int}, tile_y::ct.Constant{Int}
426+
)
427+
bid_x = ct.bid(1)
428+
bid_y = ct.bid(2)
429+
tile = ct.load(src, (bid_x, bid_y), (tile_x[], tile_y[]))
430+
ct.store(dst, (bid_x, bid_y), tile)
431+
return
432+
end
433+
434+
m, n = 64, 32
435+
tm, tn = 16, 16
436+
A = CuArray(Float32.(reshape(1:n*m, n, m)))
437+
P = PermutedDimsArray(A, (2, 1))
438+
out = CUDA.zeros(Float32, m, n)
439+
440+
grid = (cld(m, tm), cld(n, tn))
441+
ct.launch(copy_kernel_2d, grid, P, out, ct.Constant(tm), ct.Constant(tn))
442+
443+
@test out == permutedims(A, (2, 1))
444+
end
445+
end
446+
421447
@testset "extract" begin
422448
@testset "extract identity (0,0) full shape" begin
423449
function extract_identity_kernel(x::ct.TileArray{Float32,2}, y::ct.TileArray{Float32,2})

0 commit comments

Comments
 (0)