@@ -9,10 +9,10 @@ using CUDA_Compiler_jll
99public launch
1010
1111# Compilation cache - stores CuFunction directly to avoid re-loading CuModule
12- const _compilation_cache = Dict {Any, Any} () # (f, argtypes, sm_arch, opt_level) => CuFunction
12+ const _compilation_cache = Dict {Any, Any} () # (f, argtypes, sm_arch, opt_level, num_ctas, occupancy ) => CuFunction
1313
1414"""
15- launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3)
15+ launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing )
1616
1717Compile and launch a kernel function with the given grid size and arguments.
1818
@@ -26,6 +26,8 @@ are expanded to their constituent ptr, sizes, and strides parameters.
2626- `name`: Optional kernel name for debugging
2727- `sm_arch`: Target GPU architecture (default: current device's capability)
2828- `opt_level`: Optimization level 0-3 (default: 3)
29+ - `num_ctas`: Number of CTAs in a CGA, 1-16, must be power of 2 (default: nothing)
30+ - `occupancy`: Expected active CTAs per SM, 1-32 (default: nothing)
2931
3032# Example
3133```julia
@@ -51,7 +53,9 @@ cuTile.launch(vadd_kernel, 64, a, b, c)
5153function cuTile. launch (@nospecialize (f), grid, args... ;
5254 name:: Union{String, Nothing} = nothing ,
5355 sm_arch:: String = default_sm_arch (),
54- opt_level:: Int = 3 )
56+ opt_level:: Int = 3 ,
57+ num_ctas:: Union{Int, Nothing} = nothing ,
58+ occupancy:: Union{Int, Nothing} = nothing )
5559 # Convert CuArray -> TileArray (and other conversions)
5660 tile_args = map (to_tile_arg, args)
5761
@@ -62,10 +66,10 @@ function cuTile.launch(@nospecialize(f), grid, args...;
6266 kernel_name = name != = nothing ? name : string (nameof (f))
6367
6468 # Check compilation cache - returns CuFunction directly
65- cache_key = (f, argtypes, sm_arch, opt_level)
69+ cache_key = (f, argtypes, sm_arch, opt_level, num_ctas, occupancy )
6670 cufunc = get (_compilation_cache, cache_key, nothing )
6771 if cufunc === nothing || cuTile. compile_hook[] != = nothing
68- cubin = compile (f, argtypes; name, sm_arch, opt_level)
72+ cubin = compile (f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy )
6973 if cufunc === nothing
7074 cumod = CuModule (cubin)
7175 cufunc = CuFunction (cumod, kernel_name)
@@ -98,15 +102,18 @@ function cuTile.launch(@nospecialize(f), grid, args...;
98102end
99103
100104"""
101- compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3) -> Vector{UInt8}
105+ compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing ) -> Vector{UInt8}
102106
103107Compile a Julia kernel function to a CUDA binary.
104108"""
105109function compile (@nospecialize (f), @nospecialize (argtypes);
106110 name:: Union{String, Nothing} = nothing ,
107111 sm_arch:: String = default_sm_arch (),
108- opt_level:: Int = 3 )
109- tile_bytecode = emit_tileir (f, argtypes; name)
112+ opt_level:: Int = 3 ,
113+ num_ctas:: Union{Int, Nothing} = nothing ,
114+ occupancy:: Union{Int, Nothing} = nothing )
115+ tile_bytecode = emit_tileir (f, argtypes; name, sm_arch,
116+ num_ctas, occupancy)
110117
111118 # Dump bytecode if JULIA_CUTILE_DUMP_BYTECODE is set
112119 dump_dir = get (ENV , " JULIA_CUTILE_DUMP_BYTECODE" , nothing )
0 commit comments