diff --git a/README.md b/README.md index aad1066c..5fb8705a 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ vector_size = 2^20 tile_size = 16 blocks = cld(vector_size, tile_size) -grid = (blocks, 1, 1) +grid = (blocks, 1, 1) a, b = CUDA.rand(Float32, vector_size), CUDA.rand(Float32, vector_size) c = CUDA.zeros(Float32, vector_size) @@ -232,7 +232,6 @@ uses standard Julia syntax and is overlaid on `Base`. cuTile.jl follows Julia conventions, which differ from the Python API in several ways: - ### Kernel definition syntax Kernels don't need a decorator, but do have to return `nothing`: @@ -511,6 +510,36 @@ ct.store(arr, (i, j), t) ``` +## Host-level operations + +cuTile.jl also provides a limited set of host-level APIs to use cuTile without +writing custom kernels. For example, for element-wise operations on `CuArray`s, +cuTile can automatically generate and launch a fused kernel using Julia's +broadcast machinery: + +```julia +using CUDA +import cuTile as ct + +A = CUDA.rand(Float32, 1024) +B = CUDA.rand(Float32, 1024) +C = CUDA.zeros(Float32, 1024) + +# Wrap arrays in Tiled() to route through cuTile +ct.Tiled(C) .= ct.Tiled(A) .+ ct.Tiled(B) + +# Or use the @. macro for convenience +ct.@. C = A + sin(B) + +# Allocating form (returns a new CuArray) +D = ct.@. A + B +``` + +The entire broadcast expression is fused into a single cuTile kernel. Tile sizes +are automatically chosen based on array dimensions (power-of-2, budget-based). +Works with 1D through N-dimensional arrays. + + ## Acknowledgments cuTile.jl is inspired by [cuTile-Python](https://github.com/NVIDIA/cutile-python/), diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 0b80c06a..6814b878 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -9,9 +9,12 @@ using CompilerCaching: CacheView, method_instance, results import Core.Compiler as CC -using CUDA: CuModule, CuFunction, cudacall, device, capability +using CUDA: CuArray, CuModule, CuFunction, cudacall, device, capability using CUDA_Compiler_jll +import Base.Broadcast: BroadcastStyle +import CUDA: CuArrayStyle + public launch function run_and_collect(cmd) @@ -255,4 +258,7 @@ Other values pass through unchanged. to_tile_arg(x) = x to_tile_arg(arr::AbstractArray) = TileArray(arr) +# Tiled Broadcast — TiledStyle wins over CuArrayStyle +BroadcastStyle(::cuTile.TiledStyle{N}, ::CuArrayStyle{M}) where {N,M} = cuTile.TiledStyle{max(N,M)}() + end diff --git a/src/broadcast.jl b/src/broadcast.jl new file mode 100644 index 00000000..12c84e8a --- /dev/null +++ b/src/broadcast.jl @@ -0,0 +1,204 @@ +import Base.Broadcast: BroadcastStyle, Broadcasted + +#============================================================================= + Tiled wrapper — routes broadcast expressions through cuTile kernels +=============================================================================# + +""" + Tiled(x) + +Wrapper that routes broadcast expressions through cuTile kernels. + + Tiled(B) .= A .+ A + +Uses Julia's `Base.Broadcast` fusion machinery to build a `Broadcasted` tree, +then dispatches to a generic cuTile kernel that evaluates the tree on tiles. +""" +struct Tiled{A <: AbstractArray} + parent::A +end +Tiled(x) = x # passthrough for non-arrays (Numbers, etc.) +Base.parent(t::Tiled) = t.parent +Base.axes(t::Tiled) = axes(parent(t)) +Base.size(t::Tiled) = size(parent(t)) +Base.ndims(::Tiled{A}) where A = ndims(A) +Base.eltype(::Tiled{A}) where A = eltype(A) +Base.Broadcast.broadcastable(t::Tiled) = t + +# Walk dotted AST, wrap value-position leaves in Tiled() +_wrap_tiled(x) = x # literals pass through +_wrap_tiled(s::Symbol) = :($Tiled($s)) +function _wrap_tiled(ex::Expr) + if ex.head === :.= + Expr(:.=, _wrap_tiled(ex.args[1]), _wrap_tiled(ex.args[2])) + elseif ex.head === :. && length(ex.args) == 2 && + ex.args[2] isa Expr && ex.args[2].head === :tuple + # f.(args...) — wrap args, NOT function position + new_args = map(_wrap_tiled, ex.args[2].args) + Expr(:., ex.args[1], Expr(:tuple, new_args...)) + else + Expr(ex.head, map(_wrap_tiled, ex.args)...) + end +end + +""" + @. expr + +Like `Base.@.` but wraps every value-position leaf in `Tiled()`, routing +the broadcast through cuTile kernels. + + using cuTile; const ct = cuTile + ct.@. C = A + sin(B) + # equivalent to: Tiled(C) .= Tiled(A) .+ sin.(Tiled(B)) +""" +macro __dot__(ex) + esc(_wrap_tiled(Base.Broadcast.__dot__(ex))) +end + +#============================================================================= + TiledStyle — routes broadcast through cuTile kernels +=============================================================================# + +struct TiledStyle{N} <: BroadcastStyle end +TiledStyle{M}(::Val{N}) where {N,M} = TiledStyle{N}() + +BroadcastStyle(::Type{<:Tiled{A}}) where A = TiledStyle{ndims(A)}() + +# TiledStyle wins over DefaultArrayStyle +BroadcastStyle(::TiledStyle{N}, ::Base.Broadcast.DefaultArrayStyle{M}) where {N,M} = TiledStyle{max(N,M)}() +BroadcastStyle(::TiledStyle{N}, ::TiledStyle{M}) where {N,M} = TiledStyle{max(N,M)}() + +#============================================================================= + materialize! and copy — dispatch to _tiled_broadcast! +=============================================================================# + +function Base.Broadcast.materialize!(dest::Tiled, bc::Broadcasted) + _tiled_broadcast!(parent(dest), bc) + return dest +end + +function Base.copy(bc::Broadcasted{TiledStyle{N}}) where N + arr = @something _find_tiled_array(bc) error("tiled broadcast requires at least one Tiled() argument") + ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args) + dest = similar(arr, ElType, axes(bc)) + _tiled_broadcast!(dest, bc) + return dest +end + +"""Find the first underlying array from a Tiled leaf in a Broadcasted tree.""" +_find_tiled_array(t::Tiled) = parent(t) +_find_tiled_array(x) = nothing +function _find_tiled_array(bc::Broadcasted) + for arg in bc.args + arr = _find_tiled_array(arg) + arr !== nothing && return arr + end + return nothing +end + +#============================================================================= + _tiled_broadcast! — generic AbstractArray implementation +=============================================================================# + +function _tiled_broadcast!(dest::AbstractArray{T,N}, bc::Broadcasted) where {T, N} + dest_ta = TileArray(dest) + tiled_bc = _to_tiled_bc(bc) + + ts = _compute_tile_sizes(size(dest)) + grid = ntuple(i -> cld(size(dest, i), ts[i]), N) + + launch_grid = N <= 3 ? grid : (grid[1], grid[2], prod(grid[i] for i in 3:N)) + overflow = N > 3 ? grid[3:end] : () + + launch(_tiled_bc_kernel, launch_grid, dest_ta, tiled_bc, + Constant(ts), Constant(overflow)) +end + +#============================================================================= + Generic tree walk — convert leaves to TileArrays +=============================================================================# + +_to_tiled_bc(t::Tiled) = TileArray(parent(t)) +_to_tiled_bc(arr::AbstractArray) = TileArray(arr) +_to_tiled_bc(x::Number) = x +_to_tiled_bc(x) = x # fallback for other types +function _to_tiled_bc(bc::Broadcasted) + new_args = map(_to_tiled_bc, bc.args) + Broadcasted{Nothing}(bc.f, new_args, nothing) +end + +#============================================================================= + Broadcast kernel — evaluates Broadcasted tree on tiles +=============================================================================# + +@generated function _tiled_bc_kernel(dest::TileArray{T, N}, bc, tile_size, overflow_grids) where {T, N} + body = Expr[] + bid_vars = [Symbol("bid_$d") for d in 1:N] + + if N <= 3 + for d in 1:N + push!(body, :($(bid_vars[d]) = cuTile.bid($d))) + end + else + push!(body, :($(bid_vars[1]) = cuTile.bid(1))) + push!(body, :($(bid_vars[2]) = cuTile.bid(2))) + push!(body, :(_rem = cuTile.bid(3) - Int32(1))) + for d in 3:N + if d < N + push!(body, :($(bid_vars[d]) = rem(_rem, Int32(overflow_grids[$(d-2)])) + Int32(1))) + push!(body, :(_rem = fld(_rem, Int32(overflow_grids[$(d-2)])))) + else + push!(body, :($(bid_vars[d]) = _rem + Int32(1))) + end + end + end + + idx = N == 1 ? bid_vars[1] : Expr(:tuple, bid_vars...) + push!(body, :(result = _eval_bc(bc, $idx, tile_size))) + push!(body, :(result_converted = convert(cuTile.Tile{$T}, result))) + push!(body, :(cuTile.store(dest, $idx, result_converted))) + push!(body, :(return)) + Expr(:block, body...) +end + +#============================================================================= + Recursive tree evaluation inside kernel +=============================================================================# + +@inline _eval_bc(arr::TileArray, bid, tile_size) = cuTile.load(arr, bid, tile_size) +@inline _eval_bc(x::Number, bid, tile_size) = x + +@inline function _eval_bc(bc::Broadcasted, bid, tile_size) + args = _eval_bc_args(bc.args, bid, tile_size) + # Use broadcast to get element-wise semantics (not direct call, which + # would dispatch to e.g. matmul for * on tiles) + broadcast(bc.f, args...) +end + +@inline _eval_bc_args(::Tuple{}, bid, tile_size) = () +@inline _eval_bc_args(args::Tuple, bid, tile_size) = + (_eval_bc(args[1], bid, tile_size), _eval_bc_args(Base.tail(args), bid, tile_size)...) + +#============================================================================= + Tile sizing +=============================================================================# + +""" + _compute_tile_sizes(dest_size; budget=4096) + +Distribute a total element budget greedily across dimensions, skipping singletons. +Each tile dimension is a power of 2, capped by the array size in that dimension. +""" +function _compute_tile_sizes(dest_size::NTuple{N,Int}; budget::Int=4096) where N + ts = ones(Int, N) + remaining = budget + for i in 1:N + s = dest_size[i] + s == 1 && continue + t = prevpow(2, min(remaining, s)) + ts[i] = t + remaining = remaining ÷ t + remaining < 2 && break + end + return NTuple{N,Int}(ts) +end diff --git a/src/cuTile.jl b/src/cuTile.jl index 2e5fe398..e2d75cf3 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -38,7 +38,10 @@ include("language/math.jl") include("language/operations.jl") include("language/atomics.jl") -public launch, ByTarget, @compiler_options +# Host-level abstractions +include("broadcast.jl") + +public launch, Tiled, ByTarget, @compiler_options, @. launch(args...) = error("Please import CUDA.jl before using `cuTile.launch`.") end # module cuTile diff --git a/test/execution/atomics.jl b/test/device/atomics.jl similarity index 100% rename from test/execution/atomics.jl rename to test/device/atomics.jl diff --git a/test/execution/broadcast.jl b/test/device/broadcast.jl similarity index 100% rename from test/execution/broadcast.jl rename to test/device/broadcast.jl diff --git a/test/execution/control_flow.jl b/test/device/control_flow.jl similarity index 100% rename from test/execution/control_flow.jl rename to test/device/control_flow.jl diff --git a/test/execution/core.jl b/test/device/core.jl similarity index 100% rename from test/execution/core.jl rename to test/device/core.jl diff --git a/test/execution/hints.jl b/test/device/hints.jl similarity index 100% rename from test/execution/hints.jl rename to test/device/hints.jl diff --git a/test/execution/integration.jl b/test/device/integration.jl similarity index 100% rename from test/execution/integration.jl rename to test/device/integration.jl diff --git a/test/execution/math.jl b/test/device/math.jl similarity index 100% rename from test/execution/math.jl rename to test/device/math.jl diff --git a/test/execution/reductions.jl b/test/device/reductions.jl similarity index 100% rename from test/execution/reductions.jl rename to test/device/reductions.jl diff --git a/test/execution/tile.jl b/test/device/tile.jl similarity index 100% rename from test/execution/tile.jl rename to test/device/tile.jl diff --git a/test/execution/types.jl b/test/device/types.jl similarity index 100% rename from test/execution/types.jl rename to test/device/types.jl diff --git a/test/host/broadcast.jl b/test/host/broadcast.jl new file mode 100644 index 00000000..64a625cc --- /dev/null +++ b/test/host/broadcast.jl @@ -0,0 +1,115 @@ +using CUDA + +@testset "Tiled broadcast" begin + @testset "1D element-wise" begin + n = 1024 + A = CUDA.rand(Float32, n) + B = CUDA.rand(Float32, n) + C = CUDA.zeros(Float32, n) + ct.Tiled(C) .= ct.Tiled(A) .+ ct.Tiled(B) + @test Array(C) ≈ Array(A) .+ Array(B) + end + + @testset "fused multi-op" begin + n = 1024 + A = CUDA.rand(Float32, n) .+ 0.1f0 + C = CUDA.zeros(Float32, n) + ct.Tiled(C) .= ct.Tiled(A) .+ ct.Tiled(A) .* sin.(ct.Tiled(A)) + @test Array(C) ≈ Array(A) .+ Array(A) .* sin.(Array(A)) rtol=1e-5 + end + + @testset "scalar broadcast" begin + n = 1024 + A = CUDA.rand(Float32, n) + C = CUDA.zeros(Float32, n) + ct.Tiled(C) .= ct.Tiled(A) .+ 1.0f0 + @test Array(C) ≈ Array(A) .+ 1.0f0 + end + + @testset "2D element-wise" begin + m, n = 128, 256 + A = CUDA.rand(Float32, m, n) + B = CUDA.rand(Float32, m, n) + C = CUDA.zeros(Float32, m, n) + ct.Tiled(C) .= ct.Tiled(A) .+ ct.Tiled(B) + @test Array(C) ≈ Array(A) .+ Array(B) + end + + @testset "3D element-wise" begin + A = CUDA.rand(Float32, 64, 64, 4) + B = CUDA.rand(Float32, 64, 64, 4) + C = CUDA.zeros(Float32, 64, 64, 4) + ct.Tiled(C) .= ct.Tiled(A) .+ ct.Tiled(B) + @test Array(C) ≈ Array(A) .+ Array(B) + end + + @testset "ct.@. expands to Tiled" begin + ex = @macroexpand ct.@. C = A + B + # The macro should produce Tiled() wrapping, not plain dotted calls + @test occursin("Tiled", string(ex)) + end + + @testset "ct.@. in-place" begin + n = 1024 + A = CUDA.rand(Float32, n) + B = CUDA.rand(Float32, n) + C = CUDA.zeros(Float32, n) + ct.@. C = A + B + @test Array(C) ≈ Array(A) .+ Array(B) + end + + @testset "ct.@. with function" begin + n = 1024 + A = CUDA.rand(Float32, n) .+ 0.1f0 + C = CUDA.zeros(Float32, n) + ct.@. C = A + sin(A) + @test Array(C) ≈ Array(A) .+ sin.(Array(A)) rtol=1e-5 + end + + @testset "ct.@. with scalar" begin + n = 1024 + A = CUDA.rand(Float32, n) + C = CUDA.zeros(Float32, n) + ct.@. C = A + 2.0f0 + @test Array(C) ≈ Array(A) .+ 2.0f0 + end + + @testset "allocating copy" begin + n = 1024 + A = CUDA.rand(Float32, n) + B = CUDA.rand(Float32, n) + C = ct.Tiled(A) .+ ct.Tiled(B) + @test C isa CuArray + @test Array(C) ≈ Array(A) .+ Array(B) + end + + @testset "allocating ct.@." begin + n = 1024 + A = CUDA.rand(Float32, n) + B = CUDA.rand(Float32, n) + C = ct.@. A + B + @test C isa CuArray + @test Array(C) ≈ Array(A) .+ Array(B) + end + + @testset "leading singleton dim" begin + A = CUDA.rand(Float32, 1, 1024) + B = similar(A) + ct.Tiled(B) .= ct.Tiled(A) .+ 1.0f0 + @test Array(B) ≈ Array(A) .+ 1.0f0 + end + + @testset "double leading singleton" begin + A = CUDA.rand(Float32, 1, 1, 512) + B = similar(A) + ct.Tiled(B) .= ct.Tiled(A) .* 2.0f0 + @test Array(B) ≈ Array(A) .* 2.0f0 + end + + @testset "small leading dim" begin + A = CUDA.rand(Float32, 4, 1024) + B = similar(A) + ct.Tiled(B) .= ct.Tiled(A) .+ ct.Tiled(A) + @test Array(B) ≈ 2 .* Array(A) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 9330bab9..7e8a444d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,7 +43,7 @@ args = parse_args(ARGS) if filter_tests!(testsuite, args) cuda_functional = CUDA.functional() filter!(testsuite) do (test, _) - if startswith(test, "execution/") || startswith(test, "examples/") + if startswith(test, "device/") || startswith(test, "host/") || startswith(test, "examples/") return cuda_functional else return true