Skip to content

Commit 1a3958e

Browse files
Add num_worker_warps entry hint (#245)
* Add `num_worker_warps` entry hint * Update src/bytecode/writer.jl Co-authored-by: Tim Besard <tim.besard@gmail.com> --------- Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 53bd1a7 commit 1a3958e

10 files changed

Lines changed: 76 additions & 21 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ end
361361
| `num_ctas` | Number of CTAs in a CGA | Powers of 2 |
362362
| `occupancy` | Target concurrent CTAs per SM | 1–32 |
363363
| `opt_level` | Optimization level | 0–3 |
364+
| `num_worker_warps` | Worker warps per CTA in a warp-specialized kernel | 4 or 8 |
364365

365366
Values can be plain scalars or `ct.ByTarget(...)` for per-architecture dispatch.
366367
`ByTarget` maps compute capabilities to values, with an optional default:

src/bytecode/writer.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -741,17 +741,19 @@ function encode_load_store_hints_dict!(cb::CodeBuilder, hints::LoadStoreHints)
741741
end
742742

743743
"""
744-
Kernel-level compilation hints (num_ctas, occupancy).
744+
Kernel-level compilation hints (num_ctas, occupancy, num_worker_warps).
745745
Encoded as a dictionary attribute in bytecode.
746746
"""
747747
@kwdef struct EntryHints
748-
num_ctas::Union{Int, Nothing} = nothing # 1, 2, 4, 8, 16
749-
occupancy::Union{Int, Nothing} = nothing # 1-32
748+
num_ctas::Union{Int, Nothing} = nothing # 1, 2, 4, 8, 16
749+
occupancy::Union{Int, Nothing} = nothing # 1-32
750+
num_worker_warps::Union{Int, Nothing} = nothing # 4 or 8
750751
end
751752

752753
function encode_entry_hints(writer::BytecodeWriter, sm_arch::Union{VersionNumber, Nothing}, hints::EntryHints)
753754
validate_hint(:num_ctas, hints.num_ctas)
754755
validate_hint(:occupancy, hints.occupancy)
756+
validate_hint(:num_worker_warps, hints.num_worker_warps)
755757

756758
# CTA clusters (num_cta_in_cga > 1) are a Blackwell feature. Older tileiras
757759
# versions rejected the bytecode for non-Blackwell targets; tileiras 13.3
@@ -763,10 +765,16 @@ function encode_entry_hints(writer::BytecodeWriter, sm_arch::Union{VersionNumber
763765
"$(format_sm_arch(sm_arch))"))
764766
end
765767

768+
if hints.num_worker_warps !== nothing && writer.version < v"13.3"
769+
throw(ArgumentError(
770+
"num_worker_warps requires Tile IR bytecode v13.3+, got v$(writer.version)"))
771+
end
772+
766773
# Build items list (only non-nothing values)
767774
items = Tuple{String, Int}[]
768775
isnothing(hints.num_ctas) || push!(items, ("num_cta_in_cga", hints.num_ctas))
769776
isnothing(hints.occupancy) || push!(items, ("occupancy", hints.occupancy))
777+
isnothing(hints.num_worker_warps) || push!(items, ("num_worker_warps_per_cta", hints.num_worker_warps))
770778

771779
# Always emit optimization hints when sm_arch is specified, even with an empty
772780
# dict. Python cuTile emits `optimization_hints=<sm_NNN = {}>` unconditionally

src/compiler/codegen/kernel.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# kernel and argument handling
22

33
"""
4-
emit_kernel!(writer, func_buf, sci, rettype; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing, const_argtypes=nothing)
4+
emit_kernel!(writer, func_buf, sci, rettype; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing, num_worker_warps=nothing, const_argtypes=nothing)
55
66
Compile a StructuredIRCode to Tile IR bytecode.
77
@@ -17,6 +17,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
1717
is_entry::Bool = true,
1818
num_ctas::Union{Int, Nothing} = nothing,
1919
occupancy::Union{Int, Nothing} = nothing,
20+
num_worker_warps::Union{Int, Nothing} = nothing,
2021
cache::CacheView,
2122
const_argtypes::Union{Vector{Any}, Nothing} = nothing)
2223
tt = writer.type_table
@@ -82,7 +83,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
8283
end
8384

8485
# Create entry hints if provided
85-
entry_hints = encode_entry_hints(writer, sm_arch, EntryHints(; num_ctas, occupancy))
86+
entry_hints = encode_entry_hints(writer, sm_arch, EntryHints(; num_ctas, occupancy, num_worker_warps))
8687

8788
# Create function-level debug attribute
8889
func_debug_attr = make_func_debug_attr(debug_emitter, sci; linkage_name=name)

src/compiler/driver.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ const compile_hook = Ref{Union{Nothing,Function}}(nothing)
77
=============================================================================#
88

99
# Compilation options for cache sharding.
10-
# Hint fields (opt_level, num_ctas, occupancy) represent explicit overrides only;
11-
# `nothing` means "consult @compiler_options meta nodes in the IR during compilation."
10+
# Hint fields (opt_level, num_ctas, occupancy, num_worker_warps) represent explicit
11+
# overrides only; `nothing` means "consult @compiler_options meta nodes in the IR
12+
# during compilation."
1213
const CGOpts = @NamedTuple{
1314
sm_arch::Union{VersionNumber, Nothing},
1415
opt_level::Union{Int, Nothing},
1516
num_ctas::Union{Int, Nothing},
1617
occupancy::Union{Int, Nothing},
18+
num_worker_warps::Union{Int, Nothing},
1719
bytecode_version::VersionNumber
1820
}
1921

@@ -168,6 +170,8 @@ function emit_tile(sci::StructuredIRCode, rettype, kernel_meta::Dict{Symbol,Any}
168170
# Resolve hints: launch()/code_tiled() kwargs > @compiler_options meta > defaults
169171
resolved_num_ctas = resolve_hint(opts.num_ctas, kernel_meta, :num_ctas, opts.sm_arch)
170172
resolved_occupancy = resolve_hint(opts.occupancy, kernel_meta, :occupancy, opts.sm_arch)
173+
resolved_num_worker_warps = resolve_hint(opts.num_worker_warps, kernel_meta,
174+
:num_worker_warps, opts.sm_arch)
171175

172176
# Generate Tile IR bytecode
173177
bytecode = write_bytecode!(1; version=opts.bytecode_version) do writer, func_buf
@@ -176,6 +180,7 @@ function emit_tile(sci::StructuredIRCode, rettype, kernel_meta::Dict{Symbol,Any}
176180
sm_arch = opts.sm_arch,
177181
num_ctas = resolved_num_ctas,
178182
occupancy = resolved_occupancy,
183+
num_worker_warps = resolved_num_worker_warps,
179184
cache,
180185
const_argtypes
181186
)
@@ -307,6 +312,7 @@ function emit_tile!(cache::CacheView, mi::Core.MethodInstance,
307312
opt_level=unpack_hint(key.opt_level),
308313
num_ctas=unpack_hint(key.num_ctas),
309314
occupancy=unpack_hint(key.occupancy),
315+
num_worker_warps=unpack_hint(key.num_worker_warps),
310316
bytecode_version=unpack_version(key.bytecode_version)))
311317
bytecode = emit_tile(sci, rettype, kernel_meta;
312318
name=sanitize_name(string(mi.def.name)),

src/compiler/reflection.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ constant_eltype(::Type{Constant{T,V}}) where {T,V} = T
109109
constant_value(::Type{Constant{T,V}}) where {T,V} = V
110110

111111
"""
112-
code_tiled([io::IO], f, argtypes; sm_arch, opt_level, num_ctas, occupancy)
112+
code_tiled([io::IO], f, argtypes; sm_arch, opt_level, num_ctas, occupancy, num_worker_warps)
113113
114114
Print the CUDA Tile IR for a Julia function as a textual MLIR representation.
115115
Analogous to `code_llvm`/`code_native`. Calls the driver directly without
@@ -120,14 +120,15 @@ function code_tiled(io::IO, @nospecialize(f), @nospecialize(argtypes);
120120
opt_level::Union{Int, Nothing}=nothing,
121121
num_ctas::Union{Int, Nothing}=nothing,
122122
occupancy::Union{Int, Nothing}=nothing,
123+
num_worker_warps::Union{Int, Nothing}=nothing,
123124
bytecode_version::VersionNumber=cuTile.bytecode_version(),
124125
debuginfo::Bool=false,
125126
world::UInt=Base.get_world_counter())
126127
stripped, const_argtypes = process_const_argtypes(f, argtypes)
127128
mi = lookup_method_instance(f, stripped; world)
128129

129130
opts = CGOpts((sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy,
130-
bytecode_version=bytecode_version))
131+
num_worker_warps=num_worker_warps, bytecode_version=bytecode_version))
131132
cache = CacheView{CuTileResults}(:cuTile, world)
132133
ir, rettype = emit_julia(cache, mi; const_argtypes)
133134
sci, rettype, kernel_meta = emit_structured(ir, rettype)

