@@ -436,11 +436,92 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
436436 B
437437end
438438
439+ # XXX : figure out how to do dynamically
440+ MAX_TILE_DIM = 16
439441
440442# # matrix multiplication
441443# legacy method
442444generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
443445 generic_matmatmul! (C, A, B, MulAddMul (a, b))
446+ function generic_matmatmul! (C:: AbstractGPUMatrix{R} , A:: AbstractGPUMatrix{T} , B:: AbstractGPUMatrix{S} , add:: MulAddMul ) where {T<: Number ,S<: Number ,R<: Number }
447+ N = size (A,1 )
448+ Q = size (A,2 )
449+ M = size (B,2 )
450+ if Q != size (B,1 )
451+ throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
452+ end
453+ if size (C,1 ) != N || size (C,2 ) != M
454+ throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((N,M)) " ))
455+ end
456+ if isempty (A) || isempty (B)
457+ return fill! (C, zero (R))
458+ end
459+
460+ @kernel unsafe_indices= true function coalesced_matmul_kernel! (
461+ output, @Const (input1), @Const (input2), N, Q, M,
462+ :: Val{BANK} = Val (1 ),
463+ ) where {BANK}
464+ grow, gcol = @index (Group, NTuple)
465+ tile_row, tile_col = @index (Local, NTuple)
466+
467+ TILE_DIM = @uniform @groupsize ()[1 ]
468+
469+ # +1 to avoid bank conflicts on shared memory
470+ tile1 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
471+ tile2 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
472+
473+ # private variable for tile output
474+ outval = @private R 1
475+ @inbounds outval[1 ] = - zero (R)
476+
477+ # number of tiles depends on inner dimension
478+ @uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
479+
480+ # loop over all tiles needed for this calculation
481+ for t in 0 : (NUM_TILES - 1 )
482+ I = (grow - 1 ) * TILE_DIM + tile_row
483+ J = (gcol - 1 ) * TILE_DIM + tile_col
484+
485+ # load inputs into tiles, with bounds checking for non-square matrices
486+ if I <= N && t * TILE_DIM + tile_col <= Q
487+ @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
488+ else
489+ @inbounds tile1[tile_row, tile_col] = zero (R)
490+ end
491+ if J <= M && t * TILE_DIM + tile_row <= Q
492+ @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
493+ else
494+ @inbounds tile2[tile_row, tile_col] = zero (R)
495+ end
496+
497+ # wait for all tiles to be loaded
498+ @synchronize
499+
500+ I = (grow - 1 ) * TILE_DIM + tile_row
501+ J = (gcol - 1 ) * TILE_DIM + tile_col
502+
503+ # calculate value of spot in output, use temporary value to allow for vectorization
504+ out = zero (R)
505+ @simd for k in 1 : TILE_DIM
506+ @inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
507+ end
508+ outval[1 ] += out
509+
510+ @synchronize
511+ end
512+
513+ I = (grow - 1 ) * TILE_DIM + tile_row
514+ J = (gcol - 1 ) * TILE_DIM + tile_col
515+
516+ # save if inbounds
517+ if I <= N && J <= M
518+ @inbounds output[I, J] = add (outval[1 ], output[I, J])
519+ end
520+ end
521+
522+ coalesced_matmul_kernel! (get_backend (C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange= map (x -> ceil (Int,x/ MAX_TILE_DIM)* MAX_TILE_DIM, size (C)))
523+ C
524+ end
444525function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
445526 if size (A,2 ) != size (B,1 )
446527 throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
0 commit comments