33# Compiles a Julia function with `TileArray` arguments to Tile IR bytecode,
44# runs `tileiras` to lower bytecode → CUBIN, loads the cubin into the active
55# CUDA context, and launches it via `cudacall`. Compilation is cached per
6- # `(MethodInstance, sm_arch, opt_level, num_ctas, occupancy, bytecode_version)`.
6+ # `(MethodInstance, sm_arch, opt_level, num_ctas, occupancy, num_worker_warps, bytecode_version)`.
77
88using CUDACore: CUDACore, CuArray, CuModule, CuFunction, cudacall, device, capability,
99 AbstractBackend, AbstractKernel, kernel_convert, kernel_compile, PerDevice
@@ -88,7 +88,8 @@ CUDACore.kernel_compile(::TileBackend, f::F, tt::TT=Tuple{}; kwargs...) where {F
8888@inline unpack_version (x:: UInt16 ) = VersionNumber (Int (x >> 8 ), Int (x & 0xff ))
8989
9090# isbits sentinel codec for `Union{Int, Nothing}` hint fields (`opt_level`,
91- # `num_ctas`, `occupancy`). `-1` is unused as a value, so we use it for `nothing`.
91+ # `num_ctas`, `occupancy`, `num_worker_warps`). `-1` is unused as a value, so we
92+ # use it for `nothing`.
9293const _UNSET = - 1
9394@inline pack_hint (x:: Union{Int, Nothing} ) = x === nothing ? _UNSET : x
9495@inline unpack_hint (x:: Int ) = x == _UNSET ? nothing : x
@@ -113,12 +114,14 @@ struct TileCacheKey
113114 opt_level:: Int
114115 num_ctas:: Int
115116 occupancy:: Int
117+ num_worker_warps:: Int
116118end
117119TileCacheKey (sm_arch:: VersionNumber , bytecode_version:: VersionNumber ,
118120 opt_level:: Union{Int, Nothing} , num_ctas:: Union{Int, Nothing} ,
119- occupancy:: Union{Int, Nothing} ) =
121+ occupancy:: Union{Int, Nothing} , num_worker_warps :: Union{Int, Nothing} ) =
120122 TileCacheKey (pack_version (sm_arch), pack_version (bytecode_version),
121- pack_hint (opt_level), pack_hint (num_ctas), pack_hint (occupancy))
123+ pack_hint (opt_level), pack_hint (num_ctas), pack_hint (occupancy),
124+ pack_hint (num_worker_warps))
122125
123126
124127#= ============================================================================
@@ -409,7 +412,7 @@ function emit_binary!(cache::CacheView, mi::Core.MethodInstance,
409412 sm_arch = unpack_version (cache. owner. sm_arch)
410413
411414 # Resolve opt_level here (not in emit_tile) because it's a tileiras flag, not bytecode.
412- # num_ctas/occupancy are resolved in emit_tile because they're encoded in bytecode.
415+ # num_ctas/occupancy/num_worker_warps are resolved in emit_tile because they're encoded in bytecode.
413416 _, _, kernel_meta = res. julia_ir
414417 opt_level = something (resolve_hint (unpack_hint (cache. owner. opt_level),
415418 kernel_meta, :opt_level , sm_arch), 3 )
516519
517520"""
518521 cuTile.cufunction(f, tt=Tuple{}; sm_arch=nothing, opt_level=nothing,
519- num_ctas=nothing, occupancy=nothing, name=nothing) -> TileKernel
522+ num_ctas=nothing, occupancy=nothing, num_worker_warps=nothing,
523+ name=nothing) -> TileKernel
520524
521525Compile `f` for the cuTile backend. `tt` is the tuple of *converted*
522526argument types (i.e. after `cuTileconvert`/`Adapt.adapt(KernelAdaptor(), …)`).
@@ -533,11 +537,13 @@ function cufunction(@nospecialize(f), tt::Type{<:Tuple}=Tuple{};
533537 opt_level:: Union{Int, Nothing} = nothing ,
534538 num_ctas:: Union{Int, Nothing} = nothing ,
535539 occupancy:: Union{Int, Nothing} = nothing ,
540+ num_worker_warps:: Union{Int, Nothing} = nothing ,
536541 name:: Union{String, Nothing} = nothing )
537542 bytecode_version = check_tile_ir_support ()
538543 resolved_sm_arch = sm_arch != = nothing ? sm_arch : default_sm_arch ()
539544
540- key = TileCacheKey (resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy)
545+ key = TileCacheKey (resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy,
546+ num_worker_warps)
541547
542548 # Single pass over `tt.parameters`: build the unwrapped argtypes tuple
543549 # (Constant{T,V} → T for MI lookup) and the const_argtypes vector
685691
686692"""
687693 launch(f, grid, args...; sm_arch=nothing, opt_level=nothing,
688- num_ctas=nothing, occupancy=nothing, name=nothing)
694+ num_ctas=nothing, occupancy=nothing, num_worker_warps=nothing, name=nothing)
689695
690696Compile and launch a Tile IR kernel. `args` are converted via
691697`cuTileconvert` (CuArray → TileArray, Type → Constant). Equivalent to
@@ -715,10 +721,11 @@ function launch(@nospecialize(f), grid, args...;
715721 opt_level:: Union{Int, Nothing} = nothing ,
716722 num_ctas:: Union{Int, Nothing} = nothing ,
717723 occupancy:: Union{Int, Nothing} = nothing ,
724+ num_worker_warps:: Union{Int, Nothing} = nothing ,
718725 name:: Union{String, Nothing} = nothing )
719726 converted = map (cuTileconvert, args)
720727 tt = Tuple{map (Core. Typeof, converted)... }
721- kernel = cufunction (f, tt; sm_arch, opt_level, num_ctas, occupancy, name)
728+ kernel = cufunction (f, tt; sm_arch, opt_level, num_ctas, occupancy, num_worker_warps, name)
722729 kernel (converted... ; blocks= grid)
723730 return nothing
724731end
0 commit comments