src/language/operations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,7 +1441,7 @@ end
14411441
@compiler_options macro
14421442
=============================================================================#
14431443

1444-
const _COMPILER_OPTION_NAMES = Set([:num_ctas, :occupancy, :opt_level])
1444+
const _COMPILER_OPTION_NAMES = Set([:num_ctas, :occupancy, :opt_level, :num_worker_warps])
14451445

14461446
"""
14471447
@compiler_options key=val...
@@ -1450,7 +1450,7 @@ Specify per-architecture optimization hints inside a kernel function body.
14501450
Hints are embedded as `:meta` nodes and resolved at compile time based on
14511451
the target `sm_arch`.
14521452
1453-
Supported options: `num_ctas`, `occupancy`, `opt_level`.
1453+
Supported options: `num_ctas`, `occupancy`, `opt_level`, `num_worker_warps`.
14541454
14551455
Values can be plain scalars or `ByTarget(...)` for per-architecture dispatch.
14561456

src/language/types.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ Validate a kernel optimization hint value. Throws `ArgumentError` for invalid va
561561
- `num_ctas`: power of 2 in [1, 16]
562562
- `occupancy`: integer in [1, 32]
563563
- `opt_level`: integer in [0, 3]
564+
- `num_worker_warps`: either 4 or 8
564565
"""
565566
function validate_hint(key::Symbol, val)
566567
val === nothing && return
@@ -574,6 +575,9 @@ function validate_hint(key::Symbol, val)
574575
elseif key === :opt_level
575576
val isa Integer || throw(ArgumentError("opt_level must be an integer, got $(typeof(val))"))
576577
0 <= val <= 3 || throw(ArgumentError("opt_level must be between 0 and 3, got $val"))
578+
elseif key === :num_worker_warps
579+
val isa Integer || throw(ArgumentError("num_worker_warps must be an integer, got $(typeof(val))"))
580+
val in (4, 8) || throw(ArgumentError("num_worker_warps must be either 4 or 8, got $val"))
577581
else
578582
throw(ArgumentError("unknown hint key: $key"))
579583
end

