@@ -350,21 +350,34 @@ function _to_tiled_bc(bc::Broadcasted)
350350end
351351
352352# The generic broadcast kernel: evaluates the Broadcasted tree on tiles
353- function _tiled_bc_kernel_1d (dest:: TileArray{T, 1} , bc, tile_size) where T
354- bid = cuTile. bid (1 )
355- result = _eval_bc (bc, bid, tile_size)
356- result_converted = convert (cuTile. Tile{T}, result)
357- cuTile. store (dest, bid, result_converted)
358- return
359- end
353+ @generated function _tiled_bc_kernel (dest:: TileArray{T, N} , bc, tile_size, overflow_grids) where {T, N}
354+ body = Expr[]
355+ bid_vars = [Symbol (" bid_$d " ) for d in 1 : N]
360356
361- function _tiled_bc_kernel_2d (dest:: TileArray{T, 2} , bc, tile_size) where T
362- bid_x = cuTile. bid (1 )
363- bid_y = cuTile. bid (2 )
364- result = _eval_bc (bc, (bid_x, bid_y), tile_size)
365- result_converted = convert (cuTile. Tile{T}, result)
366- cuTile. store (dest, (bid_x, bid_y), result_converted)
367- return
357+ if N <= 3
358+ for d in 1 : N
359+ push! (body, :($ (bid_vars[d]) = cuTile. bid ($ d)))
360+ end
361+ else
362+ push! (body, :($ (bid_vars[1 ]) = cuTile. bid (1 )))
363+ push! (body, :($ (bid_vars[2 ]) = cuTile. bid (2 )))
364+ push! (body, :(_rem = cuTile. bid (3 ) - Int32 (1 )))
365+ for d in 3 : N
366+ if d < N
367+ push! (body, :($ (bid_vars[d]) = rem (_rem, Int32 (overflow_grids[$ (d- 2 )])) + Int32 (1 )))
368+ push! (body, :(_rem = fld (_rem, Int32 (overflow_grids[$ (d- 2 )]))))
369+ else
370+ push! (body, :($ (bid_vars[d]) = _rem + Int32 (1 )))
371+ end
372+ end
373+ end
374+
375+ idx = N == 1 ? bid_vars[1 ] : Expr (:tuple , bid_vars... )
376+ push! (body, :(result = _eval_bc (bc, $ idx, tile_size)))
377+ push! (body, :(result_converted = convert (cuTile. Tile{$ T}, result)))
378+ push! (body, :(cuTile. store (dest, $ idx, result_converted)))
379+ push! (body, :(return ))
380+ Expr (:block , body... )
368381end
369382
370383# Recursive tree evaluation inside kernel
@@ -391,17 +404,14 @@ function _tiled_broadcast!(dest::CuArray{T,N}, bc::Broadcasted; tile_size::Int=6
391404 dest_ta = TileArray (dest)
392405 tiled_bc = _to_tiled_bc (bc)
393406
394- if N == 1
395- ts = (tile_size,)
396- grid = (cld (size (dest, 1 ), tile_size),)
397- cuTile. launch (_tiled_bc_kernel_1d, grid, dest_ta, tiled_bc, Constant (ts))
398- elseif N == 2
399- ts = (tile_size, tile_size)
400- grid = (cld (size (dest, 1 ), tile_size), cld (size (dest, 2 ), tile_size))
401- cuTile. launch (_tiled_bc_kernel_2d, grid, dest_ta, tiled_bc, Constant (ts))
402- else
403- error (" Tiled broadcast not yet supported for $N dimensions" )
404- end
407+ ts = ntuple (i -> i <= min (N, 2 ) ? tile_size : 1 , N)
408+ grid = ntuple (i -> cld (size (dest, i), ts[i]), N)
409+
410+ launch_grid = N <= 3 ? grid : (grid[1 ], grid[2 ], prod (grid[i] for i in 3 : N))
411+ overflow = N > 3 ? grid[3 : end ] : ()
412+
413+ cuTile. launch (_tiled_bc_kernel, launch_grid, dest_ta, tiled_bc,
414+ Constant (ts), Constant (overflow))
405415end
406416
407417end
0 commit comments