Skip to content

Commit 387d870

Browse files
authored
Expose entry hints through launch (#27)
1 parent b419422 commit 387d870

6 files changed

Lines changed: 285 additions & 17 deletions

File tree

ext/CUDAExt.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ using CUDA_Compiler_jll
99
public launch
1010

1111
# Compilation cache - stores CuFunction directly to avoid re-loading CuModule
12-
const _compilation_cache = Dict{Any, Any}() # (f, argtypes, sm_arch, opt_level) => CuFunction
12+
const _compilation_cache = Dict{Any, Any}() # (f, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction
1313

1414
"""
15-
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3)
15+
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing)
1616
1717
Compile and launch a kernel function with the given grid size and arguments.
1818
@@ -26,6 +26,8 @@ are expanded to their constituent ptr, sizes, and strides parameters.
2626
- `name`: Optional kernel name for debugging
2727
- `sm_arch`: Target GPU architecture (default: current device's capability)
2828
- `opt_level`: Optimization level 0-3 (default: 3)
29+
- `num_ctas`: Number of CTAs in a CGA, 1-16, must be power of 2 (default: nothing)
30+
- `occupancy`: Expected active CTAs per SM, 1-32 (default: nothing)
2931
3032
# Example
3133
```julia
@@ -51,7 +53,9 @@ cuTile.launch(vadd_kernel, 64, a, b, c)
5153
function cuTile.launch(@nospecialize(f), grid, args...;
5254
name::Union{String, Nothing}=nothing,
5355
sm_arch::String=default_sm_arch(),
54-
opt_level::Int=3)
56+
opt_level::Int=3,
57+
num_ctas::Union{Int, Nothing}=nothing,
58+
occupancy::Union{Int, Nothing}=nothing)
5559
# Convert CuArray -> TileArray (and other conversions)
5660
tile_args = map(to_tile_arg, args)
5761

@@ -62,10 +66,10 @@ function cuTile.launch(@nospecialize(f), grid, args...;
6266
kernel_name = name !== nothing ? name : string(nameof(f))
6367

6468
# Check compilation cache - returns CuFunction directly
65-
cache_key = (f, argtypes, sm_arch, opt_level)
69+
cache_key = (f, argtypes, sm_arch, opt_level, num_ctas, occupancy)
6670
cufunc = get(_compilation_cache, cache_key, nothing)
6771
if cufunc === nothing || cuTile.compile_hook[] !== nothing
68-
cubin = compile(f, argtypes; name, sm_arch, opt_level)
72+
cubin = compile(f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy)
6973
if cufunc === nothing
7074
cumod = CuModule(cubin)
7175
cufunc = CuFunction(cumod, kernel_name)
@@ -98,15 +102,18 @@ function cuTile.launch(@nospecialize(f), grid, args...;
98102
end
99103

100104
"""
101-
compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3) -> Vector{UInt8}
105+
compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing) -> Vector{UInt8}
102106
103107
Compile a Julia kernel function to a CUDA binary.
104108
"""
105109
function compile(@nospecialize(f), @nospecialize(argtypes);
106110
name::Union{String, Nothing}=nothing,
107111
sm_arch::String=default_sm_arch(),
108-
opt_level::Int=3)
109-
tile_bytecode = emit_tileir(f, argtypes; name)
112+
opt_level::Int=3,
113+
num_ctas::Union{Int, Nothing}=nothing,
114+
occupancy::Union{Int, Nothing}=nothing)
115+
tile_bytecode = emit_tileir(f, argtypes; name, sm_arch,
116+
num_ctas, occupancy)
110117

111118
# Dump bytecode if JULIA_CUTILE_DUMP_BYTECODE is set
112119
dump_dir = get(ENV, "JULIA_CUTILE_DUMP_BYTECODE", nothing)

src/bytecode/writer.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,3 +542,71 @@ function finalize_function!(func_buf::Vector{UInt8}, cb::CodeBuilder,
542542
encode_varint!(func_buf, length(cb.buf))
543543
append!(func_buf, cb.buf)
544544
end
545+
546+
#=============================================================================
547+
EntryHints: Kernel-level compilation hints
548+
=============================================================================#
549+
550+
"""
551+
Kernel-level compilation hints (num_ctas, occupancy).
552+
Encoded as a dictionary attribute in bytecode.
553+
"""
554+
@kwdef struct EntryHints
555+
num_ctas::Union{Int, Nothing} = nothing # 1, 2, 4, 8, 16
556+
occupancy::Union{Int, Nothing} = nothing # 1-32
557+
end
558+
559+
function validate_num_ctas(num_ctas::Union{Int, Nothing})
560+
isnothing(num_ctas) && return
561+
1 <= num_ctas <= 16 || throw(ArgumentError("num_ctas must be between 1 and 16, got $num_ctas"))
562+
ispow2(num_ctas) || throw(ArgumentError("num_ctas must be a power of 2, got $num_ctas"))
563+
end
564+
565+
function validate_occupancy(occupancy::Union{Int, Nothing})
566+
isnothing(occupancy) && return
567+
1 <= occupancy <= 32 || throw(ArgumentError("occupancy must be between 1 and 32, got $occupancy"))
568+
end
569+
570+
"""
571+
Encode EntryHints as OptimizationHints format.
572+
Returns raw bytes for entry_hints parameter or nothing.
573+
"""
574+
function encode_entry_hints(writer::BytecodeWriter, sm_arch::Union{String, Nothing}, hints::EntryHints)
575+
validate_num_ctas(hints.num_ctas)
576+
validate_occupancy(hints.occupancy)
577+
578+
# Build items list (only non-nothing values)
579+
items = Tuple{String, Int}[]
580+
isnothing(hints.num_ctas) || push!(items, ("num_cta_in_cga", hints.num_ctas))
581+
isnothing(hints.occupancy) || push!(items, ("occupancy", hints.occupancy))
582+
isempty(items) && return nothing
583+
584+
# Use default architecture if not specified and hints are present
585+
arch = @something sm_arch throw(ArgumentError("sm_arch must be specified when entry hints are present"))
586+
587+
buf = UInt8[]
588+
589+
# Start with OptimizationHints tag
590+
push!(buf, AttributeTag.OptimizationHints)
591+
592+
# Encode as architecture-specific dictionary
593+
# Format: num_archs, then for each arch: arch_id, dictionary
594+
encode_varint!(buf, 1) # 1 architecture
595+
596+
# Architecture string ID
597+
arch_id = writer.string_table[arch]
598+
encode_varint!(buf, arch_id.id)
599+
600+
# Encode dictionary
601+
push!(buf, AttributeTag.Dictionary)
602+
encode_varint!(buf, length(items))
603+
for (key, value) in items
604+
key_id = writer.string_table[key]
605+
encode_varint!(buf, key_id.id)
606+
push!(buf, AttributeTag.Integer)
607+
encode_typeid!(buf, I32(writer.type_table))
608+
encode_varint!(buf, UInt32(value))
609+
end
610+
611+
return buf
612+
end

src/compiler/codegen/kernel.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# kernel and argument handling
22

33
"""
4-
emit_kernel!(writer, func_buf, target; name, is_entry=true)
4+
emit_kernel!(writer, func_buf, target; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing)
55
66
Compile a TileTarget to Tile IR bytecode.
77
"""
88
function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
99
target::TileTarget;
1010
name::String = string(target.mi.def.name),
11-
is_entry::Bool = true)
11+
sm_arch::Union{String, Nothing} = nothing,
12+
is_entry::Bool = true,
13+
num_ctas::Union{Int, Nothing} = nothing,
14+
occupancy::Union{Int, Nothing} = nothing)
1215
ctx = CGCtx(writer, target)
1316
tt = ctx.tt
1417

@@ -58,8 +61,12 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
5861
push!(result_types, tile_type_for_julia!(ctx, target.rettype))
5962
end
6063

64+
# Create entry hints if provided
65+
entry_hints = encode_entry_hints(writer, sm_arch, EntryHints(; num_ctas, occupancy))
66+
6167
# Create function
62-
cb = add_function!(writer, func_buf, name, param_types, result_types; is_entry)
68+
cb = add_function!(writer, func_buf, name, param_types, result_types;
69+
is_entry, entry_hints)
6370
ctx.cb = cb
6471

6572
# Set up argument values

src/compiler/reflection.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
export code_tiled, @code_tiled
22

33
"""
4-
emit_tileir(f, argtypes; name=nothing) -> Vector{UInt8}
4+
emit_tileir(f, argtypes; name, sm_arch, num_ctas, occupancy) -> Vector{UInt8}
55
66
Compile a Julia function to Tile IR bytecode.
77
"""
88
function emit_tileir(@nospecialize(f), @nospecialize(argtypes);
9-
name::Union{String, Nothing} = nothing)
9+
name::Union{String, Nothing} = nothing,
10+
sm_arch::Union{String, Nothing} = nothing,
11+
num_ctas::Union{Int, Nothing} = nothing,
12+
occupancy::Union{Int, Nothing} = nothing)
1013
target = TileTarget(f, argtypes)
1114
kernel_name = name === nothing ? string(target.mi.def.name) : name
1215

@@ -15,7 +18,8 @@ function emit_tileir(@nospecialize(f), @nospecialize(argtypes);
1518
end
1619

1720
buf = write_bytecode!(1) do writer, func_buf
18-
emit_kernel!(writer, func_buf, target; name=kernel_name)
21+
emit_kernel!(writer, func_buf, target; name=kernel_name, sm_arch,
22+
num_ctas, occupancy)
1923
end
2024

2125
return buf
@@ -31,14 +35,14 @@ function disassemble_tileir(bytecode::Vector{UInt8})::String
3135
end
3236

3337
"""
34-
code_tiled(f, argtypes; name=nothing) -> String
38+
code_tiled(f, argtypes; name, sm_arch, num_ctas, occupancy) -> String
3539
3640
Return the CUDA Tile IR for a Julia function as a textual MLIR representation.
3741
Analogous to `code_typed` or `code_structured`.
3842
"""
3943
function code_tiled(@nospecialize(f), @nospecialize(argtypes);
40-
name::Union{String, Nothing} = nothing)
41-
bytecode = emit_tileir(f, argtypes; name)
44+
kwargs...)
45+
bytecode = emit_tileir(f, argtypes; kwargs...)
4246
disassemble_tileir(bytecode)
4347
end
4448

test/codegen.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,3 +1846,118 @@ end
18461846
end
18471847
end
18481848
end
1849+
1850+
#=============================================================================
1851+
Entry Hints (optimization_hints attribute)
1852+
=============================================================================#
1853+
1854+
@testset "Entry Hints" begin
1855+
# Common ArraySpecs for tests
1856+
spec1d = ct.ArraySpec{1}(16, true)
1857+
1858+
@testset "num_ctas only" begin
1859+
@test @filecheck begin
1860+
@check "optimization_hints=<sm_100 = {num_cta_in_cga = 4}>"
1861+
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=4) do a
1862+
pid = ct.bid(1)
1863+
t = ct.load(a, pid, (16,))
1864+
ct.store(a, pid, t)
1865+
return nothing
1866+
end
1867+
end
1868+
end
1869+
1870+
@testset "occupancy only" begin
1871+
@test @filecheck begin
1872+
@check "optimization_hints=<sm_100 = {occupancy = 8}>"
1873+
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=8) do a
1874+
pid = ct.bid(1)
1875+
t = ct.load(a, pid, (16,))
1876+
ct.store(a, pid, t)
1877+
return nothing
1878+
end
1879+
end
1880+
end
1881+
1882+
@testset "both hints" begin
1883+
@test @filecheck begin
1884+
@check "optimization_hints=<sm_120 = {num_cta_in_cga = 2, occupancy = 4}"
1885+
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_120", num_ctas=2, occupancy=4) do a
1886+
pid = ct.bid(1)
1887+
t = ct.load(a, pid, (16,))
1888+
ct.store(a, pid, t)
1889+
return nothing
1890+
end
1891+
end
1892+
end
1893+
1894+
@testset "no hints" begin
1895+
@test @filecheck begin
1896+
@check_not "optimization_hints"
1897+
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100") do a
1898+
pid = ct.bid(1)
1899+
t = ct.load(a, pid, (16,))
1900+
ct.store(a, pid, t)
1901+
return nothing
1902+
end
1903+
end
1904+
end
1905+
1906+
@testset "architecture parameter" begin
1907+
@test @filecheck begin
1908+
@check "optimization_hints=<sm_120 = {num_cta_in_cga = 4}>"
1909+
ct.code_tiled(Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_120", num_ctas=4) do a
1910+
pid = ct.bid(1)
1911+
t = ct.load(a, pid, (16,))
1912+
ct.store(a, pid, t)
1913+
return nothing
1914+
end
1915+
end
1916+
end
1917+
1918+
@testset "num_ctas validation" begin
1919+
# Too small
1920+
@test_throws "num_ctas must be between 1 and 16" begin
1921+
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=0)
1922+
end
1923+
1924+
# Too large
1925+
@test_throws "num_ctas must be between 1 and 16" begin
1926+
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=17)
1927+
end
1928+
1929+
# Not power of 2
1930+
@test_throws "num_ctas must be a power of 2" begin
1931+
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=3)
1932+
end
1933+
1934+
@test_throws "num_ctas must be a power of 2" begin
1935+
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas=5)
1936+
end
1937+
1938+
# Valid values should succeed
1939+
for num_ctas in [1, 2, 4, 8, 16]
1940+
bytecode = code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", num_ctas)
1941+
@test !isempty(bytecode)
1942+
end
1943+
end
1944+
1945+
@testset "occupancy validation" begin
1946+
# Too small
1947+
@test_throws "occupancy must be between 1 and 32" begin
1948+
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=0)
1949+
end
1950+
1951+
# Too large
1952+
@test_throws "occupancy must be between 1 and 32" begin
1953+
code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=33)
1954+
end
1955+
1956+
# Valid boundaries
1957+
bytecode1 = code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=1)
1958+
@test !isempty(bytecode1)
1959+
1960+
bytecode32 = code_tiled((a) -> nothing, Tuple{ct.TileArray{Float32, 1, spec1d}}; sm_arch="sm_100", occupancy=32)
1961+
@test !isempty(bytecode32)
1962+
end
1963+
end

0 commit comments

Comments
 (0)