src/launch.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Compiles a Julia function with `TileArray` arguments to Tile IR bytecode,
44
# runs `tileiras` to lower bytecode → CUBIN, loads the cubin into the active
55
# CUDA context, and launches it via `cudacall`. Compilation is cached per
6-
# `(MethodInstance, sm_arch, opt_level, num_ctas, occupancy, bytecode_version)`.
6+
# `(MethodInstance, sm_arch, opt_level, num_ctas, occupancy, num_worker_warps, bytecode_version)`.
77

88
using CUDACore: CUDACore, CuArray, CuModule, CuFunction, cudacall, device, capability,
99
AbstractBackend, AbstractKernel, kernel_convert, kernel_compile, PerDevice
@@ -88,7 +88,8 @@ CUDACore.kernel_compile(::TileBackend, f::F, tt::TT=Tuple{}; kwargs...) where {F
8888
@inline unpack_version(x::UInt16) = VersionNumber(Int(x >> 8), Int(x & 0xff))
8989

9090
# isbits sentinel codec for `Union{Int, Nothing}` hint fields (`opt_level`,
91-
# `num_ctas`, `occupancy`). `-1` is unused as a value, so we use it for `nothing`.
91+
# `num_ctas`, `occupancy`, `num_worker_warps`). `-1` is unused as a value, so we
92+
# use it for `nothing`.
9293
const _UNSET = -1
9394
@inline pack_hint(x::Union{Int, Nothing}) = x === nothing ? _UNSET : x
9495
@inline unpack_hint(x::Int) = x == _UNSET ? nothing : x
@@ -113,12 +114,14 @@ struct TileCacheKey
113114
opt_level::Int
114115
num_ctas::Int
115116
occupancy::Int
117+
num_worker_warps::Int
116118
end
117119
TileCacheKey(sm_arch::VersionNumber, bytecode_version::VersionNumber,
118120
opt_level::Union{Int, Nothing}, num_ctas::Union{Int, Nothing},
119-
occupancy::Union{Int, Nothing}) =
121+
occupancy::Union{Int, Nothing}, num_worker_warps::Union{Int, Nothing}) =
120122
TileCacheKey(pack_version(sm_arch), pack_version(bytecode_version),
121-
pack_hint(opt_level), pack_hint(num_ctas), pack_hint(occupancy))
123+
pack_hint(opt_level), pack_hint(num_ctas), pack_hint(occupancy),
124+
pack_hint(num_worker_warps))
122125

123126

