Skip to content

Commit 5d9ab82

Browse files
committed
Faster matmul
1 parent 5a83c70 commit 5d9ab82

1 file changed

Lines changed: 81 additions & 0 deletions

File tree

src/host/linalg.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,11 +436,92 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
436436
B
437437
end
438438

439+
# XXX: figure out how to do dynamically
440+
MAX_TILE_DIM = 16
439441

440442
## matrix multiplication
441443
# legacy method
442444
generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) =
443445
generic_matmatmul!(C, A, B, MulAddMul(a, b))
446+
function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B::AbstractGPUMatrix{S}, add::MulAddMul) where {T<:Number,S<:Number,R<:Number}
447+
N = size(A,1)
448+
Q = size(A,2)
449+
M = size(B,2)
450+
if Q != size(B,1)
451+
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
452+
end
453+
if size(C,1) != N || size(C,2) != M
454+
throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((N,M))"))
455+
end
456+
if isempty(A) || isempty(B)
457+
return fill!(C, zero(R))
458+
end
459+
460+
@kernel unsafe_indices=true function coalesced_matmul_kernel!(
461+
output, @Const(input1), @Const(input2), N, Q, M,
462+
::Val{BANK} = Val(1),
463+
) where {BANK}
464+
grow, gcol = @index(Group, NTuple)
465+
tile_row, tile_col = @index(Local, NTuple)
466+
467+
TILE_DIM = @uniform @groupsize()[1]
468+
469+
# +1 to avoid bank conflicts on shared memory
470+
tile1 = @localmem(R, (TILE_DIM + BANK, TILE_DIM))
471+
tile2 = @localmem(R, (TILE_DIM + BANK, TILE_DIM))
472+
473+
# private variable for tile output
474+
outval = @private R 1
475+
@inbounds outval[1] = -zero(R)
476+
477+
# number of tiles depends on inner dimension
478+
@uniform NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM)
479+
480+
# loop over all tiles needed for this calculation
481+
for t in 0:(NUM_TILES - 1)
482+
I = (grow - 1) * TILE_DIM + tile_row
483+
J = (gcol - 1) * TILE_DIM + tile_col
484+
485+
# load inputs into tiles, with bounds checking for non-square matrices
486+
if I <= N && t * TILE_DIM + tile_col <= Q
487+
@inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
488+
else
489+
@inbounds tile1[tile_row, tile_col] = zero(R)
490+
end
491+
if J <= M && t * TILE_DIM + tile_row <= Q
492+
@inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
493+
else
494+
@inbounds tile2[tile_row, tile_col] = zero(R)
495+
end
496+
497+
# wait for all tiles to be loaded
498+
@synchronize
499+
500+
I = (grow - 1) * TILE_DIM + tile_row
501+
J = (gcol - 1) * TILE_DIM + tile_col
502+
503+
# calculate value of spot in output, use temporary value to allow for vectorization
504+
out = zero(R)
505+
@simd for k in 1:TILE_DIM
506+
@inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
507+
end
508+
outval[1] += out
509+
510+
@synchronize
511+
end
512+
513+
I = (grow - 1) * TILE_DIM + tile_row
514+
J = (gcol - 1) * TILE_DIM + tile_col
515+
516+
# save if inbounds
517+
if I <= N && J <= M
518+
@inbounds output[I, J] = add(outval[1], output[I, J])
519+
end
520+
end
521+
522+
coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C)))
523+
C
524+
end
444525
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}
445526
if size(A,2) != size(B,1)
446527
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))

0 commit comments

Comments
 (0)