124127
#=============================================================================
@@ -409,7 +412,7 @@ function emit_binary!(cache::CacheView, mi::Core.MethodInstance,
409412
sm_arch = unpack_version(cache.owner.sm_arch)
410413

411414
# Resolve opt_level here (not in emit_tile) because it's a tileiras flag, not bytecode.
412-
# num_ctas/occupancy are resolved in emit_tile because they're encoded in bytecode.
415+
# num_ctas/occupancy/num_worker_warps are resolved in emit_tile because they're encoded in bytecode.
413416
_, _, kernel_meta = res.julia_ir
414417
opt_level = something(resolve_hint(unpack_hint(cache.owner.opt_level),
415418
kernel_meta, :opt_level, sm_arch), 3)
@@ -516,7 +519,8 @@ end
516519

517520
"""
518521
cuTile.cufunction(f, tt=Tuple{}; sm_arch=nothing, opt_level=nothing,
519-
num_ctas=nothing, occupancy=nothing, name=nothing) -> TileKernel
522+
num_ctas=nothing, occupancy=nothing, num_worker_warps=nothing,
523+
name=nothing) -> TileKernel
520524
521525
Compile `f` for the cuTile backend. `tt` is the tuple of *converted*
522526
argument types (i.e. after `cuTileconvert`/`Adapt.adapt(KernelAdaptor(), …)`).
@@ -533,11 +537,13 @@ function cufunction(@nospecialize(f), tt::Type{<:Tuple}=Tuple{};
533537
opt_level::Union{Int, Nothing}=nothing,
534538
num_ctas::Union{Int, Nothing}=nothing,
535539
occupancy::Union{Int, Nothing}=nothing,
540+
num_worker_warps::Union{Int, Nothing}=nothing,
536541
name::Union{String, Nothing}=nothing)
537542
bytecode_version = check_tile_ir_support()
538543
resolved_sm_arch = sm_arch !== nothing ? sm_arch : default_sm_arch()
539544

540-
key = TileCacheKey(resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy)
545+
key = TileCacheKey(resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy,
546+
num_worker_warps)
541547

542548
# Single pass over `tt.parameters`: build the unwrapped argtypes tuple
543549
# (Constant{T,V} → T for MI lookup) and the const_argtypes vector
@@ -685,7 +691,7 @@ end
685691

686692
"""
687693
launch(f, grid, args...; sm_arch=nothing, opt_level=nothing,
688-
num_ctas=nothing, occupancy=nothing, name=nothing)
694+
num_ctas=nothing, occupancy=nothing, num_worker_warps=nothing, name=nothing)
689695
690696
Compile and launch a Tile IR kernel. `args` are converted via
691697
`cuTileconvert` (CuArray → TileArray, Type → Constant). Equivalent to
@@ -715,10 +721,11 @@ function launch(@nospecialize(f), grid, args...;
715721
opt_level::Union{Int, Nothing}=nothing,
716722
num_ctas::Union{Int, Nothing}=nothing,
717723
occupancy::Union{Int, Nothing}=nothing,
724+
num_worker_warps::Union{Int, Nothing}=nothing,
718725
name::Union{String, Nothing}=nothing)
719726
converted = map(cuTileconvert, args)
720727
tt = Tuple{map(Core.Typeof, converted)...}
721-
kernel = cufunction(f, tt; sm_arch, opt_level, num_ctas, occupancy, name)
728+
kernel = cufunction(f, tt; sm_arch, opt_level, num_ctas, occupancy, num_worker_warps, name)
722729
kernel(converted...; blocks=grid)
723730
return nothing
724731
end

src/precompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import REPL
4444
bv = bytecode_version()
4545
for sm_arch in [v"8.0", v"8.6", v"8.7", v"8.9",
4646
v"10.0", v"11.0", v"12.0", v"12.1"]
47-
key = TileCacheKey(sm_arch, bv, nothing, nothing, nothing)
47+
key = TileCacheKey(sm_arch, bv, nothing, nothing, nothing, nothing)
4848
compile(f, argtypes, const_argtypes, key)
4949
end
5050
return

test/device/hints.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,33 @@ end
7171
end
7272
end
7373

74+
@testset "launch with num_worker_warps" begin
75+
function vadd_kernel_worker_warps(a::ct.TileArray{Float32,1},
76+
b::ct.TileArray{Float32,1},
77+
c::ct.TileArray{Float32,1})
78+
pid = ct.bid(1)
79+
ta = ct.load(a, pid, (16,))
80+
tb = ct.load(b, pid, (16,))
81+
ct.store(c, pid, ta + tb)
82+
return nothing
83+
end
84+
85+
n = 1024
86+
a = CUDA.ones(Float32, n)
87+
b = CUDA.ones(Float32, n) .* 2
88+
c = CUDA.zeros(Float32, n)
89+
90+
if cuTile.bytecode_version() >= v"13.3"
91+
@cuda backend=cuTile blocks=64 num_worker_warps=8 vadd_kernel_worker_warps(a, b, c)
92+
@test Array(c) ones(Float32, n) .* 3
93+
else
94+
@test_throws "num_worker_warps requires" @cuda backend=cuTile blocks=64 num_worker_warps=8 vadd_kernel_worker_warps(a, b, c)
95+
end
96+
97+
# Invalid value rejected before compilation.
98+
@test_throws "num_worker_warps must be either 4 or 8" @cuda backend=cuTile blocks=64 num_worker_warps=3 vadd_kernel_worker_warps(a, b, c)
99+
end
100+
74101
end
75102

76103
@testset "Load / Store Optimization Hints" begin

0 commit comments

Comments
 (0)