From c9aa469828f051018c57ff90e22280f9989395e5 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 26 May 2026 19:23:29 +0200 Subject: [PATCH 01/15] Add autotuning to Experimental submodule. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ports the autotune work from the autotune branch onto main's launch infrastructure (`cuTile.cufunction` / `TileKernel`, `cuTileconvert`, `Constant` unwrapping in `unwrap_argtypes`). Lives in `src/Experimental.jl` (the original was under `src/Experimental.jl` + `ext/autotune/autotune.jl` behind a CUDA weakdep; on main, `cuTile.launch` is in `src/` and depends on `CUDACore` directly, so the extension boundary is no longer needed). What this drops vs. the autotune branch: - `ext/CUDAExt.jl` — subsumed by `src/launch.jl` on main. - The `_SCOPED_INF_CACHE`/`create_inf_cache` scoped inference plumbing — caching is now handled by `CompilerCaching` inside `cufunction`. - Manual `emit_function!` calls in `precompile_cfg` — `cufunction` does the full compile+link in one shot, so precompile is just a `cufunction` call per cfg. API preserved: - `Experimental.autotune_launch(f, space, grid_fn, args_fn; ...)` - `Experimental.clear_autotune_cache(; kernel=nothing, key=nothing)` - `Experimental.CartesianSpace`, `FixedSpace`, `AbstractSearchSpace` - All tuning preset / verify / setup / launch_args_fn knobs. Tests: all 31 cases in `test/device/autotune.jl` pass. --- src/Experimental.jl | 44 ++++++ src/autotune.jl | 287 ++++++++++++++++++++++++++++++++++++++++ src/cuTile.jl | 2 + test/device/autotune.jl | 234 ++++++++++++++++++++++++++++++++ 4 files changed, 567 insertions(+) create mode 100644 src/Experimental.jl create mode 100644 src/autotune.jl create mode 100644 test/device/autotune.jl diff --git a/src/Experimental.jl b/src/Experimental.jl new file mode 100644 index 00000000..9e10c7f3 --- /dev/null +++ b/src/Experimental.jl @@ -0,0 +1,44 @@ +module Experimental + +using ..cuTile +using ..cuTile: cuTileconvert, cufunction, default_sm_arch + +using CUDACore: CUDACore + +using Random + +abstract type AbstractSearchSpace end + +Base.length(s::AbstractSearchSpace) = count(_ -> true, s) + +struct FixedSpace{names,NT<:NamedTuple{names}} <: AbstractSearchSpace + elements::Vector{NT} +end + +Base.iterate(space::FixedSpace, args...) = iterate(space.elements, args...) + +struct CartesianSpace{names,NT<:NamedTuple{names,<:Tuple{Vararg{Tuple}}}} <: AbstractSearchSpace + constraint::Function + axes::NT +end + +CartesianSpace(axes::NamedTuple) = CartesianSpace(Returns(true), axes) +CartesianSpace(; axes...) = CartesianSpace(NamedTuple(axes)) +CartesianSpace(constraint::Function; axes...) = CartesianSpace(constraint, NamedTuple(axes)) + +function Base.iterate(space::CartesianSpace{names}, state=nothing) where names + to_cfg = vals -> NamedTuple{names}(vals) + inner = state === nothing ? + Iterators.filter(space.constraint ∘ to_cfg, + Iterators.product(map(Tuple, values(space.axes))...)) : + state.inner + result = isnothing(state) ? iterate(inner) : iterate(inner, state.cursor) + isnothing(result) && return nothing + vals, cursor = result + cfg = to_cfg(vals) + return cfg, (; inner, cursor) +end + +include("autotune.jl") + +end diff --git a/src/autotune.jl b/src/autotune.jl new file mode 100644 index 00000000..dbf34197 --- /dev/null +++ b/src/autotune.jl @@ -0,0 +1,287 @@ +const AUTOTUNE_LOCK = ReentrantLock() +const AUTOTUNE_CACHE = Dict{Any, Dict{Any, Any}}() + +struct VerificationError <: Exception + msg::String +end + +const TUNING_PRESETS = ( + fast = (warmup=1, reps=3, refine_topk=0, refine_reps=2), + default = (warmup=2, reps=5, refine_topk=2, refine_reps=4), + thorough = (warmup=2, reps=7, refine_topk=4, refine_reps=6), +) + +function normalize_tuning(tuning::NamedTuple) + preset = get(tuning, :preset, :default) + preset isa Symbol || throw(ArgumentError("tuning.preset must be a Symbol")) + hasproperty(TUNING_PRESETS, preset) || + throw(ArgumentError("Unknown preset `$preset`; use :fast, :default, or :thorough")) + + base = merge(getproperty(TUNING_PRESETS, preset), + (seed=nothing, force=false, precompile_workers=Threads.nthreads())) + + overrides = NamedTuple(k => v for (k, v) in pairs(tuning) if k !== :preset) + return merge(base, overrides) +end + +# Extract hint fields (occupancy, num_ctas) from a config for cufunction kwargs. +function hints_from_cfg(cfg) + n = hasproperty(cfg, :num_ctas) ? cfg.num_ctas : nothing + o = hasproperty(cfg, :occupancy) ? cfg.occupancy : nothing + return (num_ctas=n, occupancy=o) +end + +# Pre-convert and compile a (cfg, args) pair into a cached TileKernel + the +# converted argument tuple ready for the call. Both the compilation and the +# adapt are reused across warmup/measurement reps. +function _prepare_launch(@nospecialize(f), cfg, args_fn::Function; + sm_arch::VersionNumber, opt_level::Int) + converted = map(cuTileconvert, args_fn(cfg)) + tt = Tuple{map(Core.Typeof, converted)...} + kernel = cufunction(f, tt; sm_arch, opt_level, hints_from_cfg(cfg)...) + return kernel, converted +end + +function time_ms(run_once::Function, get_args::Function; + warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing, + reset::Union{Nothing, Function}=nothing) + CUDACore.synchronize() + for _ in 1:max(warmup, verify !== nothing ? 1 : 0) + reset !== nothing && reset() + run_once(get_args()) + end + + if verify !== nothing + CUDACore.synchronize() + verify() || throw(VerificationError("config produced incorrect output")) + end + + best_ms = Inf32 + for _ in 1:reps + reset !== nothing && reset() + args = get_args() + CUDACore.synchronize() + elapsed_s = CUDACore.@elapsed run_once(args) + CUDACore.synchronize() + best_ms = min(best_ms, Float32(elapsed_s * 1000)) + end + return best_ms +end + +function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; + sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, + verify::Union{Nothing, Function}=nothing, + reset::Union{Nothing, Function}=nothing) + grid = grid_fn(cfg) + grid_dims = grid isa Integer ? (grid,) : grid + + # Compile once, then convert + call each rep. We `cufunction` outside the + # timed loop so JIT cost doesn't pollute the measurement. + sample_converted = map(cuTileconvert, args_fn(cfg)) + tt = Tuple{map(Core.Typeof, sample_converted)...} + kernel = cufunction(f, tt; sm_arch, opt_level, hints_from_cfg(cfg)...) + + run_once = converted -> kernel(converted...; blocks=grid_dims) + get_args = () -> map(cuTileconvert, args_fn(cfg)) + return time_ms(run_once, get_args; warmup, reps, verify, reset) +end + +function precompile_cfg(@nospecialize(f), cfg, args_fn::Function; + sm_arch::VersionNumber, opt_level::Int) + converted = map(cuTileconvert, args_fn(cfg)) + tt = Tuple{map(Core.Typeof, converted)...} + cufunction(f, tt; sm_arch, opt_level, hints_from_cfg(cfg)...) + return nothing +end + +function precompile_candidates(@nospecialize(f), configs::Vector{Any}, + args_fn::Function; + sm_arch::VersionNumber, opt_level::Int, workers::Int) + isempty(configs) && return configs, nothing + iszero(workers) && return configs, nothing + + workers = min(workers, Threads.nthreads(), length(configs)) + compiled = fill(true, length(configs)) + errors = Vector{Any}(nothing, length(configs)) + sem = Base.Semaphore(workers) + cancelled = Threads.Atomic{Bool}(false) + + try + @sync for (i, cfg) in enumerate(configs) + Threads.@spawn begin + cancelled[] && return + Base.acquire(sem) do + cancelled[] && return + try + precompile_cfg(f, cfg, args_fn; sm_arch, opt_level) + catch err + compiled[i] = false + errors[i] = (cfg, err) + end + end + end + end + catch e + cancelled[] = true + e isa InterruptException || rethrow() + @warn "Precompilation interrupted, waiting for in-flight workers…" + rethrow() + end + + first_err = nothing + for e in errors + if e !== nothing + first_err = e + break + end + end + + return configs[compiled], first_err +end + +function measure_candidates(@nospecialize(f), configs::Vector{Any}, + grid_fn::Function, args_fn::Function; + sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, + verify::Union{Nothing, Function}=nothing, + reset::Union{Nothing, Function}=nothing) + record = Tuple{Any, Float32}[] + first_error = nothing + for cfg in configs + ms = try + eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, verify, reset) + catch err + if err isa InterruptException + @warn "Benchmarking interrupted after $(length(record)) configs" + break + end + err isa VerificationError && @warn "Config $cfg failed verification, skipping" + first_error === nothing && (first_error = (cfg, err)) + continue + end + push!(record, (cfg, ms)) + end + return record, first_error +end + +function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG, + grid_fn::Function, args_fn::Function, tuning; + sm_arch::VersionNumber, opt_level::Int, kernel_key, arg_key, + verify::Union{Nothing, Function}=nothing, + setup::Union{Nothing, Function}=nothing) + if !tuning.force + entry = lock(AUTOTUNE_LOCK) do + per_kernel = get(AUTOTUNE_CACHE, kernel_key, nothing) + per_kernel !== nothing ? get(per_kernel, arg_key, nothing) : nothing + end + entry !== nothing && return entry, true, nothing + end + + checker = verify !== nothing ? verify() : nothing + reset = setup !== nothing ? setup() : nothing + + trials = Any[collect(space)...] + + trials, precompile_error = precompile_candidates(f, trials, args_fn; + sm_arch, opt_level, workers=tuning.precompile_workers) + + record, first_error = measure_candidates(f, trials, grid_fn, args_fn; + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, verify=checker, reset) + + if isempty(record) + err_info = first_error !== nothing ? first_error : precompile_error + if err_info === nothing + throw(ArgumentError("No valid config found in search space")) + else + cfg, err = err_info + throw(ArgumentError( + "No valid config found. First failure for cfg=$cfg: $(sprint(showerror, err))")) + end + end + + if tuning.refine_topk > 0 && length(record) > 1 + sort!(record, by=last) + top_configs = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] + refined, _ = measure_candidates(f, top_configs, grid_fn, args_fn; + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, reset) + if !isempty(refined) + record = refined + end + end + + _, best_idx = findmin(last, record) + candidate = (; best_config=record[best_idx][1], tuning_record=record) + + entry, _ = lock(AUTOTUNE_LOCK) do + per_kernel = get!(Dict{Any,Any}, AUTOTUNE_CACHE, kernel_key) + if !tuning.force && haskey(per_kernel, arg_key) + per_kernel[arg_key], true + else + per_kernel[arg_key] = candidate + candidate, false + end + end + return entry, false, reset +end + +""" + autotune_launch(f, space, grid_fn, args_fn; key, key_fn, launch_args_fn, + verify, setup, tuning, sm_arch, opt_level) + +Tune `f` over `space` (an [`AbstractSearchSpace`](@ref) or a `Vector`/`NamedTuple` +shorthand) and launch the best config. `grid_fn(cfg)` returns the launch +grid; `args_fn(cfg)` returns the argument tuple. Results are cached per +`(f, sm_arch, opt_level) ⇒ key`. +""" +function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, + grid_fn::Function, args_fn::Function; + key=nothing, + key_fn::Union{Nothing, Function}=nothing, + launch_args_fn::Union{Nothing, Function}=nothing, + verify::Union{Nothing, Function}=nothing, + setup::Union{Nothing, Function}=nothing, + tuning::NamedTuple=NamedTuple(), + sm_arch::VersionNumber=default_sm_arch(), + opt_level::Int=3) + tuning = normalize_tuning(tuning) + rng = tuning.seed !== nothing ? MersenneTwister(tuning.seed) : Random.default_rng() + + kernel_key = (f, sm_arch, opt_level) + arg_key = key !== nothing ? key : (key_fn !== nothing ? key_fn() : nothing) + + entry, cache_hit, reset = find_or_tune(f, space, rng, grid_fn, args_fn, tuning; + sm_arch, opt_level, kernel_key, arg_key, verify, setup) + + cfg = entry.best_config + grid = grid_fn(cfg) + args = launch_args_fn !== nothing ? launch_args_fn(cfg) : args_fn(cfg) + + reset !== nothing && reset() + + cuTile.launch(f, grid, args...; sm_arch, opt_level, hints_from_cfg(cfg)...) + + return (; tuned_config=cfg, grid, tuning_record=copy(entry.tuning_record), cache_hit) +end + +function autotune_launch(@nospecialize(f), configs, grid_fn::Function, args_fn::Function; kwargs...) + space = configs isa NamedTuple ? CartesianSpace(configs) : FixedSpace(configs) + return autotune_launch(f, space, grid_fn, args_fn; kwargs...) +end + +function clear_autotune_cache(; kernel=nothing, key=nothing) + lock(AUTOTUNE_LOCK) do + if kernel === nothing + key === nothing || throw(ArgumentError("`key` requires `kernel`")) + empty!(AUTOTUNE_CACHE) + return nothing + end + + for kernel_key in collect(keys(AUTOTUNE_CACHE)) + kernel_key isa Tuple || continue + kernel_key[1] === kernel || continue + per_kernel = AUTOTUNE_CACHE[kernel_key] + key === nothing ? empty!(per_kernel) : pop!(per_kernel, key, nothing) + isempty(per_kernel) && delete!(AUTOTUNE_CACHE, kernel_key) + end + end + return nothing +end diff --git a/src/cuTile.jl b/src/cuTile.jl index 784e1918..01cadcfd 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -123,4 +123,6 @@ end include("precompile.jl") +include("Experimental.jl") + end # module cuTile diff --git a/test/device/autotune.jl b/test/device/autotune.jl new file mode 100644 index 00000000..d6e705e9 --- /dev/null +++ b/test/device/autotune.jl @@ -0,0 +1,234 @@ +using CUDA + +const Exp = ct.Experimental + +@testset "Autotune" begin + + function vadd_kernel(a::ct.TileArray{Float32,1}, + b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}, + tile::Int) + pid = ct.bid(1) + ta = ct.load(a, pid, (tile[],)) + tb = ct.load(b, pid, (tile[],)) + ct.store(c, pid, ta + tb) + return nothing + end + + function inplace_add_kernel(x::ct.TileArray{Float32,1}, + tile::Int) + pid = ct.bid(1) + tx = ct.load(x, pid, (tile[],)) + ct.store(x, pid, tx .+ 1f0) + return nothing + end + + n = 512 + a = CUDA.fill(1f0, n) + b = CUDA.fill(2f0, n) + c = CUDA.zeros(Float32, n) + + configs = [ + (; tile=16, occupancy=nothing, num_ctas=nothing), + (; tile=32, occupancy=2, num_ctas=nothing), + (; tile=64, occupancy=4, num_ctas=2), + ] + args_fn = cfg -> (a, b, c, ct.Constant(cfg.tile)) + grid_fn = cfg -> cld(n, cfg.tile) + + @testset "basic tuning" begin + Exp.clear_autotune_cache() + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test !result.cache_hit + @test result.tuned_config in configs + @test !isempty(result.tuning_record) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "cache hit" begin + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "force retune" begin + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0, force=true), + ) + @test !result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "CartesianSpace" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + space = Exp.CartesianSpace(; + tile=(16, 32), occupancy=(nothing, 2), num_ctas=(nothing,)) + result = Exp.autotune_launch( + vadd_kernel, space, grid_fn, args_fn; + key=(:cartesian, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test hasproperty(result.tuned_config, :tile) + @test hasproperty(result.tuned_config, :occupancy) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "CartesianSpace with constraint" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + space = Exp.CartesianSpace( + cfg -> cfg.tile == 16; + tile=(16, 32, 64), occupancy=(nothing,), num_ctas=(nothing,)) + result = Exp.autotune_launch( + vadd_kernel, space, grid_fn, args_fn; + key=(:constrained, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.tuned_config.tile == 16 + @test Array(c) ≈ fill(3f0, n) + end + + @testset "NamedTuple convenience → CartesianSpace" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, + (tile=(16, 32), occupancy=(nothing,), num_ctas=(nothing,)), + grid_fn, args_fn; + key=(:nt_convenience, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.tuned_config.tile in (16, 32) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "launch_args_fn (inplace kernel)" begin + x = CUDA.zeros(Float32, n) + original_x = Array(x) + Exp.clear_autotune_cache() + result = Exp.autotune_launch( + inplace_add_kernel, + [(; tile=16), (; tile=32)], + grid_fn, + cfg -> (copy(x), ct.Constant(cfg.tile)); + launch_args_fn=cfg -> (x, ct.Constant(cfg.tile)), + key=(:inplace, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test !result.cache_hit + @test Array(x) == original_x .+ 1f0 + end + + @testset "refinement" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:refine, n), + tuning=(warmup=1, reps=2, refine_topk=2, refine_reps=4), + ) + @test !result.cache_hit + # Refinement record replaces initial — has at most refine_topk entries + @test length(result.tuning_record) <= 2 + @test Array(c) ≈ fill(3f0, n) + end + + @testset "verify" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + verify_called = Ref(false) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:verify, n), + tuning=(preset=:fast, refine_topk=0), + verify=() -> let + ref = Array(a) .+ Array(b) + verify_called[] = true + () -> (CUDA.@allowscalar all(isapprox.(Array(c), ref, atol=1f-5))) + end, + ) + @test verify_called[] + @test Array(c) ≈ fill(3f0, n) + end + + @testset "clear cache per-kernel per-key" begin + Exp.clear_autotune_cache() + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k1, n), tuning=(preset=:fast, refine_topk=0)) + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k2, n), tuning=(preset=:fast, refine_topk=0)) + + # Clear only one key + Exp.clear_autotune_cache(kernel=vadd_kernel, key=(:k1, n)) + fill!(c, 0f0) + r1 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k1, n), tuning=(preset=:fast, refine_topk=0)) + @test !r1.cache_hit # was cleared + + fill!(c, 0f0) + r2 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k2, n), tuning=(preset=:fast, refine_topk=0)) + @test r2.cache_hit # still cached + end + + @testset "shared key across shapes" begin + Exp.clear_autotune_cache() + n2 = 1024 + a2 = CUDA.fill(1f0, n2) + b2 = CUDA.fill(2f0, n2) + c2 = CUDA.zeros(Float32, n2) + shared_key = (:shape_agnostic, eltype(a)) + + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=shared_key, tuning=(preset=:fast, refine_topk=0)) + + fill!(c2, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, + cfg -> cld(n2, cfg.tile), + cfg -> (a2, b2, c2, ct.Constant(cfg.tile)); + key=shared_key, tuning=(preset=:fast, refine_topk=0)) + @test result.cache_hit + @test result.grid == cld(n2, result.tuned_config.tile) + @test Array(c2) ≈ fill(3f0, n2) + end + + @testset "key_fn" begin + Exp.clear_autotune_cache() + call_count = Ref(0) + my_key_fn = () -> begin + call_count[] += 1 + return (:dynamic, Float32) + end + + fill!(c, 0f0) + r1 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key_fn=my_key_fn, tuning=(preset=:fast, refine_topk=0)) + r2 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key_fn=my_key_fn, tuning=(preset=:fast, refine_topk=0)) + @test !r1.cache_hit + @test r2.cache_hit + @test call_count[] == 2 + @test Array(c) ≈ fill(3f0, n) + end +end From 0549723cd85ce8784067484c6a042f137a0083af Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 26 May 2026 23:03:26 +0200 Subject: [PATCH 02/15] Lock cufunction's codegen pipeline against concurrent compiles. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `emit_tile!` and everything below it (inference, structured IR, tile-IR emission) mutates shared state — `CacheView` entries, `CuTileResults` fields, the inference cache, and CompilerCaching's per-CI const_entries vector — none of which is thread-safe. Without this, concurrent `cufunction` calls (e.g. autotuning's precompile fan-out across `Threads.@spawn` workers) silently race. Two acquisitions, brief each: - `ensure_compiled` inside `compile` (lookup hit just briefly contends) - `emit_tile!` inside `emit_binary!` The tileiras subprocess still runs unlocked, so concurrent compiles can still overlap their shell-outs — the original rationale behind the autotune branch's `EMIT_TILE_LOCK`. --- src/launch.jl | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/launch.jl b/src/launch.jl index c4c097eb..bf480663 100644 --- a/src/launch.jl +++ b/src/launch.jl @@ -394,6 +394,16 @@ end Compilation: bytecode → CUBIN → CuFunction. =============================================================================# +# Serializes the Julia-side codegen pipeline (inference, structured IR, +# tile-IR emission). `emit_tile!` and everything it calls into mutates shared +# state: `CacheView` entries, `CuTileResults` fields, the inference cache, +# and CompilerCaching's per-CI const_entries vector. None of that is +# thread-safe. We hold the lock only across `emit_tile!`; the tileiras +# subprocess below runs unlocked so concurrent `cufunction` calls (e.g. +# from autotuning's precompile fan-out) can still overlap their tileiras +# shell-outs. +const EMIT_TILE_LOCK = ReentrantLock() + """ emit_binary!(cache, mi, ci, res; const_argtypes=nothing) -> Vector{UInt8} @@ -405,7 +415,7 @@ function emit_binary!(cache::CacheView, mi::Core.MethodInstance, # Recurse first — emit_structured! at the bottom of the chain fires # `compile_hook` for `@device_code_*` reflection, which must run on every # launch even when downstream artifacts are fully cached. - bytecode = emit_tile!(cache, mi, ci, res; const_argtypes) + bytecode = Base.@lock EMIT_TILE_LOCK emit_tile!(cache, mi, ci, res; const_argtypes) res.cuda_bin !== nothing && return res.cuda_bin @@ -584,7 +594,12 @@ function compile(@nospecialize(f), @nospecialize(argtypes), # underlying `CodeInstance` via CompilerCaching; the `TileKernel` wrapper # rides along in the same `CuTileResults`, so kernel-instance lifecycle # follows the CI's instead of needing a separate global Dict. - ci, res = ensure_compiled(cache, mi, const_argtypes) + # + # Held under EMIT_TILE_LOCK so concurrent compiles (e.g. autotuning's + # precompile fan-out) don't race on inference / CompilerCaching state. + # The lock-protected region also includes the lookup fast path; that + # path is just a hashtable read, so brief contention here is fine. + ci, res = Base.@lock EMIT_TILE_LOCK ensure_compiled(cache, mi, const_argtypes) # Always walk the emit chain (each phase short-circuits on its own cached # field, but `emit_structured!` also fires `compile_hook` for reflection, From 8a9b1b87fa211d6058c373b0f20aeb77633a5cc5 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 26 May 2026 23:03:39 +0200 Subject: [PATCH 03/15] Restore shared inference cache via ScopedValue. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `_SCOPED_INF_CACHE::ScopedValue` to the cuTile interpreter so a caller can opt into reusing one `CC.InferenceCache` across many `cuTileInterpreter(cache)` constructions. Autotuning's `find_or_tune` now wraps its precompile+measure pass in `with(_SCOPED_INF_CACHE => fresh)`, so the 25-config sweep over a single kernel shares inference results across all const-seeded variants instead of paying the slow paths once per config. Microbenchmark on a TileArray load with `order=(1,2)` (the original motivating case), 25 configs: Fresh: 9.7 ms/cfg Shared: 2.4 ms/cfg (~4x speedup) The ScopedValue is unassigned by default, so non-autotune callers (plain `@cuda backend=cuTile`) behave exactly as before — each `cuTileInterpreter` allocates its own fresh cache. --- src/Experimental.jl | 14 +++++++++++++- src/autotune.jl | 18 +++++++++++++----- src/compiler/interpreter.jl | 19 ++++++++++++++++--- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/Experimental.jl b/src/Experimental.jl index 9e10c7f3..14568e69 100644 --- a/src/Experimental.jl +++ b/src/Experimental.jl @@ -1,12 +1,24 @@ module Experimental using ..cuTile -using ..cuTile: cuTileconvert, cufunction, default_sm_arch +using ..cuTile: cuTileconvert, cufunction, default_sm_arch, _SCOPED_INF_CACHE using CUDACore: CUDACore +using Base.ScopedValues: with +import Core.Compiler as CC using Random +# Builds a fresh inference cache compatible with the running Julia version. +# Used to wrap an autotune pass in `with(_SCOPED_INF_CACHE => …)` so all the +# per-config const-seeded inference calls share results instead of paying +# the slow paths (e.g. `ct.load(..., order=…)`) once per config. +@inline _fresh_inf_cache() = @static if isdefined(CC, :InferenceCache) + CC.InferenceCache() +else + Vector{CC.InferenceResult}() +end + abstract type AbstractSearchSpace end Base.length(s::AbstractSearchSpace) = count(_ -> true, s) diff --git a/src/autotune.jl b/src/autotune.jl index dbf34197..83fc6b2a 100644 --- a/src/autotune.jl +++ b/src/autotune.jl @@ -181,11 +181,19 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac trials = Any[collect(space)...] - trials, precompile_error = precompile_candidates(f, trials, args_fn; - sm_arch, opt_level, workers=tuning.precompile_workers) - - record, first_error = measure_candidates(f, trials, grid_fn, args_fn; - sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, verify=checker, reset) + # Share the inference cache across all per-cfg const-seeded compiles. + # Each cfg differs only in `Constant{T,V}` values, so the generic + # inference graph is identical — without sharing, kernels with slow + # inference paths (e.g. `ct.load(..., order=…)`) pay that cost N times. + trials, precompile_error, record, first_error = + with(_SCOPED_INF_CACHE => _fresh_inf_cache()) do + t, pe = precompile_candidates(f, trials, args_fn; + sm_arch, opt_level, workers=tuning.precompile_workers) + r, fe = measure_candidates(f, t, grid_fn, args_fn; + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, + verify=checker, reset) + (t, pe, r, fe) + end if isempty(record) err_info = first_error !== nothing ? first_error : precompile_error diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index c5e03515..c7ad2256 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -1,8 +1,17 @@ # Integration with Julia's abstract interpreter +using Base.ScopedValues: ScopedValue Base.Experimental.@MethodTable cuTileMethodTable +# When assigned, every `cuTileInterpreter(cache)` constructed within +# the dynamic scope reuses this inference cache instead of allocating +# a fresh one. Lets batched inference passes (autotuning over many +# const-seeded variants of the same kernel) share work; without it, +# kernels that hit slow inference paths (e.g. `ct.load(..., order=...)`) +# pay the cost on every config. +const _SCOPED_INF_CACHE = ScopedValue{Any}() + function get_method_table_view(world::UInt) CC.CachedMethodTable(CC.OverlayMethodTable(world, cuTileMethodTable)) end @@ -21,10 +30,14 @@ end function cuTileInterpreter(cache::CacheView; always_inline::Bool=true) method_table = get_method_table_view(cache.world) - @static if isdefined(CC, :InferenceCache) - inf_cache = CC.InferenceCache() + inf_cache = if isassigned(_SCOPED_INF_CACHE) + _SCOPED_INF_CACHE[] else - inf_cache = Vector{CC.InferenceResult}() + @static if isdefined(CC, :InferenceCache) + CC.InferenceCache() + else + Vector{CC.InferenceResult}() + end end inf_params = CC.InferenceParams() opt_params = if always_inline From 4b7176ff5d658455943668a93150fdc1b5b98d95 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 26 May 2026 23:03:49 +0200 Subject: [PATCH 04/15] Add hygienic @cutile macro. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Shorthand for `CUDACore.@cuda backend=cuTile …`. Module references are interpolated as values (`$CUDACore`, `$cuTile`) rather than symbols, so callers don't need `using CUDACore` (or even `using cuTile`) at the call site — the expanded form points directly at the module objects. Verified working from a module with only `using cuTile`: module NoSetup using cuTile function run(a, b, c) cuTile.@cutile blocks=cld(length(a), 16) vadd(a, b, c) end end --- src/cuTile.jl | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/cuTile.jl b/src/cuTile.jl index 01cadcfd..8d29cff4 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -88,9 +88,30 @@ include("cache.jl") include("launch.jl") public launch, TileBackend, DefaultBackend, Tiled, ByTarget, - @compiler_options, @fpmode, @., + @compiler_options, @fpmode, @., @cutile, bytecode_version +""" + @cutile [kwargs...] kernel(args...) + +Shorthand for `CUDACore.@cuda backend=cuTile [kwargs...] kernel(args...)`. + +Works from any module — does not require the caller to `using CUDACore`, +since the macro expands to a fully-qualified reference to the actual +`CUDACore` module object. + +```julia +@cutile blocks=N kernel(a, b, c) +@cutile blocks=N occupancy=4 kernel(a, b, c) +``` +""" +macro cutile(args...) + # Interpolate the actual module values rather than the symbols so the + # caller doesn't need `CUDACore`/`cuTile` in scope — the expanded form + # references the module objects directly. + esc(:($CUDACore.@cuda backend=$cuTile $(args...))) +end + # World age captured at __init__ time. The compilation pipeline # (typeinf!, codegen, bytecode emission) is invoked in this world via # `invoke_frozen` so that precompiled native code stays usable even after From 4d5d315032d39cf93de430e6e39f37c17b8ac64e Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 26 May 2026 23:31:29 +0200 Subject: [PATCH 05/15] Add `@autotune` macro and static num_ctas/occupancy hints. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `@autotune` is a thin surface over `autotune_launch`. `$X` inside `blocks=` or the kernel-call args is rewritten to `cfg.X` (the macro intercepts `Expr(:$, :X)` nodes before lowering rejects them): @autotune( key = (eltype(A), size(A, 2)), space = (TILE_M=(64, 128), TILE_N=(64, 128), occupancy=(1, 2, 4)), blocks = (cld(M, $TILE_M), cld(N, $TILE_N)), matmul(A, B, C, Constant($TILE_M), Constant($TILE_N)) ) `space=` accepts a NamedTuple literal (→ CartesianSpace), a Vector of NamedTuples (→ FixedSpace), or any `AbstractSearchSpace` (pass-through — use this for `CartesianSpace(constraint; ...)`). Kernel kwargs aren't supported (rejected at expansion); pass values positionally via `Constant(...)`. `autotune_launch` now also accepts `num_ctas`/`occupancy` as **static** kwargs (applied uniformly to every cfg, e.g. for `ByTarget(...)` per-arch dispatch). They may not coexist with a same-named axis in `space`: - The macro flags the conflict at expansion time when `space` is a literal NamedTuple. - `autotune_launch` flags it at run time otherwise (opaque spaces). Cache key now includes the static hints, so cfgs tuned under different `num_ctas`/`occupancy` settings are kept separate. Tests: 48 cases pass (31 existing + 17 covering static hints, the macro, $X interpolation in tuple-blocks, NT vs Vector space, required-kwarg errors, and the macro-time + run-time conflict errors). --- src/Experimental.jl | 1 + src/autotune.jl | 90 ++++++++++++++------- src/autotune_macro.jl | 173 ++++++++++++++++++++++++++++++++++++++++ test/device/autotune.jl | 106 ++++++++++++++++++++++++ 4 files changed, 342 insertions(+), 28 deletions(-) create mode 100644 src/autotune_macro.jl diff --git a/src/Experimental.jl b/src/Experimental.jl index 14568e69..9f3a56a0 100644 --- a/src/Experimental.jl +++ b/src/Experimental.jl @@ -52,5 +52,6 @@ function Base.iterate(space::CartesianSpace{names}, state=nothing) where names end include("autotune.jl") +include("autotune_macro.jl") end diff --git a/src/autotune.jl b/src/autotune.jl index 83fc6b2a..271ff0e4 100644 --- a/src/autotune.jl +++ b/src/autotune.jl @@ -24,24 +24,16 @@ function normalize_tuning(tuning::NamedTuple) return merge(base, overrides) end -# Extract hint fields (occupancy, num_ctas) from a config for cufunction kwargs. -function hints_from_cfg(cfg) - n = hasproperty(cfg, :num_ctas) ? cfg.num_ctas : nothing - o = hasproperty(cfg, :occupancy) ? cfg.occupancy : nothing +# Extract hint fields (num_ctas, occupancy) from a config, falling back to +# the static defaults supplied by the caller. cfg takes precedence; the +# caller is expected to have rejected the both-supplied case upstream +# (see `autotune_launch`). +function hints_from_cfg(cfg; static_num_ctas=nothing, static_occupancy=nothing) + n = hasproperty(cfg, :num_ctas) ? cfg.num_ctas : static_num_ctas + o = hasproperty(cfg, :occupancy) ? cfg.occupancy : static_occupancy return (num_ctas=n, occupancy=o) end -# Pre-convert and compile a (cfg, args) pair into a cached TileKernel + the -# converted argument tuple ready for the call. Both the compilation and the -# adapt are reused across warmup/measurement reps. -function _prepare_launch(@nospecialize(f), cfg, args_fn::Function; - sm_arch::VersionNumber, opt_level::Int) - converted = map(cuTileconvert, args_fn(cfg)) - tt = Tuple{map(Core.Typeof, converted)...} - kernel = cufunction(f, tt; sm_arch, opt_level, hints_from_cfg(cfg)...) - return kernel, converted -end - function time_ms(run_once::Function, get_args::Function; warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing, reset::Union{Nothing, Function}=nothing) @@ -70,6 +62,7 @@ end function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, + static_num_ctas=nothing, static_occupancy=nothing, verify::Union{Nothing, Function}=nothing, reset::Union{Nothing, Function}=nothing) grid = grid_fn(cfg) @@ -79,7 +72,8 @@ function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; # timed loop so JIT cost doesn't pollute the measurement. sample_converted = map(cuTileconvert, args_fn(cfg)) tt = Tuple{map(Core.Typeof, sample_converted)...} - kernel = cufunction(f, tt; sm_arch, opt_level, hints_from_cfg(cfg)...) + kernel = cufunction(f, tt; sm_arch, opt_level, + hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) run_once = converted -> kernel(converted...; blocks=grid_dims) get_args = () -> map(cuTileconvert, args_fn(cfg)) @@ -87,16 +81,19 @@ function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; end function precompile_cfg(@nospecialize(f), cfg, args_fn::Function; - sm_arch::VersionNumber, opt_level::Int) + sm_arch::VersionNumber, opt_level::Int, + static_num_ctas=nothing, static_occupancy=nothing) converted = map(cuTileconvert, args_fn(cfg)) tt = Tuple{map(Core.Typeof, converted)...} - cufunction(f, tt; sm_arch, opt_level, hints_from_cfg(cfg)...) + cufunction(f, tt; sm_arch, opt_level, + hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) return nothing end function precompile_candidates(@nospecialize(f), configs::Vector{Any}, args_fn::Function; - sm_arch::VersionNumber, opt_level::Int, workers::Int) + sm_arch::VersionNumber, opt_level::Int, workers::Int, + static_num_ctas=nothing, static_occupancy=nothing) isempty(configs) && return configs, nothing iszero(workers) && return configs, nothing @@ -113,7 +110,8 @@ function precompile_candidates(@nospecialize(f), configs::Vector{Any}, Base.acquire(sem) do cancelled[] && return try - precompile_cfg(f, cfg, args_fn; sm_arch, opt_level) + precompile_cfg(f, cfg, args_fn; sm_arch, opt_level, + static_num_ctas, static_occupancy) catch err compiled[i] = false errors[i] = (cfg, err) @@ -142,13 +140,15 @@ end function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn::Function, args_fn::Function; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, + static_num_ctas=nothing, static_occupancy=nothing, verify::Union{Nothing, Function}=nothing, reset::Union{Nothing, Function}=nothing) record = Tuple{Any, Float32}[] first_error = nothing for cfg in configs ms = try - eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, verify, reset) + eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, + static_num_ctas, static_occupancy, verify, reset) catch err if err isa InterruptException @warn "Benchmarking interrupted after $(length(record)) configs" @@ -166,6 +166,7 @@ end function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG, grid_fn::Function, args_fn::Function, tuning; sm_arch::VersionNumber, opt_level::Int, kernel_key, arg_key, + static_num_ctas=nothing, static_occupancy=nothing, verify::Union{Nothing, Function}=nothing, setup::Union{Nothing, Function}=nothing) if !tuning.force @@ -181,6 +182,24 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac trials = Any[collect(space)...] + # Conflict check: if the cfg carries a `num_ctas`/`occupancy` field AND + # the caller also provided a static value, error rather than silently + # ignoring one. (Handles the case where `space` is opaque to the macro + # — a user-built `CartesianSpace(...)` or `FixedSpace([(...),...])`.) + if !isempty(trials) + sample = first(trials) + if static_num_ctas !== nothing && hasproperty(sample, :num_ctas) + throw(ArgumentError( + "`num_ctas` is both a static kwarg and an axis in the search space. " * + "Pick one.")) + end + if static_occupancy !== nothing && hasproperty(sample, :occupancy) + throw(ArgumentError( + "`occupancy` is both a static kwarg and an axis in the search space. " * + "Pick one.")) + end + end + # Share the inference cache across all per-cfg const-seeded compiles. # Each cfg differs only in `Constant{T,V}` values, so the generic # inference graph is identical — without sharing, kernels with slow @@ -188,9 +207,11 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac trials, precompile_error, record, first_error = with(_SCOPED_INF_CACHE => _fresh_inf_cache()) do t, pe = precompile_candidates(f, trials, args_fn; - sm_arch, opt_level, workers=tuning.precompile_workers) + sm_arch, opt_level, workers=tuning.precompile_workers, + static_num_ctas, static_occupancy) r, fe = measure_candidates(f, t, grid_fn, args_fn; sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, + static_num_ctas, static_occupancy, verify=checker, reset) (t, pe, r, fe) end @@ -210,7 +231,8 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac sort!(record, by=last) top_configs = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] refined, _ = measure_candidates(f, top_configs, grid_fn, args_fn; - sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, reset) + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, + static_num_ctas, static_occupancy, reset) if !isempty(refined) record = refined end @@ -233,12 +255,18 @@ end """ autotune_launch(f, space, grid_fn, args_fn; key, key_fn, launch_args_fn, - verify, setup, tuning, sm_arch, opt_level) + verify, setup, tuning, sm_arch, opt_level, + num_ctas=nothing, occupancy=nothing) Tune `f` over `space` (an [`AbstractSearchSpace`](@ref) or a `Vector`/`NamedTuple` shorthand) and launch the best config. `grid_fn(cfg)` returns the launch grid; `args_fn(cfg)` returns the argument tuple. Results are cached per `(f, sm_arch, opt_level) ⇒ key`. + +`num_ctas` and `occupancy` may be supplied as **static** kwargs (applied +uniformly to every cfg — useful for `ByTarget(...)`-style per-arch dispatch) +OR as **axes** inside `space` (tuned per cfg), but not both. Specifying +both throws an `ArgumentError`. """ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, grid_fn::Function, args_fn::Function; @@ -249,15 +277,19 @@ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, setup::Union{Nothing, Function}=nothing, tuning::NamedTuple=NamedTuple(), sm_arch::VersionNumber=default_sm_arch(), - opt_level::Int=3) + opt_level::Int=3, + num_ctas=nothing, + occupancy=nothing) tuning = normalize_tuning(tuning) rng = tuning.seed !== nothing ? MersenneTwister(tuning.seed) : Random.default_rng() - kernel_key = (f, sm_arch, opt_level) + kernel_key = (f, sm_arch, opt_level, num_ctas, occupancy) arg_key = key !== nothing ? key : (key_fn !== nothing ? key_fn() : nothing) entry, cache_hit, reset = find_or_tune(f, space, rng, grid_fn, args_fn, tuning; - sm_arch, opt_level, kernel_key, arg_key, verify, setup) + sm_arch, opt_level, kernel_key, arg_key, + static_num_ctas=num_ctas, static_occupancy=occupancy, + verify, setup) cfg = entry.best_config grid = grid_fn(cfg) @@ -265,7 +297,9 @@ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, reset !== nothing && reset() - cuTile.launch(f, grid, args...; sm_arch, opt_level, hints_from_cfg(cfg)...) + cuTile.launch(f, grid, args...; sm_arch, opt_level, + hints_from_cfg(cfg; static_num_ctas=num_ctas, + static_occupancy=occupancy)...) return (; tuned_config=cfg, grid, tuning_record=copy(entry.tuning_record), cache_hit) end diff --git a/src/autotune_macro.jl b/src/autotune_macro.jl new file mode 100644 index 00000000..bac633cf --- /dev/null +++ b/src/autotune_macro.jl @@ -0,0 +1,173 @@ +# `@autotune` — surface syntax for `autotune_launch`. +# +# Desugars +# +# @autotune( +# key = (T, Dk_pow2), +# space = (TILE_M=(32, 64), TILE_N=(32, 64), occupancy=(1, 2)), +# blocks = (cld(SeqLen_Q, $TILE_M), Heads * Batch), +# mha_fwd(Q, K, V, ..., Constant($TILE_M), Constant($TILE_N)), +# ) +# +# into a call to `autotune_launch`, with `$X` interpolated as `cfg.X` in +# `blocks` and the kernel-call args. `$X` is the parser's `Expr(:$, :X)` +# node, which is normally rejected by lowering; the macro intercepts and +# rewrites it before that happens. + +# Walk an expression replacing `Expr(:$, :X)` with `cfg.X`. +function _autotune_interp(ex, cfg::Symbol) + if Meta.isexpr(ex, :$) + length(ex.args) == 1 || + error("@autotune: \$ takes exactly one argument") + sym = ex.args[1] + sym isa Symbol || + error("@autotune: \$ argument must be a symbol, got `$sym`") + return :($cfg.$sym) + elseif ex isa Expr + return Expr(ex.head, [_autotune_interp(a, cfg) for a in ex.args]...) + else + return ex + end +end + +# Try to extract axis names from a literal NT space expression. +# Returns Vector{Symbol} on success, or `nothing` if the expression +# isn't a recognizable NT literal (e.g. it's a variable reference, a +# function call, or a constructor like `FixedSpace([...])`). +function _autotune_space_axes(space_expr) + Meta.isexpr(space_expr, :tuple) || return nothing + axes = Symbol[] + for a in space_expr.args + if Meta.isexpr(a, :(=)) && a.args[1] isa Symbol + push!(axes, a.args[1]) + elseif Meta.isexpr(a, :parameters) + for kv in a.args + if Meta.isexpr(kv, :kw) && kv.args[1] isa Symbol + push!(axes, kv.args[1]) + else + return nothing + end + end + else + return nothing + end + end + return axes +end + +const _AUTOTUNE_KWARGS = (:key, :space, :blocks, :tuning, :verify, :setup, + :sm_arch, :opt_level, :key_fn, :launch_args_fn, + :num_ctas, :occupancy) + +""" + @autotune key=... space=... blocks=... [kwargs...] kernel(args...) + +Tune `kernel` over `space` and launch the best config. Inside `blocks=` +and the kernel-call args, `\$X` interpolates `cfg.X` (where `cfg` is the +tuning configuration being evaluated). + +# Required kwargs +- `space` — a `NamedTuple` like `(A=(...), B=(...))` (becomes a + `CartesianSpace`), a `Vector` of `NamedTuple`s (becomes a `FixedSpace`), + or any `AbstractSearchSpace` (passed through — useful for + `CartesianSpace(constraint; ...)`). +- `blocks` — grid dimensions, an `Int` or `Tuple`. May reference `\$X`. + +# Optional kwargs +- `key` — eager cache key (any value) +- `key_fn` — lazy alternative to `key` +- `tuning` — `NamedTuple` of tuning knobs (`preset`, `force`, etc.) +- `verify` — `() -> (() -> Bool)` factory; the returned checker is + called after each warmup pass to reject incorrect cfgs +- `setup` — `() -> (() -> Nothing)` factory; reset between reps +- `launch_args_fn` — final-launch arg builder (defaults to the kernel-call + args); use this when the timed args are throwaway + copies (in-place kernels) and the final launch should + hit the real buffers +- `sm_arch`, `opt_level` — forwarded to `cufunction` +- `num_ctas`, `occupancy` — **static** hints applied uniformly to every + cfg. May not coexist with same-named axes in `space` + (the macro flags the conflict at expansion time when + `space` is a literal NT; otherwise `autotune_launch` + catches it at run time) + +# Example + +```julia +@autotune( + key = (eltype(A), size(A, 2)), + space = (TILE_M=(64, 128), TILE_N=(64, 128), occupancy=(1, 2, 4)), + blocks = (cld(M, \$TILE_M), cld(N, \$TILE_N)), + matmul(A, B, C, Constant(\$TILE_M), Constant(\$TILE_N)) +) +``` +""" +macro autotune(args...) + kwargs = Dict{Symbol, Any}() + call = nothing + for arg in args + if Meta.isexpr(arg, :(=)) || Meta.isexpr(arg, :kw) + k, v = arg.args + k isa Symbol || error("@autotune: kwarg key must be a symbol, got `$k`") + k in _AUTOTUNE_KWARGS || + error("@autotune: unknown kwarg `$k`. Valid: $(join(_AUTOTUNE_KWARGS, ", "))") + haskey(kwargs, k) && error("@autotune: duplicate kwarg `$k`") + kwargs[k] = v + elseif Meta.isexpr(arg, :call) + call === nothing || error("@autotune: only one kernel call allowed") + call = arg + else + error("@autotune: unexpected argument `$arg` — expected `kwarg=val` or a kernel call") + end + end + + call === nothing && error("@autotune: missing kernel call (e.g. `kernel(args...)`)") + haskey(kwargs, :space) || error("@autotune: missing required kwarg `space=`") + haskey(kwargs, :blocks) || error("@autotune: missing required kwarg `blocks=`") + + space_expr = kwargs[:space] + blocks_expr = kwargs[:blocks] + + # Macro-time conflict check: if `space` is a literal NT and contains + # `num_ctas`/`occupancy` as an axis, the same name can't also appear + # as a static kwarg. (Opaque spaces are caught at run time.) + space_axes = _autotune_space_axes(space_expr) + if space_axes !== nothing + for hint in (:num_ctas, :occupancy) + if hint in space_axes && haskey(kwargs, hint) + error("@autotune: `$hint` appears both as an axis in `space=` and " * + "as a static kwarg. Pick one.") + end + end + end + + # Extract the kernel call (positional only — no kernel kwargs). + Meta.isexpr(call, :call) || + error("@autotune: kernel must be a function-call expression") + f_expr = call.args[1] + arg_exprs = call.args[2:end] + for a in arg_exprs + if Meta.isexpr(a, :parameters) || Meta.isexpr(a, :kw) + error("@autotune: kernel-side kwargs are not supported; pass values " * + "as positional args (typically wrapped in `Constant(...)`)") + end + end + + # Substitute `$X` -> `cfg.X` inside blocks/args. + cfg = gensym(:cfg) + grid_body = _autotune_interp(blocks_expr, cfg) + arg_subs = [_autotune_interp(a, cfg) for a in arg_exprs] + + grid_fn = :($cfg -> $grid_body) + args_fn = :($cfg -> ($(arg_subs...),)) + + # Forward all macro kwargs (except space/blocks, which are positional + # / lifted into the closures) to `autotune_launch`. + forwarded_keys = (:key, :tuning, :verify, :setup, :sm_arch, :opt_level, + :key_fn, :launch_args_fn, :num_ctas, :occupancy) + kw_exprs = [Expr(:kw, k, kwargs[k]) for k in forwarded_keys if haskey(kwargs, k)] + + return esc(quote + $autotune_launch($f_expr, $space_expr, $grid_fn, $args_fn; $(kw_exprs...)) + end) +end diff --git a/test/device/autotune.jl b/test/device/autotune.jl index d6e705e9..3cac1253 100644 --- a/test/device/autotune.jl +++ b/test/device/autotune.jl @@ -231,4 +231,110 @@ const Exp = ct.Experimental @test call_count[] == 2 @test Array(c) ≈ fill(3f0, n) end + + @testset "static num_ctas / occupancy as kwargs" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + # `space` has no num_ctas/occupancy axes — they're static kwargs. + result = Exp.autotune_launch( + vadd_kernel, + [(; tile=16), (; tile=32)], + cfg -> cld(n, cfg.tile), + cfg -> (a, b, c, ct.Constant(cfg.tile)); + key=(:static_hints, n), + occupancy=2, + tuning=(preset=:fast, refine_topk=0)) + @test !result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "conflict: static + space axis" begin + Exp.clear_autotune_cache() + # Run-time path (opaque space): autotune_launch should reject. + @test_throws ArgumentError Exp.autotune_launch( + vadd_kernel, + [(; tile=16, occupancy=2)], + cfg -> cld(n, cfg.tile), + cfg -> (a, b, c, ct.Constant(cfg.tile)); + key=(:conflict, n), + occupancy=4, + tuning=(preset=:fast, refine_topk=0)) + end + + @testset "@autotune macro: NT space" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.@autotune( + key = (:macro_nt, n), + space = (tile=(16, 32, 64),), + blocks = cld(n, $tile), + tuning = (preset=:fast, refine_topk=0), + vadd_kernel(a, b, c, ct.Constant($tile)), + ) + @test !result.cache_hit + @test result.tuned_config.tile in (16, 32, 64) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: vector space" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.@autotune( + key = (:macro_vec, n), + space = [(; tile=16), (; tile=32)], + blocks = cld(n, $tile), + tuning = (preset=:fast, refine_topk=0), + vadd_kernel(a, b, c, ct.Constant($tile)), + ) + @test !result.cache_hit + @test result.tuned_config.tile in (16, 32) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: tuple blocks + 2D \$interp" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + # Use a 1D kernel but pass a Tuple blocks=(N, 1) to exercise the + # tuple-grid + $X-interp-in-blocks path. + result = Exp.@autotune( + key = (:macro_tuple, n), + space = (tile=(16, 32),), + blocks = (cld(n, $tile), 1), + tuning = (preset=:fast, refine_topk=0), + vadd_kernel(a, b, c, ct.Constant($tile)), + ) + @test result.grid == (cld(n, result.tuned_config.tile), 1) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: static num_ctas as kwarg" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.@autotune( + key = (:macro_static, n), + space = (tile=(16, 32),), + blocks = cld(n, $tile), + occupancy = 2, + tuning = (preset=:fast, refine_topk=0), + vadd_kernel(a, b, c, ct.Constant($tile)), + ) + @test !result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: macro-time conflict error" begin + # Should error at macro expansion (not run time). + @test_throws LoadError @eval Exp.@autotune( + space = (tile=(16,), num_ctas=(1, 2)), + blocks = 1, + num_ctas = 4, + kernel(a), + ) + end + + @testset "@autotune macro: required kwargs" begin + @test_throws LoadError @eval Exp.@autotune(blocks = 1, kernel(a)) + @test_throws LoadError @eval Exp.@autotune(space = (tile=(16,),), kernel(a)) + @test_throws LoadError @eval Exp.@autotune(space = (tile=(16,),), blocks = 1) + end end From fa5c5bed1060f6510e38f84adf8ea8b1e9cc5ead Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 26 May 2026 23:48:31 +0200 Subject: [PATCH 06/15] Preserve `cache_hit=true` when a race fills the cache mid-tune. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `find_or_tune`'s final write block already handles the case where two threads race into tuning the same key: only one wins the `per_kernel[arg_key] = candidate` write, the other reads the winner's entry. The original autotune branch returned `cache_hit=true` for the race-loser; my port dropped that signal and hardcoded `false`, so the result NT misreported provenance in races. Cosmetic — the right entry is still returned — but worth fixing while the audit found it. --- src/autotune.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/autotune.jl b/src/autotune.jl index 271ff0e4..fb28f20c 100644 --- a/src/autotune.jl +++ b/src/autotune.jl @@ -241,7 +241,10 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac _, best_idx = findmin(last, record) candidate = (; best_config=record[best_idx][1], tuning_record=record) - entry, _ = lock(AUTOTUNE_LOCK) do + # Race: another thread may have populated the cache while we were + # tuning. If so, return their result and report `cache_hit=true` so + # the caller's accounting stays accurate. + entry, cache_hit = lock(AUTOTUNE_LOCK) do per_kernel = get!(Dict{Any,Any}, AUTOTUNE_CACHE, kernel_key) if !tuning.force && haskey(per_kernel, arg_key) per_kernel[arg_key], true @@ -250,7 +253,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac candidate, false end end - return entry, false, reset + return entry, cache_hit, reset end """ From dc34ccf8e414e97f00caa6ee40394944d29ee555 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 27 May 2026 00:05:25 +0200 Subject: [PATCH 07/15] =?UTF-8?q?Pipeline=20compile=E2=86=92measure=20duri?= =?UTF-8?q?ng=20autotune.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the two-phase all-compile-then-all-measure flow with a producer-consumer pipeline: compile workers push each finished cfg onto a Channel, the master task pulls them off in arrival order and times them on the GPU. The first cfg's measurement starts the moment that cfg's tileiras subprocess returns, overlapping with the remaining cfgs still compiling in the background. Master is the consumer by design — `eval_cfg`'s `CUDACore.synchronize` / `@elapsed` rely on task-local CUDA state, and we want the timed context to match the caller's. Producer sub-tasks (which only run `cufunction` / load the CUBIN) tolerate fresh task-local state. Falls back to `measure_candidates` (serial compile+measure) when `precompile_workers=0` or only one cfg is in the search space. Cancellation: master sets `cancelled[]=true` on `InterruptException`, drains the channel so producers don't block on `put!`, then waits for the producer driver to wind down. Record is now in completion order rather than trial order — the only behavior change visible to callers. Refinement (`sort!(record, by=last)`) doesn't care; tests don't check ordering. Sanity check on a 16-cfg sweep with 8 threads: ~2× wall-time vs. `precompile_workers=0` (combined effect of compile parallelism + pipelining). --- src/autotune.jl | 129 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 122 insertions(+), 7 deletions(-) diff --git a/src/autotune.jl b/src/autotune.jl index fb28f20c..f0fdca25 100644 --- a/src/autotune.jl +++ b/src/autotune.jl @@ -163,6 +163,119 @@ function measure_candidates(@nospecialize(f), configs::Vector{Any}, return record, first_error end +""" + pipelined_tune(f, configs, grid_fn, args_fn; ...) -> (record, precompile_error, first_error) + +Overlap parallel compile with sequential measure. Compile workers (`Threads.@spawn`) +push each finished cfg onto a `Channel`; the master task pulls them off in +arrival order and runs `eval_cfg` on each. Net wall time is roughly +`max(parallel_compile_time, total_measure_time + first_compile)` instead of +the all-compile-then-all-measure sum. + +The master task is the consumer by design: it inherits the CUDA context from +the caller, which the timed `eval_cfg` needs (CUDA state is task-local). +`workers=0` skips parallel compile entirely and falls back to a +compile-then-measure cycle on the master. + +`record` is in completion order, not trial order — callers that care about +deterministic ordering should sort the result. +""" +function pipelined_tune(@nospecialize(f), configs::Vector{Any}, + grid_fn::Function, args_fn::Function; + sm_arch::VersionNumber, opt_level::Int, + warmup::Int, reps::Int, workers::Int, + static_num_ctas=nothing, static_occupancy=nothing, + verify::Union{Nothing, Function}=nothing, + reset::Union{Nothing, Function}=nothing) + record = Tuple{Any, Float32}[] + if isempty(configs) + return record, nothing, nothing + end + + # Serial fallback: avoids channel + extra task overhead when there's + # nothing to overlap (workers=0) or only one cfg to evaluate. + if iszero(workers) || length(configs) == 1 + rec, ferr = measure_candidates(f, configs, grid_fn, args_fn; + sm_arch, opt_level, warmup, reps, + static_num_ctas, static_occupancy, verify, reset) + return rec, nothing, ferr + end + + workers = min(workers, Threads.nthreads(), length(configs)) + cancelled = Threads.Atomic{Bool}(false) + sem = Base.Semaphore(workers) + precompile_error = Ref{Any}(nothing) + err_lock = ReentrantLock() + + # Buffer == n: producers never block on put!. Channel carries + # (trial_index, cfg) pairs; the index is preserved so callers that want + # trial-order can recover it. + ch = Channel{Tuple{Int, Any}}(length(configs)) + + # Producer driver: runs in its own task so the master can start consuming + # the moment the first cfg lands. `@sync` inside ensures we don't close + # the channel until every spawned compiler has either pushed or recorded + # an error. + producer = Threads.@spawn begin + try + @sync for (i, cfg) in enumerate(configs) + cancelled[] && break + Threads.@spawn begin + cancelled[] && return + Base.acquire(sem) do + cancelled[] && return + try + precompile_cfg(f, cfg, args_fn; sm_arch, opt_level, + static_num_ctas, static_occupancy) + cancelled[] || put!(ch, (i, cfg)) + catch err + lock(err_lock) do + precompile_error[] === nothing && + (precompile_error[] = (cfg, err)) + end + end + end + end + end + finally + close(ch) + end + end + + # Master consumes (and times) on this task — keeps the CUDA context for + # `eval_cfg` consistent with the caller's. + first_error = nothing + try + for (_, cfg) in ch + ms = try + eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, + static_num_ctas, static_occupancy, verify, reset) + catch err + if err isa InterruptException + @warn "Benchmarking interrupted after $(length(record)) configs" + cancelled[] = true + break + end + err isa VerificationError && + @warn "Config $cfg failed verification, skipping" + first_error === nothing && (first_error = (cfg, err)) + continue + end + push!(record, (cfg, ms)) + end + catch err + cancelled[] = true + # Drain any items already in the channel so producers don't block on + # put! while we wait for them to notice the cancel flag. + while isready(ch); take!(ch); end + wait(producer) + rethrow() + end + + wait(producer) + return record, precompile_error[], first_error +end + function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG, grid_fn::Function, args_fn::Function, tuning; sm_arch::VersionNumber, opt_level::Int, kernel_key, arg_key, @@ -204,16 +317,18 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac # Each cfg differs only in `Constant{T,V}` values, so the generic # inference graph is identical — without sharing, kernels with slow # inference paths (e.g. `ct.load(..., order=…)`) pay that cost N times. - trials, precompile_error, record, first_error = + # + # `pipelined_tune` overlaps the parallel compile fan-out with sequential + # GPU measurement: as soon as the first compiler finishes, the master + # starts timing while the remaining compilers continue in the background. + record, precompile_error, first_error = with(_SCOPED_INF_CACHE => _fresh_inf_cache()) do - t, pe = precompile_candidates(f, trials, args_fn; - sm_arch, opt_level, workers=tuning.precompile_workers, - static_num_ctas, static_occupancy) - r, fe = measure_candidates(f, t, grid_fn, args_fn; - sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, + pipelined_tune(f, trials, grid_fn, args_fn; + sm_arch, opt_level, + warmup=tuning.warmup, reps=tuning.reps, + workers=tuning.precompile_workers, static_num_ctas, static_occupancy, verify=checker, reset) - (t, pe, r, fe) end if isempty(record) From 0d423e9ccf02b334b9755195231c5ecb2b254733 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 27 May 2026 09:35:51 +0200 Subject: [PATCH 08/15] Move to src/experimental/ --- src/cuTile.jl | 5 +---- src/{ => experimental}/Experimental.jl | 0 src/{ => experimental}/autotune.jl | 0 src/{ => experimental}/autotune_macro.jl | 0 4 files changed, 1 insertion(+), 4 deletions(-) rename src/{ => experimental}/Experimental.jl (100%) rename src/{ => experimental}/autotune.jl (100%) rename src/{ => experimental}/autotune_macro.jl (100%) diff --git a/src/cuTile.jl b/src/cuTile.jl index 8d29cff4..6a488c0c 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -106,9 +106,6 @@ since the macro expands to a fully-qualified reference to the actual ``` """ macro cutile(args...) - # Interpolate the actual module values rather than the symbols so the - # caller doesn't need `CUDACore`/`cuTile` in scope — the expanded form - # references the module objects directly. esc(:($CUDACore.@cuda backend=$cuTile $(args...))) end @@ -144,6 +141,6 @@ end include("precompile.jl") -include("Experimental.jl") +include("experimental/Experimental.jl") end # module cuTile diff --git a/src/Experimental.jl b/src/experimental/Experimental.jl similarity index 100% rename from src/Experimental.jl rename to src/experimental/Experimental.jl diff --git a/src/autotune.jl b/src/experimental/autotune.jl similarity index 100% rename from src/autotune.jl rename to src/experimental/autotune.jl diff --git a/src/autotune_macro.jl b/src/experimental/autotune_macro.jl similarity index 100% rename from src/autotune_macro.jl rename to src/experimental/autotune_macro.jl From 51f05891675e3767ec14fc580886deff779cb742 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 27 May 2026 09:48:58 +0200 Subject: [PATCH 09/15] Trim autotune API: drop key_fn, accept literal grid/args, untype kwargs. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `key_fn` removed. `arg_key` was always built eagerly inside `autotune_launch` (no laziness benefit), so `key_fn=f` and `key=f()` were identical from the caller's perspective. Use `key=`. - `grid` and `args` (formerly `grid_fn`/`args_fn`) now accept either a `cfg -> value` callable OR a plain value (wrapped in `Returns(...)`). Lets direct `autotune_launch` callers pass `cld(n, 16)` and `(a, b, c, Constant(16))` without writing trivial closures. The `@autotune` macro keeps emitting closures (because `$X` may appear). - `launch_args_fn` renamed to `launch_args` (same fn-or-value treatment). - Strip the `::Union{Nothing, Function}=nothing` annotations from kwargs throughout — they pinned the type for no reason and added noise. Implicit `Any` with `=nothing` does the same job. Tests: drop the `key_fn` case (3 assertions), add a "literal grid/args" case. Net 2388 total pass; autotune set goes 48 → 47. --- src/experimental/autotune.jl | 78 ++++++++++++++++-------------- src/experimental/autotune_macro.jl | 31 ++++++------ test/device/autotune.jl | 33 ++++++------- 3 files changed, 72 insertions(+), 70 deletions(-) diff --git a/src/experimental/autotune.jl b/src/experimental/autotune.jl index f0fdca25..e8feb461 100644 --- a/src/experimental/autotune.jl +++ b/src/experimental/autotune.jl @@ -34,9 +34,8 @@ function hints_from_cfg(cfg; static_num_ctas=nothing, static_occupancy=nothing) return (num_ctas=n, occupancy=o) end -function time_ms(run_once::Function, get_args::Function; - warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing, - reset::Union{Nothing, Function}=nothing) +function time_ms(run_once, get_args; + warmup::Int, reps::Int, verify=nothing, reset=nothing) CUDACore.synchronize() for _ in 1:max(warmup, verify !== nothing ? 1 : 0) reset !== nothing && reset() @@ -60,11 +59,10 @@ function time_ms(run_once::Function, get_args::Function; return best_ms end -function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; +function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, static_num_ctas=nothing, static_occupancy=nothing, - verify::Union{Nothing, Function}=nothing, - reset::Union{Nothing, Function}=nothing) + verify=nothing, reset=nothing) grid = grid_fn(cfg) grid_dims = grid isa Integer ? (grid,) : grid @@ -80,7 +78,7 @@ function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; return time_ms(run_once, get_args; warmup, reps, verify, reset) end -function precompile_cfg(@nospecialize(f), cfg, args_fn::Function; +function precompile_cfg(@nospecialize(f), cfg, args_fn; sm_arch::VersionNumber, opt_level::Int, static_num_ctas=nothing, static_occupancy=nothing) converted = map(cuTileconvert, args_fn(cfg)) @@ -90,8 +88,7 @@ function precompile_cfg(@nospecialize(f), cfg, args_fn::Function; return nothing end -function precompile_candidates(@nospecialize(f), configs::Vector{Any}, - args_fn::Function; +function precompile_candidates(@nospecialize(f), configs::Vector{Any}, args_fn; sm_arch::VersionNumber, opt_level::Int, workers::Int, static_num_ctas=nothing, static_occupancy=nothing) isempty(configs) && return configs, nothing @@ -137,12 +134,10 @@ function precompile_candidates(@nospecialize(f), configs::Vector{Any}, return configs[compiled], first_err end -function measure_candidates(@nospecialize(f), configs::Vector{Any}, - grid_fn::Function, args_fn::Function; +function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, static_num_ctas=nothing, static_occupancy=nothing, - verify::Union{Nothing, Function}=nothing, - reset::Union{Nothing, Function}=nothing) + verify=nothing, reset=nothing) record = Tuple{Any, Float32}[] first_error = nothing for cfg in configs @@ -180,13 +175,11 @@ compile-then-measure cycle on the master. `record` is in completion order, not trial order — callers that care about deterministic ordering should sort the result. """ -function pipelined_tune(@nospecialize(f), configs::Vector{Any}, - grid_fn::Function, args_fn::Function; +function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, workers::Int, static_num_ctas=nothing, static_occupancy=nothing, - verify::Union{Nothing, Function}=nothing, - reset::Union{Nothing, Function}=nothing) + verify=nothing, reset=nothing) record = Tuple{Any, Float32}[] if isempty(configs) return record, nothing, nothing @@ -277,11 +270,10 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, end function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG, - grid_fn::Function, args_fn::Function, tuning; + grid_fn, args_fn, tuning; sm_arch::VersionNumber, opt_level::Int, kernel_key, arg_key, static_num_ctas=nothing, static_occupancy=nothing, - verify::Union{Nothing, Function}=nothing, - setup::Union{Nothing, Function}=nothing) + verify=nothing, setup=nothing) if !tuning.force entry = lock(AUTOTUNE_LOCK) do per_kernel = get(AUTOTUNE_CACHE, kernel_key, nothing) @@ -371,15 +363,27 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac return entry, cache_hit, reset end +# Normalize `grid`/`args` to a callable. A `Function` is used as-is; any +# other value is wrapped in `Returns(...)`. Lets `autotune_launch` accept +# `blocks = 1024` or `args = (a, b, c)` directly instead of forcing the +# caller to write `cfg -> 1024` for grids that don't depend on the cfg. +@inline _as_cfg_fn(f::Function) = f +@inline _as_cfg_fn(x) = Returns(x) + """ - autotune_launch(f, space, grid_fn, args_fn; key, key_fn, launch_args_fn, - verify, setup, tuning, sm_arch, opt_level, + autotune_launch(f, space, grid, args; key, launch_args, verify, setup, + tuning, sm_arch, opt_level, num_ctas=nothing, occupancy=nothing) Tune `f` over `space` (an [`AbstractSearchSpace`](@ref) or a `Vector`/`NamedTuple` -shorthand) and launch the best config. `grid_fn(cfg)` returns the launch -grid; `args_fn(cfg)` returns the argument tuple. Results are cached per -`(f, sm_arch, opt_level) ⇒ key`. +shorthand) and launch the best config. + +`grid` and `args` are either functions of the form `cfg -> grid` / +`cfg -> args_tuple` (for grids/args that depend on the cfg), or plain +values (for grids/args that don't). Internally both are normalized to +callables. `launch_args` follows the same convention; defaults to `args`. + +Results are cached per `(f, sm_arch, opt_level, num_ctas, occupancy) ⇒ key`. `num_ctas` and `occupancy` may be supplied as **static** kwargs (applied uniformly to every cfg — useful for `ByTarget(...)`-style per-arch dispatch) @@ -387,12 +391,11 @@ OR as **axes** inside `space` (tuned per cfg), but not both. Specifying both throws an `ArgumentError`. """ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, - grid_fn::Function, args_fn::Function; + grid, args; key=nothing, - key_fn::Union{Nothing, Function}=nothing, - launch_args_fn::Union{Nothing, Function}=nothing, - verify::Union{Nothing, Function}=nothing, - setup::Union{Nothing, Function}=nothing, + launch_args=nothing, + verify=nothing, + setup=nothing, tuning::NamedTuple=NamedTuple(), sm_arch::VersionNumber=default_sm_arch(), opt_level::Int=3, @@ -401,30 +404,33 @@ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, tuning = normalize_tuning(tuning) rng = tuning.seed !== nothing ? MersenneTwister(tuning.seed) : Random.default_rng() + grid_fn = _as_cfg_fn(grid) + args_fn = _as_cfg_fn(args) + launch_args_fn = launch_args === nothing ? args_fn : _as_cfg_fn(launch_args) + kernel_key = (f, sm_arch, opt_level, num_ctas, occupancy) - arg_key = key !== nothing ? key : (key_fn !== nothing ? key_fn() : nothing) entry, cache_hit, reset = find_or_tune(f, space, rng, grid_fn, args_fn, tuning; - sm_arch, opt_level, kernel_key, arg_key, + sm_arch, opt_level, kernel_key, arg_key=key, static_num_ctas=num_ctas, static_occupancy=occupancy, verify, setup) cfg = entry.best_config grid = grid_fn(cfg) - args = launch_args_fn !== nothing ? launch_args_fn(cfg) : args_fn(cfg) + launched_args = launch_args_fn(cfg) reset !== nothing && reset() - cuTile.launch(f, grid, args...; sm_arch, opt_level, + cuTile.launch(f, grid, launched_args...; sm_arch, opt_level, hints_from_cfg(cfg; static_num_ctas=num_ctas, static_occupancy=occupancy)...) return (; tuned_config=cfg, grid, tuning_record=copy(entry.tuning_record), cache_hit) end -function autotune_launch(@nospecialize(f), configs, grid_fn::Function, args_fn::Function; kwargs...) +function autotune_launch(@nospecialize(f), configs, grid, args; kwargs...) space = configs isa NamedTuple ? CartesianSpace(configs) : FixedSpace(configs) - return autotune_launch(f, space, grid_fn, args_fn; kwargs...) + return autotune_launch(f, space, grid, args; kwargs...) end function clear_autotune_cache(; kernel=nothing, key=nothing) diff --git a/src/experimental/autotune_macro.jl b/src/experimental/autotune_macro.jl index bac633cf..1b421661 100644 --- a/src/experimental/autotune_macro.jl +++ b/src/experimental/autotune_macro.jl @@ -56,7 +56,7 @@ function _autotune_space_axes(space_expr) end const _AUTOTUNE_KWARGS = (:key, :space, :blocks, :tuning, :verify, :setup, - :sm_arch, :opt_level, :key_fn, :launch_args_fn, + :sm_arch, :opt_level, :launch_args, :num_ctas, :occupancy) """ @@ -74,22 +74,21 @@ tuning configuration being evaluated). - `blocks` — grid dimensions, an `Int` or `Tuple`. May reference `\$X`. # Optional kwargs -- `key` — eager cache key (any value) -- `key_fn` — lazy alternative to `key` -- `tuning` — `NamedTuple` of tuning knobs (`preset`, `force`, etc.) -- `verify` — `() -> (() -> Bool)` factory; the returned checker is - called after each warmup pass to reject incorrect cfgs -- `setup` — `() -> (() -> Nothing)` factory; reset between reps -- `launch_args_fn` — final-launch arg builder (defaults to the kernel-call - args); use this when the timed args are throwaway - copies (in-place kernels) and the final launch should - hit the real buffers +- `key` — cache key (any value) +- `tuning` — `NamedTuple` of tuning knobs (`preset`, `force`, etc.) +- `verify` — `() -> (() -> Bool)` factory; the returned checker is + called after each warmup pass to reject incorrect cfgs +- `setup` — `() -> (() -> Nothing)` factory; reset between reps +- `launch_args` — final-launch args (or `cfg -> args` if it should differ + from the kernel-call args). Use this when the timed args + are throwaway copies (in-place kernels) and the final + launch should hit the real buffers - `sm_arch`, `opt_level` — forwarded to `cufunction` - `num_ctas`, `occupancy` — **static** hints applied uniformly to every - cfg. May not coexist with same-named axes in `space` - (the macro flags the conflict at expansion time when - `space` is a literal NT; otherwise `autotune_launch` - catches it at run time) + cfg. May not coexist with same-named axes in `space` + (the macro flags the conflict at expansion time when + `space` is a literal NT; otherwise `autotune_launch` + catches it at run time) # Example @@ -164,7 +163,7 @@ macro autotune(args...) # Forward all macro kwargs (except space/blocks, which are positional # / lifted into the closures) to `autotune_launch`. forwarded_keys = (:key, :tuning, :verify, :setup, :sm_arch, :opt_level, - :key_fn, :launch_args_fn, :num_ctas, :occupancy) + :launch_args, :num_ctas, :occupancy) kw_exprs = [Expr(:kw, k, kwargs[k]) for k in forwarded_keys if haskey(kwargs, k)] return esc(quote diff --git a/test/device/autotune.jl b/test/device/autotune.jl index 3cac1253..17af6c1f 100644 --- a/test/device/autotune.jl +++ b/test/device/autotune.jl @@ -115,7 +115,7 @@ const Exp = ct.Experimental @test Array(c) ≈ fill(3f0, n) end - @testset "launch_args_fn (inplace kernel)" begin + @testset "launch_args (inplace kernel)" begin x = CUDA.zeros(Float32, n) original_x = Array(x) Exp.clear_autotune_cache() @@ -124,7 +124,7 @@ const Exp = ct.Experimental [(; tile=16), (; tile=32)], grid_fn, cfg -> (copy(x), ct.Constant(cfg.tile)); - launch_args_fn=cfg -> (x, ct.Constant(cfg.tile)), + launch_args=cfg -> (x, ct.Constant(cfg.tile)), key=(:inplace, n), tuning=(preset=:fast, refine_topk=0), ) @@ -211,24 +211,21 @@ const Exp = ct.Experimental @test Array(c2) ≈ fill(3f0, n2) end - @testset "key_fn" begin + @testset "literal grid/args (no closure)" begin Exp.clear_autotune_cache() - call_count = Ref(0) - my_key_fn = () -> begin - call_count[] += 1 - return (:dynamic, Float32) - end - fill!(c, 0f0) - r1 = Exp.autotune_launch( - vadd_kernel, configs, grid_fn, args_fn; - key_fn=my_key_fn, tuning=(preset=:fast, refine_topk=0)) - r2 = Exp.autotune_launch( - vadd_kernel, configs, grid_fn, args_fn; - key_fn=my_key_fn, tuning=(preset=:fast, refine_topk=0)) - @test !r1.cache_hit - @test r2.cache_hit - @test call_count[] == 2 + # Pass `grid` and `args` as values rather than `cfg -> …` closures. + # cfg-independent grid: cld(n, tile) == cld(512, 16) for every cfg + # so the literal `32` happens to be valid here. + result = Exp.autotune_launch( + vadd_kernel, + [(; tile=16)], + cld(n, 16), # literal grid + (a, b, c, ct.Constant(16)); # literal args + key=(:literal, n), + tuning=(preset=:fast, refine_topk=0)) + @test !result.cache_hit + @test result.grid == cld(n, 16) @test Array(c) ≈ fill(3f0, n) end From 574ddac9622786a6b615769d1efc23684752b642 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 27 May 2026 09:53:28 +0200 Subject: [PATCH 10/15] Drop single-value `nothing` axes from autotune test fixtures. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A CartesianSpace axis with only `(nothing,)` adds the same field to every cfg with value `nothing` — same outcome as omitting the axis, since `hints_from_cfg` already falls through `hasproperty`. Was present in the original autotune branch too; just noise. Keep `occupancy=(nothing, 2)` in the CartesianSpace testset (tunes between "no hint" and 2, which is meaningful) and keep the explicit `nothing` slots on the `configs` Vector (FixedSpace requires uniform NT shape across elements; cfg 1's nothings make the shape match cfgs 2/3 which carry real hint values). --- test/device/autotune.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/device/autotune.jl b/test/device/autotune.jl index 17af6c1f..cfe33d0b 100644 --- a/test/device/autotune.jl +++ b/test/device/autotune.jl @@ -28,10 +28,14 @@ const Exp = ct.Experimental b = CUDA.fill(2f0, n) c = CUDA.zeros(Float32, n) + # cfg 1's `occupancy=nothing, num_ctas=nothing` slots are present for + # `FixedSpace` shape uniformity — cfgs 2 and 3 carry real hint values, + # and `FixedSpace{names, NT<:NamedTuple{names}}` requires every element + # to share the same `names` set. configs = [ (; tile=16, occupancy=nothing, num_ctas=nothing), - (; tile=32, occupancy=2, num_ctas=nothing), - (; tile=64, occupancy=4, num_ctas=2), + (; tile=32, occupancy=2, num_ctas=nothing), + (; tile=64, occupancy=4, num_ctas=2), ] args_fn = cfg -> (a, b, c, ct.Constant(cfg.tile)) grid_fn = cfg -> cld(n, cfg.tile) @@ -74,8 +78,12 @@ const Exp = ct.Experimental @testset "CartesianSpace" begin Exp.clear_autotune_cache() fill!(c, 0f0) + # Keep `occupancy=(nothing, 2)` — legitimately tunes between + # "no hint" and 2. Single-value `nothing` axes are noise (see + # other testsets); a 2-value axis with `nothing` as one option + # is meaningful. space = Exp.CartesianSpace(; - tile=(16, 32), occupancy=(nothing, 2), num_ctas=(nothing,)) + tile=(16, 32), occupancy=(nothing, 2)) result = Exp.autotune_launch( vadd_kernel, space, grid_fn, args_fn; key=(:cartesian, n), @@ -91,7 +99,7 @@ const Exp = ct.Experimental fill!(c, 0f0) space = Exp.CartesianSpace( cfg -> cfg.tile == 16; - tile=(16, 32, 64), occupancy=(nothing,), num_ctas=(nothing,)) + tile=(16, 32, 64)) result = Exp.autotune_launch( vadd_kernel, space, grid_fn, args_fn; key=(:constrained, n), @@ -106,7 +114,7 @@ const Exp = ct.Experimental fill!(c, 0f0) result = Exp.autotune_launch( vadd_kernel, - (tile=(16, 32), occupancy=(nothing,), num_ctas=(nothing,)), + (tile=(16, 32),), grid_fn, args_fn; key=(:nt_convenience, n), tuning=(preset=:fast, refine_topk=0), From 7f5b9f827424587bca20d1a2ccb3921069c277f5 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 27 May 2026 18:46:50 +0200 Subject: [PATCH 11/15] Refactor experimental autotuning internals --- src/experimental/Experimental.jl | 37 +- src/experimental/autotune.jl | 547 +++++++++++++++-------------- src/experimental/autotune_macro.jl | 33 +- src/experimental/search_space.jl | 58 +++ test/device/autotune.jl | 57 +++ 5 files changed, 415 insertions(+), 317 deletions(-) create mode 100644 src/experimental/search_space.jl diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl index 9f3a56a0..485d33fc 100644 --- a/src/experimental/Experimental.jl +++ b/src/experimental/Experimental.jl @@ -10,47 +10,16 @@ import Core.Compiler as CC using Random # Builds a fresh inference cache compatible with the running Julia version. -# Used to wrap an autotune pass in `with(_SCOPED_INF_CACHE => …)` so all the +# Used to wrap an autotune pass in `with(_SCOPED_INF_CACHE => ...)` so all the # per-config const-seeded inference calls share results instead of paying -# the slow paths (e.g. `ct.load(..., order=…)`) once per config. +# the slow paths (e.g. `ct.load(..., order=...)`) once per config. @inline _fresh_inf_cache() = @static if isdefined(CC, :InferenceCache) CC.InferenceCache() else Vector{CC.InferenceResult}() end -abstract type AbstractSearchSpace end - -Base.length(s::AbstractSearchSpace) = count(_ -> true, s) - -struct FixedSpace{names,NT<:NamedTuple{names}} <: AbstractSearchSpace - elements::Vector{NT} -end - -Base.iterate(space::FixedSpace, args...) = iterate(space.elements, args...) - -struct CartesianSpace{names,NT<:NamedTuple{names,<:Tuple{Vararg{Tuple}}}} <: AbstractSearchSpace - constraint::Function - axes::NT -end - -CartesianSpace(axes::NamedTuple) = CartesianSpace(Returns(true), axes) -CartesianSpace(; axes...) = CartesianSpace(NamedTuple(axes)) -CartesianSpace(constraint::Function; axes...) = CartesianSpace(constraint, NamedTuple(axes)) - -function Base.iterate(space::CartesianSpace{names}, state=nothing) where names - to_cfg = vals -> NamedTuple{names}(vals) - inner = state === nothing ? - Iterators.filter(space.constraint ∘ to_cfg, - Iterators.product(map(Tuple, values(space.axes))...)) : - state.inner - result = isnothing(state) ? iterate(inner) : iterate(inner, state.cursor) - isnothing(result) && return nothing - vals, cursor = result - cfg = to_cfg(vals) - return cfg, (; inner, cursor) -end - +include("search_space.jl") include("autotune.jl") include("autotune_macro.jl") diff --git a/src/experimental/autotune.jl b/src/experimental/autotune.jl index e8feb461..c382bf0c 100644 --- a/src/experimental/autotune.jl +++ b/src/experimental/autotune.jl @@ -1,43 +1,140 @@ -const AUTOTUNE_LOCK = ReentrantLock() -const AUTOTUNE_CACHE = Dict{Any, Dict{Any, Any}}() +public autotune_launch, clear_autotune_cache + +const AUTOTUNE_CACHE = Base.Lockable(Dict{Any, Dict{Any, Any}}()) struct VerificationError <: Exception msg::String end +Base.showerror(io::IO, err::VerificationError) = + print(io, "VerificationError: ", err.msg) + const TUNING_PRESETS = ( - fast = (warmup=1, reps=3, refine_topk=0, refine_reps=2), - default = (warmup=2, reps=5, refine_topk=2, refine_reps=4), - thorough = (warmup=2, reps=7, refine_topk=4, refine_reps=6), + fast = (warmup=1, reps=3, refine_topk=0, refine_reps=2), + default = (warmup=2, reps=5, refine_topk=2, refine_reps=4), + thorough = (warmup=2, reps=7, refine_topk=4, refine_reps=6), +) + +const TUNING_KEYS = ( + :warmup, :reps, :refine_topk, :refine_reps, + :seed, :force, :precompile_workers, +) + +struct TuningOptions + warmup::Int + reps::Int + refine_topk::Int + refine_reps::Int + seed::Union{Nothing, Int} + force::Bool + precompile_workers::Int +end + +_tuning_defaults() = ( + seed=nothing, + force=false, + precompile_workers=Threads.nthreads(), ) +function _check_int(name::Symbol, value; min::Int) + (value isa Integer && !(value isa Bool)) || + throw(ArgumentError("tuning.$name must be an integer, got $(typeof(value))")) + value >= min || + throw(ArgumentError("tuning.$name must be >= $min, got $value")) + return Int(value) +end + +function _check_seed(seed) + seed === nothing && return nothing + (seed isa Integer && !(seed isa Bool)) || + throw(ArgumentError("tuning.seed must be an integer or nothing, got $(typeof(seed))")) + return Int(seed) +end + +function _check_bool(name::Symbol, value) + value isa Bool || + throw(ArgumentError("tuning.$name must be a Bool, got $(typeof(value))")) + return value +end + function normalize_tuning(tuning::NamedTuple) + valid_keys = (:preset, TUNING_KEYS...) + unknown = setdiff(collect(keys(tuning)), collect(valid_keys)) + isempty(unknown) || + throw(ArgumentError("Unknown tuning option(s): $(join(unknown, ", "))")) + preset = get(tuning, :preset, :default) preset isa Symbol || throw(ArgumentError("tuning.preset must be a Symbol")) hasproperty(TUNING_PRESETS, preset) || - throw(ArgumentError("Unknown preset `$preset`; use :fast, :default, or :thorough")) - - base = merge(getproperty(TUNING_PRESETS, preset), - (seed=nothing, force=false, precompile_workers=Threads.nthreads())) + throw(ArgumentError("Unknown tuning preset `$preset`; use :fast, :default, or :thorough")) overrides = NamedTuple(k => v for (k, v) in pairs(tuning) if k !== :preset) - return merge(base, overrides) + values = merge(_tuning_defaults(), getproperty(TUNING_PRESETS, preset), overrides) + + return TuningOptions( + _check_int(:warmup, values.warmup; min=0), + _check_int(:reps, values.reps; min=1), + _check_int(:refine_topk, values.refine_topk; min=0), + _check_int(:refine_reps, values.refine_reps; min=1), + _check_seed(values.seed), + _check_bool(:force, values.force), + _check_int(:precompile_workers, values.precompile_workers; min=0), + ) end -# Extract hint fields (num_ctas, occupancy) from a config, falling back to -# the static defaults supplied by the caller. cfg takes precedence; the -# caller is expected to have rejected the both-supplied case upstream -# (see `autotune_launch`). +@inline _hint_from_cfg(cfg, name::Symbol, fallback) = + hasproperty(cfg, name) ? getproperty(cfg, name) : fallback + function hints_from_cfg(cfg; static_num_ctas=nothing, static_occupancy=nothing) - n = hasproperty(cfg, :num_ctas) ? cfg.num_ctas : static_num_ctas - o = hasproperty(cfg, :occupancy) ? cfg.occupancy : static_occupancy - return (num_ctas=n, occupancy=o) + return ( + num_ctas=_hint_from_cfg(cfg, :num_ctas, static_num_ctas), + occupancy=_hint_from_cfg(cfg, :occupancy, static_occupancy), + ) +end + +function _check_static_hint_conflict(configs, hint::Symbol, static_value) + static_value === nothing && return nothing + for cfg in configs + if hasproperty(cfg, hint) + throw(ArgumentError( + "`$hint` is both a static kwarg and an axis in the search space. " * + "Pick one.")) + end + end + return nothing end -function time_ms(run_once, get_args; - warmup::Int, reps::Int, verify=nothing, reset=nothing) +function _check_static_hint_conflicts(configs; static_num_ctas=nothing, + static_occupancy=nothing) + _check_static_hint_conflict(configs, :num_ctas, static_num_ctas) + _check_static_hint_conflict(configs, :occupancy, static_occupancy) + return nothing +end + +function _collect_trials(space::AbstractSearchSpace, seed) + trials = Any[cfg for cfg in space] + if seed !== nothing && length(trials) > 1 + shuffle!(MersenneTwister(seed), trials) + end + return trials +end + +@inline _grid_dims(grid) = grid isa Integer ? (grid,) : grid +@inline _converted_args(args_fn, cfg) = map(cuTileconvert, args_fn(cfg)) +@inline _argtypes(args) = Tuple{map(Core.Typeof, args)...} + +function _compile_cfg(@nospecialize(f), cfg, args_fn; + sm_arch::VersionNumber, opt_level::Int, + static_num_ctas=nothing, static_occupancy=nothing) + converted = _converted_args(args_fn, cfg) + return cufunction(f, _argtypes(converted); sm_arch, opt_level, + hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) +end + +function _time_ms(run_once, get_args; + warmup::Int, reps::Int, verify=nothing, reset=nothing) CUDACore.synchronize() - for _ in 1:max(warmup, verify !== nothing ? 1 : 0) + for _ in 1:max(warmup, verify === nothing ? 0 : 1) reset !== nothing && reset() run_once(get_args()) end @@ -63,256 +160,212 @@ function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, static_num_ctas=nothing, static_occupancy=nothing, verify=nothing, reset=nothing) - grid = grid_fn(cfg) - grid_dims = grid isa Integer ? (grid,) : grid - - # Compile once, then convert + call each rep. We `cufunction` outside the - # timed loop so JIT cost doesn't pollute the measurement. - sample_converted = map(cuTileconvert, args_fn(cfg)) - tt = Tuple{map(Core.Typeof, sample_converted)...} - kernel = cufunction(f, tt; sm_arch, opt_level, - hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) - - run_once = converted -> kernel(converted...; blocks=grid_dims) - get_args = () -> map(cuTileconvert, args_fn(cfg)) - return time_ms(run_once, get_args; warmup, reps, verify, reset) + grid = _grid_dims(grid_fn(cfg)) + kernel = _compile_cfg(f, cfg, args_fn; sm_arch, opt_level, + static_num_ctas, static_occupancy) + + run_once = converted -> kernel(converted...; blocks=grid) + get_args = () -> _converted_args(args_fn, cfg) + return _time_ms(run_once, get_args; warmup, reps, verify, reset) end function precompile_cfg(@nospecialize(f), cfg, args_fn; sm_arch::VersionNumber, opt_level::Int, static_num_ctas=nothing, static_occupancy=nothing) - converted = map(cuTileconvert, args_fn(cfg)) - tt = Tuple{map(Core.Typeof, converted)...} - cufunction(f, tt; sm_arch, opt_level, - hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) + _compile_cfg(f, cfg, args_fn; sm_arch, opt_level, + static_num_ctas, static_occupancy) return nothing end -function precompile_candidates(@nospecialize(f), configs::Vector{Any}, args_fn; - sm_arch::VersionNumber, opt_level::Int, workers::Int, - static_num_ctas=nothing, static_occupancy=nothing) - isempty(configs) && return configs, nothing - iszero(workers) && return configs, nothing - - workers = min(workers, Threads.nthreads(), length(configs)) - compiled = fill(true, length(configs)) - errors = Vector{Any}(nothing, length(configs)) - sem = Base.Semaphore(workers) - cancelled = Threads.Atomic{Bool}(false) - - try - @sync for (i, cfg) in enumerate(configs) - Threads.@spawn begin - cancelled[] && return - Base.acquire(sem) do - cancelled[] && return - try - precompile_cfg(f, cfg, args_fn; sm_arch, opt_level, - static_num_ctas, static_occupancy) - catch err - compiled[i] = false - errors[i] = (cfg, err) - end - end - end - end - catch e - cancelled[] = true - e isa InterruptException || rethrow() - @warn "Precompilation interrupted, waiting for in-flight workers…" - rethrow() - end +const TimingRecord = Tuple{Any, Float32} - first_err = nothing - for e in errors - if e !== nothing - first_err = e - break +function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue, + @nospecialize(f), cfg, grid_fn, args_fn; kwargs...) + ms = try + eval_cfg(f, cfg, grid_fn, args_fn; kwargs...) + catch err + err isa InterruptException && rethrow() + if err isa VerificationError + @warn "Config failed verification; skipping" cfg + else + bt = catch_backtrace() + @debug "Config failed during autotuning; skipping" cfg exception=(err, bt) end + first_error[] === nothing && (first_error[] = (cfg, err)) + return nothing end - - return configs[compiled], first_err + push!(record, (cfg, ms)) + return nothing end function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; - sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, + sm_arch::VersionNumber, opt_level::Int, + warmup::Int, reps::Int, static_num_ctas=nothing, static_occupancy=nothing, verify=nothing, reset=nothing) - record = Tuple{Any, Float32}[] - first_error = nothing + record = TimingRecord[] + first_error = Ref{Any}(nothing) for cfg in configs - ms = try - eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, - static_num_ctas, static_occupancy, verify, reset) - catch err - if err isa InterruptException - @warn "Benchmarking interrupted after $(length(record)) configs" - break - end - err isa VerificationError && @warn "Config $cfg failed verification, skipping" - first_error === nothing && (first_error = (cfg, err)) - continue - end - push!(record, (cfg, ms)) + _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn; + sm_arch, opt_level, warmup, reps, + static_num_ctas, static_occupancy, verify, reset) end - return record, first_error + return record, first_error[] end """ pipelined_tune(f, configs, grid_fn, args_fn; ...) -> (record, precompile_error, first_error) -Overlap parallel compile with sequential measure. Compile workers (`Threads.@spawn`) -push each finished cfg onto a `Channel`; the master task pulls them off in -arrival order and runs `eval_cfg` on each. Net wall time is roughly -`max(parallel_compile_time, total_measure_time + first_compile)` instead of -the all-compile-then-all-measure sum. - -The master task is the consumer by design: it inherits the CUDA context from -the caller, which the timed `eval_cfg` needs (CUDA state is task-local). -`workers=0` skips parallel compile entirely and falls back to a -compile-then-measure cycle on the master. - -`record` is in completion order, not trial order — callers that care about -deterministic ordering should sort the result. +Compile candidate configurations on worker tasks while the caller task measures +completed candidates. Measurement stays on the caller task because CUDA state is +task-local; workers only run the untimed `cufunction` precompile path. """ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, workers::Int, static_num_ctas=nothing, static_occupancy=nothing, verify=nothing, reset=nothing) - record = Tuple{Any, Float32}[] - if isempty(configs) - return record, nothing, nothing - end + isempty(configs) && return TimingRecord[], nothing, nothing - # Serial fallback: avoids channel + extra task overhead when there's - # nothing to overlap (workers=0) or only one cfg to evaluate. if iszero(workers) || length(configs) == 1 - rec, ferr = measure_candidates(f, configs, grid_fn, args_fn; + record, first_error = measure_candidates(f, configs, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, static_num_ctas, static_occupancy, verify, reset) - return rec, nothing, ferr + return record, nothing, first_error end workers = min(workers, Threads.nthreads(), length(configs)) + ready = Channel{Any}(length(configs)) + jobs = Channel{Any}(length(configs)) + foreach(cfg -> put!(jobs, cfg), configs) + close(jobs) + cancelled = Threads.Atomic{Bool}(false) - sem = Base.Semaphore(workers) precompile_error = Ref{Any}(nothing) - err_lock = ReentrantLock() - - # Buffer == n: producers never block on put!. Channel carries - # (trial_index, cfg) pairs; the index is preserved so callers that want - # trial-order can recover it. - ch = Channel{Tuple{Int, Any}}(length(configs)) - - # Producer driver: runs in its own task so the master can start consuming - # the moment the first cfg lands. `@sync` inside ensures we don't close - # the channel until every spawned compiler has either pushed or recorded - # an error. - producer = Threads.@spawn begin - try - @sync for (i, cfg) in enumerate(configs) + error_lock = ReentrantLock() + + producer = Threads.@spawn try + @sync for _ in 1:workers + Threads.@spawn for cfg in jobs cancelled[] && break - Threads.@spawn begin - cancelled[] && return - Base.acquire(sem) do - cancelled[] && return - try - precompile_cfg(f, cfg, args_fn; sm_arch, opt_level, - static_num_ctas, static_occupancy) - cancelled[] || put!(ch, (i, cfg)) - catch err - lock(err_lock) do - precompile_error[] === nothing && - (precompile_error[] = (cfg, err)) - end - end + try + precompile_cfg(f, cfg, args_fn; sm_arch, opt_level, + static_num_ctas, static_occupancy) + cancelled[] || put!(ready, cfg) + catch err + if err isa InterruptException + cancelled[] = true + rethrow() + end + lock(error_lock) do + precompile_error[] === nothing && + (precompile_error[] = (cfg, err)) end end end - finally - close(ch) end + finally + close(ready) end - # Master consumes (and times) on this task — keeps the CUDA context for - # `eval_cfg` consistent with the caller's. - first_error = nothing + record = TimingRecord[] + first_error = Ref{Any}(nothing) try - for (_, cfg) in ch - ms = try - eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, - static_num_ctas, static_occupancy, verify, reset) - catch err - if err isa InterruptException - @warn "Benchmarking interrupted after $(length(record)) configs" - cancelled[] = true - break - end - err isa VerificationError && - @warn "Config $cfg failed verification, skipping" - first_error === nothing && (first_error = (cfg, err)) - continue - end - push!(record, (cfg, ms)) + for cfg in ready + _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn; + sm_arch, opt_level, warmup, reps, + static_num_ctas, static_occupancy, verify, reset) end - catch err - cancelled[] = true - # Drain any items already in the channel so producers don't block on - # put! while we wait for them to notice the cancel flag. - while isready(ch); take!(ch); end wait(producer) + catch + cancelled[] = true + while isready(ready) + take!(ready) + end + try + wait(producer) + catch + end rethrow() end - wait(producer) - return record, precompile_error[], first_error + return record, precompile_error[], first_error[] end -function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG, - grid_fn, args_fn, tuning; +@inline _entry_in_trials(entry, trials) = + any(cfg -> cfg == entry.best_config, trials) + +function _cached_entry(kernel_key, arg_key, trials) + entry = Base.@lock AUTOTUNE_CACHE begin + per_kernel = get(AUTOTUNE_CACHE[], kernel_key, nothing) + per_kernel === nothing ? nothing : get(per_kernel, arg_key, nothing) + end + entry !== nothing && _entry_in_trials(entry, trials) ? entry : nothing +end + +function _cache_candidate!(candidate, kernel_key, arg_key, trials; force::Bool) + Base.@lock AUTOTUNE_CACHE begin + per_kernel = get!(AUTOTUNE_CACHE[], kernel_key) do + Dict{Any, Any}() + end + if !force + entry = get(per_kernel, arg_key, nothing) + if entry !== nothing && _entry_in_trials(entry, trials) + return entry, true + end + end + per_kernel[arg_key] = candidate + return candidate, false + end +end + +function _no_valid_config_error(first_error, precompile_error) + err_info = first_error !== nothing ? first_error : precompile_error + if err_info === nothing + throw(ArgumentError("No valid config found in search space")) + end + + cfg, err = err_info + throw(ArgumentError( + "No valid config found. First failure for cfg=$cfg: $(sprint(showerror, err))")) +end + +function _refine_record(@nospecialize(f), record::Vector{TimingRecord}, tuning::TuningOptions, + grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, + static_num_ctas=nothing, static_occupancy=nothing, + verify=nothing, reset=nothing) + (tuning.refine_topk > 0 && length(record) > 1) || return record + + sort!(record, by=last) + top = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] + refined, _ = measure_candidates(f, top, grid_fn, args_fn; + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, + static_num_ctas, static_occupancy, verify, reset) + return isempty(refined) ? record : refined +end + +function _best_candidate(record::Vector{TimingRecord}) + _, best_idx = findmin(last, record) + return (; best_config=record[best_idx][1], tuning_record=record) +end + +function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, + grid_fn, args_fn, tuning::TuningOptions; sm_arch::VersionNumber, opt_level::Int, kernel_key, arg_key, static_num_ctas=nothing, static_occupancy=nothing, verify=nothing, setup=nothing) + trials = _collect_trials(space, tuning.seed) + isempty(trials) && throw(ArgumentError("No valid config found in search space")) + _check_static_hint_conflicts(trials; static_num_ctas, static_occupancy) + if !tuning.force - entry = lock(AUTOTUNE_LOCK) do - per_kernel = get(AUTOTUNE_CACHE, kernel_key, nothing) - per_kernel !== nothing ? get(per_kernel, arg_key, nothing) : nothing - end + entry = _cached_entry(kernel_key, arg_key, trials) entry !== nothing && return entry, true, nothing end checker = verify !== nothing ? verify() : nothing reset = setup !== nothing ? setup() : nothing - trials = Any[collect(space)...] - - # Conflict check: if the cfg carries a `num_ctas`/`occupancy` field AND - # the caller also provided a static value, error rather than silently - # ignoring one. (Handles the case where `space` is opaque to the macro - # — a user-built `CartesianSpace(...)` or `FixedSpace([(...),...])`.) - if !isempty(trials) - sample = first(trials) - if static_num_ctas !== nothing && hasproperty(sample, :num_ctas) - throw(ArgumentError( - "`num_ctas` is both a static kwarg and an axis in the search space. " * - "Pick one.")) - end - if static_occupancy !== nothing && hasproperty(sample, :occupancy) - throw(ArgumentError( - "`occupancy` is both a static kwarg and an axis in the search space. " * - "Pick one.")) - end - end - - # Share the inference cache across all per-cfg const-seeded compiles. - # Each cfg differs only in `Constant{T,V}` values, so the generic - # inference graph is identical — without sharing, kernels with slow - # inference paths (e.g. `ct.load(..., order=…)`) pay that cost N times. - # - # `pipelined_tune` overlaps the parallel compile fan-out with sequential - # GPU measurement: as soon as the first compiler finishes, the master - # starts timing while the remaining compilers continue in the background. record, precompile_error, first_error = with(_SCOPED_INF_CACHE => _fresh_inf_cache()) do pipelined_tune(f, trials, grid_fn, args_fn; @@ -323,50 +376,19 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac verify=checker, reset) end - if isempty(record) - err_info = first_error !== nothing ? first_error : precompile_error - if err_info === nothing - throw(ArgumentError("No valid config found in search space")) - else - cfg, err = err_info - throw(ArgumentError( - "No valid config found. First failure for cfg=$cfg: $(sprint(showerror, err))")) - end - end + isempty(record) && _no_valid_config_error(first_error, precompile_error) - if tuning.refine_topk > 0 && length(record) > 1 - sort!(record, by=last) - top_configs = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] - refined, _ = measure_candidates(f, top_configs, grid_fn, args_fn; - sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, - static_num_ctas, static_occupancy, reset) - if !isempty(refined) - record = refined - end - end + record = _refine_record(f, record, tuning, grid_fn, args_fn; + sm_arch, opt_level, + static_num_ctas, static_occupancy, + verify=checker, reset) - _, best_idx = findmin(last, record) - candidate = (; best_config=record[best_idx][1], tuning_record=record) - - # Race: another thread may have populated the cache while we were - # tuning. If so, return their result and report `cache_hit=true` so - # the caller's accounting stays accurate. - entry, cache_hit = lock(AUTOTUNE_LOCK) do - per_kernel = get!(Dict{Any,Any}, AUTOTUNE_CACHE, kernel_key) - if !tuning.force && haskey(per_kernel, arg_key) - per_kernel[arg_key], true - else - per_kernel[arg_key] = candidate - candidate, false - end - end + candidate = _best_candidate(record) + entry, cache_hit = _cache_candidate!(candidate, kernel_key, arg_key, trials; + force=tuning.force) return entry, cache_hit, reset end -# Normalize `grid`/`args` to a callable. A `Function` is used as-is; any -# other value is wrapped in `Returns(...)`. Lets `autotune_launch` accept -# `blocks = 1024` or `args = (a, b, c)` directly instead of forcing the -# caller to write `cfg -> 1024` for grids that don't depend on the cfg. @inline _as_cfg_fn(f::Function) = f @inline _as_cfg_fn(x) = Returns(x) @@ -375,20 +397,12 @@ end tuning, sm_arch, opt_level, num_ctas=nothing, occupancy=nothing) -Tune `f` over `space` (an [`AbstractSearchSpace`](@ref) or a `Vector`/`NamedTuple` -shorthand) and launch the best config. +Tune `f` over `space` and launch the fastest valid config. -`grid` and `args` are either functions of the form `cfg -> grid` / -`cfg -> args_tuple` (for grids/args that depend on the cfg), or plain -values (for grids/args that don't). Internally both are normalized to -callables. `launch_args` follows the same convention; defaults to `args`. - -Results are cached per `(f, sm_arch, opt_level, num_ctas, occupancy) ⇒ key`. - -`num_ctas` and `occupancy` may be supplied as **static** kwargs (applied -uniformly to every cfg — useful for `ByTarget(...)`-style per-arch dispatch) -OR as **axes** inside `space` (tuned per cfg), but not both. Specifying -both throws an `ArgumentError`. +`space` can be an `AbstractSearchSpace`, a `NamedTuple` of cartesian axes, or +an iterable of `NamedTuple` configs. `grid`, `args`, and `launch_args` can be +plain values or `cfg -> value` functions. Results are cached per +`(f, sm_arch, opt_level, num_ctas, occupancy)` and user `key`. """ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, grid, args; @@ -402,15 +416,13 @@ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, num_ctas=nothing, occupancy=nothing) tuning = normalize_tuning(tuning) - rng = tuning.seed !== nothing ? MersenneTwister(tuning.seed) : Random.default_rng() grid_fn = _as_cfg_fn(grid) args_fn = _as_cfg_fn(args) launch_args_fn = launch_args === nothing ? args_fn : _as_cfg_fn(launch_args) kernel_key = (f, sm_arch, opt_level, num_ctas, occupancy) - - entry, cache_hit, reset = find_or_tune(f, space, rng, grid_fn, args_fn, tuning; + entry, cache_hit, reset = find_or_tune(f, space, grid_fn, args_fn, tuning; sm_arch, opt_level, kernel_key, arg_key=key, static_num_ctas=num_ctas, static_occupancy=occupancy, verify, setup) @@ -434,19 +446,20 @@ function autotune_launch(@nospecialize(f), configs, grid, args; kwargs...) end function clear_autotune_cache(; kernel=nothing, key=nothing) - lock(AUTOTUNE_LOCK) do + Base.@lock AUTOTUNE_CACHE begin + cache = AUTOTUNE_CACHE[] if kernel === nothing key === nothing || throw(ArgumentError("`key` requires `kernel`")) - empty!(AUTOTUNE_CACHE) + empty!(cache) return nothing end - for kernel_key in collect(keys(AUTOTUNE_CACHE)) + for kernel_key in collect(keys(cache)) kernel_key isa Tuple || continue kernel_key[1] === kernel || continue - per_kernel = AUTOTUNE_CACHE[kernel_key] + per_kernel = cache[kernel_key] key === nothing ? empty!(per_kernel) : pop!(per_kernel, key, nothing) - isempty(per_kernel) && delete!(AUTOTUNE_CACHE, kernel_key) + isempty(per_kernel) && delete!(cache, kernel_key) end end return nothing diff --git a/src/experimental/autotune_macro.jl b/src/experimental/autotune_macro.jl index 1b421661..9cdf9b27 100644 --- a/src/experimental/autotune_macro.jl +++ b/src/experimental/autotune_macro.jl @@ -1,4 +1,6 @@ -# `@autotune` — surface syntax for `autotune_launch`. +public @autotune + +# `@autotune`: surface syntax for `autotune_launch`. # # Desugars # @@ -67,24 +69,24 @@ and the kernel-call args, `\$X` interpolates `cfg.X` (where `cfg` is the tuning configuration being evaluated). # Required kwargs -- `space` — a `NamedTuple` like `(A=(...), B=(...))` (becomes a +- `space` - a `NamedTuple` like `(A=(...), B=(...))` (becomes a `CartesianSpace`), a `Vector` of `NamedTuple`s (becomes a `FixedSpace`), - or any `AbstractSearchSpace` (passed through — useful for + or any `AbstractSearchSpace` (passed through; useful for `CartesianSpace(constraint; ...)`). -- `blocks` — grid dimensions, an `Int` or `Tuple`. May reference `\$X`. +- `blocks` - grid dimensions, an `Int` or `Tuple`. May reference `\$X`. # Optional kwargs -- `key` — cache key (any value) -- `tuning` — `NamedTuple` of tuning knobs (`preset`, `force`, etc.) -- `verify` — `() -> (() -> Bool)` factory; the returned checker is +- `key` - cache key (any value) +- `tuning` - `NamedTuple` of tuning knobs (`preset`, `force`, etc.) +- `verify` - `() -> (() -> Bool)` factory; the returned checker is called after each warmup pass to reject incorrect cfgs -- `setup` — `() -> (() -> Nothing)` factory; reset between reps -- `launch_args` — final-launch args (or `cfg -> args` if it should differ +- `setup` - `() -> (() -> Nothing)` factory; reset between reps +- `launch_args` - final-launch args (or `cfg -> args` if it should differ from the kernel-call args). Use this when the timed args are throwaway copies (in-place kernels) and the final launch should hit the real buffers -- `sm_arch`, `opt_level` — forwarded to `cufunction` -- `num_ctas`, `occupancy` — **static** hints applied uniformly to every +- `sm_arch`, `opt_level` - forwarded to `cufunction` +- `num_ctas`, `occupancy` - **static** hints applied uniformly to every cfg. May not coexist with same-named axes in `space` (the macro flags the conflict at expansion time when `space` is a literal NT; otherwise `autotune_launch` @@ -116,7 +118,7 @@ macro autotune(args...) call === nothing || error("@autotune: only one kernel call allowed") call = arg else - error("@autotune: unexpected argument `$arg` — expected `kwarg=val` or a kernel call") + error("@autotune: unexpected argument `$arg`; expected `kwarg=val` or a kernel call") end end @@ -140,7 +142,7 @@ macro autotune(args...) end end - # Extract the kernel call (positional only — no kernel kwargs). + # Extract the kernel call (positional only; no kernel kwargs). Meta.isexpr(call, :call) || error("@autotune: kernel must be a function-call expression") f_expr = call.args[1] @@ -166,7 +168,6 @@ macro autotune(args...) :launch_args, :num_ctas, :occupancy) kw_exprs = [Expr(:kw, k, kwargs[k]) for k in forwarded_keys if haskey(kwargs, k)] - return esc(quote - $autotune_launch($f_expr, $space_expr, $grid_fn, $args_fn; $(kw_exprs...)) - end) + launch = GlobalRef(@__MODULE__, :autotune_launch) + return esc(:($launch($f_expr, $space_expr, $grid_fn, $args_fn; $(kw_exprs...)))) end diff --git a/src/experimental/search_space.jl b/src/experimental/search_space.jl new file mode 100644 index 00000000..d23a3840 --- /dev/null +++ b/src/experimental/search_space.jl @@ -0,0 +1,58 @@ +public AbstractSearchSpace, FixedSpace, CartesianSpace + +abstract type AbstractSearchSpace end + +Base.length(space::AbstractSearchSpace) = count(Returns(true), space) + +struct FixedSpace{T<:NamedTuple} <: AbstractSearchSpace + elements::Vector{T} +end + +FixedSpace(elements::AbstractVector{T}) where {T<:NamedTuple} = + FixedSpace{T}(collect(elements)) + +function FixedSpace(configs) + elements = collect(configs) + all(config -> config isa NamedTuple, elements) || + throw(ArgumentError("FixedSpace requires NamedTuple configs")) + return FixedSpace(NamedTuple[elements...]) +end + +Base.eltype(::Type{<:FixedSpace{T}}) where {T} = T +Base.length(space::FixedSpace) = length(space.elements) +Base.iterate(space::FixedSpace, args...) = iterate(space.elements, args...) + +struct CartesianSpace{names,F,Axes<:NamedTuple{names}} <: AbstractSearchSpace + constraint::F + axes::Axes +end + +_axis_tuple(axis::Tuple) = axis +_axis_tuple(axis) = Tuple(axis) + +function _cartesian_space(constraint::F, axes::NamedTuple) where {F} + tuple_axes = map(_axis_tuple, axes) + return CartesianSpace{keys(tuple_axes),F,typeof(tuple_axes)}(constraint, tuple_axes) +end + +CartesianSpace(axes::NamedTuple) = _cartesian_space(Returns(true), axes) +CartesianSpace(; axes...) = CartesianSpace(NamedTuple(axes)) +CartesianSpace(constraint::Function; axes...) = + _cartesian_space(constraint, NamedTuple(axes)) + +Base.eltype(::Type{<:CartesianSpace{names}}) where {names} = NamedTuple{names} + +function Base.iterate(space::CartesianSpace{names}, state=nothing) where {names} + product, cursor = state === nothing ? + (Iterators.product(values(space.axes)...), nothing) : + (state.product, state.cursor) + + result = cursor === nothing ? iterate(product) : iterate(product, cursor) + while result !== nothing + values, cursor = result + cfg = NamedTuple{names}(values) + space.constraint(cfg) && return cfg, (; product, cursor) + result = iterate(product, cursor) + end + return nothing +end diff --git a/test/device/autotune.jl b/test/device/autotune.jl index cfe33d0b..06c00c80 100644 --- a/test/device/autotune.jl +++ b/test/device/autotune.jl @@ -109,6 +109,11 @@ const Exp = ct.Experimental @test Array(c) ≈ fill(3f0, n) end + @testset "CartesianSpace range axis" begin + @test collect(Exp.CartesianSpace(tile=16:16:32)) == + [(; tile=16), (; tile=32)] + end + @testset "NamedTuple convenience → CartesianSpace" begin Exp.clear_autotune_cache() fill!(c, 0f0) @@ -266,6 +271,58 @@ const Exp = ct.Experimental tuning=(preset=:fast, refine_topk=0)) end + @testset "conflict scans every config" begin + space = Exp.FixedSpace(Any[(; tile=16), (; tile=32, occupancy=2)]) + @test_throws ArgumentError Exp.autotune_launch( + vadd_kernel, + space, + cfg -> cld(n, cfg.tile), + cfg -> (a, b, c, ct.Constant(cfg.tile)); + key=(:conflict_late, n), + occupancy=4, + tuning=(preset=:fast, refine_topk=0)) + end + + @testset "tuning validation" begin + @test_throws ArgumentError Exp.autotune_launch( + vadd_kernel, + [(; tile=16)], + grid_fn, args_fn; + key=(:bad_reps, n), + tuning=(preset=:fast, reps=0)) + + @test_throws ArgumentError Exp.autotune_launch( + vadd_kernel, + [(; tile=16)], + grid_fn, args_fn; + key=(:bad_key, n), + tuning=(preset=:fast, typo=1)) + end + + @testset "cached config must belong to current space" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + r1 = Exp.autotune_launch( + vadd_kernel, + [(; tile=16)], + grid_fn, args_fn; + key=(:space_sensitive, n), + tuning=(preset=:fast, refine_topk=0)) + @test !r1.cache_hit + @test r1.tuned_config.tile == 16 + + fill!(c, 0f0) + r2 = Exp.autotune_launch( + vadd_kernel, + [(; tile=32)], + grid_fn, args_fn; + key=(:space_sensitive, n), + tuning=(preset=:fast, refine_topk=0)) + @test !r2.cache_hit + @test r2.tuned_config.tile == 32 + @test Array(c) ≈ fill(3f0, n) + end + @testset "@autotune macro: NT space" begin Exp.clear_autotune_cache() fill!(c, 0f0) From 221d3956bb0ed18cd782b01a228581e847e2e6ef Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 27 May 2026 19:45:44 +0200 Subject: [PATCH 12/15] Keep autotune candidate compiles temporary --- src/compiler/driver.jl | 2 +- src/experimental/Experimental.jl | 3 +- src/experimental/autotune.jl | 55 +++++++++++++++++++----------- src/launch.jl | 49 +++++++++++++++++++++------ test/device/autotune.jl | 58 ++++++++++++++++++++++++++++++-- 5 files changed, 132 insertions(+), 35 deletions(-) diff --git a/src/compiler/driver.jl b/src/compiler/driver.jl index f4c93ac6..48276d09 100644 --- a/src/compiler/driver.jl +++ b/src/compiler/driver.jl @@ -307,7 +307,7 @@ function emit_tile!(cache::CacheView, mi::Core.MethodInstance, # Compute bytecode via driver sci, rettype, kernel_meta = ir_result - key = cache.owner::TileCacheKey + key = tile_cache_key(cache.owner) opts = CGOpts((sm_arch=unpack_version(key.sm_arch), opt_level=unpack_hint(key.opt_level), num_ctas=unpack_hint(key.num_ctas), diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl index 485d33fc..162cf67b 100644 --- a/src/experimental/Experimental.jl +++ b/src/experimental/Experimental.jl @@ -1,7 +1,8 @@ module Experimental using ..cuTile -using ..cuTile: cuTileconvert, cufunction, default_sm_arch, _SCOPED_INF_CACHE +using ..cuTile: cuTileconvert, default_sm_arch, temporary_cufunction, + _SCOPED_INF_CACHE using CUDACore: CUDACore diff --git a/src/experimental/autotune.jl b/src/experimental/autotune.jl index c382bf0c..37ce776f 100644 --- a/src/experimental/autotune.jl +++ b/src/experimental/autotune.jl @@ -30,6 +30,15 @@ struct TuningOptions precompile_workers::Int end +struct TuningSession + id::UInt +end + +const NEXT_TUNING_SESSION_ID = Threads.Atomic{UInt}(0) + +TuningSession() = + TuningSession(Threads.atomic_add!(NEXT_TUNING_SESSION_ID, UInt(1)) + UInt(1)) + _tuning_defaults() = ( seed=nothing, force=false, @@ -123,12 +132,13 @@ end @inline _converted_args(args_fn, cfg) = map(cuTileconvert, args_fn(cfg)) @inline _argtypes(args) = Tuple{map(Core.Typeof, args)...} -function _compile_cfg(@nospecialize(f), cfg, args_fn; +function _compile_cfg(@nospecialize(f), cfg, args_fn, session::TuningSession; sm_arch::VersionNumber, opt_level::Int, static_num_ctas=nothing, static_occupancy=nothing) converted = _converted_args(args_fn, cfg) - return cufunction(f, _argtypes(converted); sm_arch, opt_level, - hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) + return temporary_cufunction(f, _argtypes(converted), session.id; + sm_arch, opt_level, + hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) end function _time_ms(run_once, get_args; @@ -156,12 +166,12 @@ function _time_ms(run_once, get_args; return best_ms end -function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn; +function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn, session::TuningSession; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, static_num_ctas=nothing, static_occupancy=nothing, verify=nothing, reset=nothing) grid = _grid_dims(grid_fn(cfg)) - kernel = _compile_cfg(f, cfg, args_fn; sm_arch, opt_level, + kernel = _compile_cfg(f, cfg, args_fn, session; sm_arch, opt_level, static_num_ctas, static_occupancy) run_once = converted -> kernel(converted...; blocks=grid) @@ -169,10 +179,10 @@ function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn; return _time_ms(run_once, get_args; warmup, reps, verify, reset) end -function precompile_cfg(@nospecialize(f), cfg, args_fn; +function precompile_cfg(@nospecialize(f), cfg, args_fn, session::TuningSession; sm_arch::VersionNumber, opt_level::Int, static_num_ctas=nothing, static_occupancy=nothing) - _compile_cfg(f, cfg, args_fn; sm_arch, opt_level, + _compile_cfg(f, cfg, args_fn, session; sm_arch, opt_level, static_num_ctas, static_occupancy) return nothing end @@ -180,9 +190,10 @@ end const TimingRecord = Tuple{Any, Float32} function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue, - @nospecialize(f), cfg, grid_fn, args_fn; kwargs...) + @nospecialize(f), cfg, grid_fn, args_fn, + session::TuningSession; kwargs...) ms = try - eval_cfg(f, cfg, grid_fn, args_fn; kwargs...) + eval_cfg(f, cfg, grid_fn, args_fn, session; kwargs...) catch err err isa InterruptException && rethrow() if err isa VerificationError @@ -198,7 +209,8 @@ function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue, return nothing end -function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; +function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn, + session::TuningSession; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, static_num_ctas=nothing, static_occupancy=nothing, @@ -206,7 +218,7 @@ function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, arg record = TimingRecord[] first_error = Ref{Any}(nothing) for cfg in configs - _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn; + _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn, session; sm_arch, opt_level, warmup, reps, static_num_ctas, static_occupancy, verify, reset) end @@ -218,9 +230,10 @@ end Compile candidate configurations on worker tasks while the caller task measures completed candidates. Measurement stays on the caller task because CUDA state is -task-local; workers only run the untimed `cufunction` precompile path. +task-local; workers only run the untimed temporary compile path. """ -function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; +function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn, + session::TuningSession; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, workers::Int, static_num_ctas=nothing, static_occupancy=nothing, @@ -228,7 +241,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn isempty(configs) && return TimingRecord[], nothing, nothing if iszero(workers) || length(configs) == 1 - record, first_error = measure_candidates(f, configs, grid_fn, args_fn; + record, first_error = measure_candidates(f, configs, grid_fn, args_fn, session; sm_arch, opt_level, warmup, reps, static_num_ctas, static_occupancy, verify, reset) return record, nothing, first_error @@ -249,7 +262,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn Threads.@spawn for cfg in jobs cancelled[] && break try - precompile_cfg(f, cfg, args_fn; sm_arch, opt_level, + precompile_cfg(f, cfg, args_fn, session; sm_arch, opt_level, static_num_ctas, static_occupancy) cancelled[] || put!(ready, cfg) catch err @@ -272,7 +285,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn first_error = Ref{Any}(nothing) try for cfg in ready - _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn; + _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn, session; sm_arch, opt_level, warmup, reps, static_num_ctas, static_occupancy, verify, reset) end @@ -331,14 +344,15 @@ function _no_valid_config_error(first_error, precompile_error) end function _refine_record(@nospecialize(f), record::Vector{TimingRecord}, tuning::TuningOptions, - grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, + grid_fn, args_fn, session::TuningSession; + sm_arch::VersionNumber, opt_level::Int, static_num_ctas=nothing, static_occupancy=nothing, verify=nothing, reset=nothing) (tuning.refine_topk > 0 && length(record) > 1) || return record sort!(record, by=last) top = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] - refined, _ = measure_candidates(f, top, grid_fn, args_fn; + refined, _ = measure_candidates(f, top, grid_fn, args_fn, session; sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, static_num_ctas, static_occupancy, verify, reset) return isempty(refined) ? record : refined @@ -365,10 +379,11 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, checker = verify !== nothing ? verify() : nothing reset = setup !== nothing ? setup() : nothing + session = TuningSession() record, precompile_error, first_error = with(_SCOPED_INF_CACHE => _fresh_inf_cache()) do - pipelined_tune(f, trials, grid_fn, args_fn; + pipelined_tune(f, trials, grid_fn, args_fn, session; sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, workers=tuning.precompile_workers, @@ -378,7 +393,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, isempty(record) && _no_valid_config_error(first_error, precompile_error) - record = _refine_record(f, record, tuning, grid_fn, args_fn; + record = _refine_record(f, record, tuning, grid_fn, args_fn, session; sm_arch, opt_level, static_num_ctas, static_occupancy, verify=checker, reset) diff --git a/src/launch.jl b/src/launch.jl index bf480663..45814640 100644 --- a/src/launch.jl +++ b/src/launch.jl @@ -123,6 +123,17 @@ TileCacheKey(sm_arch::VersionNumber, bytecode_version::VersionNumber, pack_hint(opt_level), pack_hint(num_ctas), pack_hint(occupancy), pack_hint(num_worker_warps)) +struct TemporaryTileCacheKey + key::TileCacheKey + session_id::UInt +end + +# Autotune candidates use this owner so they can share work within one tuning +# pass without becoming visible to the normal `cufunction` cache keyed by +# `TileCacheKey`. The winning config is promoted by the final normal launch. +@inline tile_cache_key(key::TileCacheKey) = key +@inline tile_cache_key(key::TemporaryTileCacheKey) = key.key + #============================================================================= Toolkit / device validation (cached: once per `(capability, cuda_version)`). @@ -405,13 +416,15 @@ end const EMIT_TILE_LOCK = ReentrantLock() """ - emit_binary!(cache, mi, ci, res; const_argtypes=nothing) -> Vector{UInt8} + emit_binary!(cache, mi, ci, res; const_argtypes=nothing, + store_disk=true) -> Vector{UInt8} Cached binary phase: compile Tile IR bytecode to CUBIN using tileiras. """ function emit_binary!(cache::CacheView, mi::Core.MethodInstance, ci::Core.CodeInstance, res::CuTileResults; - const_argtypes::Union{Vector{Any}, Nothing}=nothing) + const_argtypes::Union{Vector{Any}, Nothing}=nothing, + store_disk::Bool=true) # Recurse first — emit_structured! at the bottom of the chain fires # `compile_hook` for `@device_code_*` reflection, which must run on every # launch even when downstream artifacts are fully cached. @@ -419,12 +432,13 @@ function emit_binary!(cache::CacheView, mi::Core.MethodInstance, res.cuda_bin !== nothing && return res.cuda_bin - sm_arch = unpack_version(cache.owner.sm_arch) + owner = tile_cache_key(cache.owner) + sm_arch = unpack_version(owner.sm_arch) # Resolve opt_level here (not in emit_tile) because it's a tileiras flag, not bytecode. # num_ctas/occupancy/num_worker_warps are resolved in emit_tile because they're encoded in bytecode. _, _, kernel_meta = res.julia_ir - opt_level = something(resolve_hint(unpack_hint(cache.owner.opt_level), + opt_level = something(resolve_hint(unpack_hint(owner.opt_level), kernel_meta, :opt_level, sm_arch), 3) # Disk cache lookup. The hash covers every input that changes the CUBIN @@ -477,7 +491,7 @@ function emit_binary!(cache::CacheView, mi::Core.MethodInstance, rm(output_path, force=true) end - if cache_key !== nothing + if store_disk && cache_key !== nothing try DiskCache.put!(dc, cache_key, res.cuda_bin) catch err @@ -565,7 +579,22 @@ function cufunction(@nospecialize(f), tt::Type{<:Tuple}=Tuple{}; # invalidated by any package that defines methods on Base.Compiler hooks # like `OptimizationParams(::AbstractInterpreter)`. To reuse precompiled # native code, run the pipeline in the world captured at __init__. - invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes, key)::TileKernel{Core.Typeof(f), tt} + invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes, + key, true)::TileKernel{Core.Typeof(f), tt} +end + +function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple}, session_id::UInt; + sm_arch::Union{VersionNumber, Nothing}=nothing, + opt_level::Union{Int, Nothing}=nothing, + num_ctas::Union{Int, Nothing}=nothing, + occupancy::Union{Int, Nothing}=nothing) + bytecode_version = check_tile_ir_support() + resolved_sm_arch = sm_arch !== nothing ? sm_arch : default_sm_arch() + key = TileCacheKey(resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy) + owner = TemporaryTileCacheKey(key, session_id) + argtypes, const_argtypes = unwrap_argtypes(f, tt) + return invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes, + owner, false)::TileKernel{Core.Typeof(f), tt} end """ @@ -577,7 +606,7 @@ by [`link`](@ref) to load the result onto the GPU. No CUDA context required. """ function compile(@nospecialize(f), @nospecialize(argtypes), const_argtypes::Union{Vector{Any}, Nothing}, - key::TileCacheKey) + key, store_disk::Bool=true) world = Base.get_world_counter() mi = method_instance(f, argtypes; world) mi === nothing && throw(MethodError(f, argtypes)) @@ -604,7 +633,7 @@ function compile(@nospecialize(f), @nospecialize(argtypes), # Always walk the emit chain (each phase short-circuits on its own cached # field, but `emit_structured!` also fires `compile_hook` for reflection, # which has to run on every launch even when the cube/cufunc is cached). - emit_binary!(cache, mi, ci, res; const_argtypes) + emit_binary!(cache, mi, ci, res; const_argtypes, store_disk) return cache, mi, ci, res end @@ -613,8 +642,8 @@ end # even when later-loaded packages would otherwise have invalidated it. function cufunction_compile(@nospecialize(f), @nospecialize(tt), @nospecialize(argtypes), const_argtypes::Union{Vector{Any}, Nothing}, - key::TileCacheKey) - cache, mi, ci, res = compile(f, argtypes, const_argtypes, key) + key, store_disk::Bool=true) + cache, mi, ci, res = compile(f, argtypes, const_argtypes, key, store_disk) cufunc = link(cache, mi, ci, res) diff --git a/test/device/autotune.jl b/test/device/autotune.jl index 06c00c80..0738603f 100644 --- a/test/device/autotune.jl +++ b/test/device/autotune.jl @@ -9,8 +9,8 @@ const Exp = ct.Experimental c::ct.TileArray{Float32,1}, tile::Int) pid = ct.bid(1) - ta = ct.load(a, pid, (tile[],)) - tb = ct.load(b, pid, (tile[],)) + ta = ct.load(a, pid, (tile,)) + tb = ct.load(b, pid, (tile,)) ct.store(c, pid, ta + tb) return nothing end @@ -18,11 +18,43 @@ const Exp = ct.Experimental function inplace_add_kernel(x::ct.TileArray{Float32,1}, tile::Int) pid = ct.bid(1) - tx = ct.load(x, pid, (tile[],)) + tx = ct.load(x, pid, (tile,)) ct.store(x, pid, tx .+ 1f0) return nothing end + function cache_probe_kernel(a::ct.TileArray{Float32,1}, + b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}, + tile::Int) + pid = ct.bid(1) + ta = ct.load(a, pid, (tile,)) + tb = ct.load(b, pid, (tile,)) + ct.store(c, pid, ta + tb) + return nothing + end + + function normal_const_entry_count(f, args) + converted = map(ct.cuTileconvert, args) + tt = Tuple{map(Core.Typeof, converted)...} + argtypes, _ = ct.unwrap_argtypes(f, tt) + + world = Base.get_world_counter() + key = ct.TileCacheKey(ct.default_sm_arch(), ct.bytecode_version(), + 3, nothing, nothing) + cache = ct.CompilerCaching.CacheView{ct.CuTileResults}(key, world) + mi = ct.CompilerCaching.method_instance(f, argtypes; world) + ci = get(cache, mi, nothing) + ci === nothing && return 0 + + cached = ct.CC.traverse_analysis_results(ci) do result + result isa ct.CompilerCaching.CachedResult{ct.CuTileResults} ? + result : nothing + end + cached === nothing && return 0 + return length(cached.const_entries) + end + n = 512 a = CUDA.fill(1f0, n) b = CUDA.fill(2f0, n) @@ -323,6 +355,26 @@ const Exp = ct.Experimental @test Array(c) ≈ fill(3f0, n) end + @testset "only winner enters normal compiler cache" begin + Exp.clear_autotune_cache() + probe_c = CUDA.zeros(Float32, n) + probe_configs = [(; tile=16), (; tile=32), (; tile=64)] + probe_args = cfg -> (a, b, probe_c, ct.Constant(cfg.tile)) + @test normal_const_entry_count(cache_probe_kernel, probe_args(probe_configs[1])) == 0 + + result = Exp.autotune_launch( + cache_probe_kernel, + probe_configs, + cfg -> cld(n, cfg.tile), + probe_args; + key=(:temporary_candidates, n), + tuning=(preset=:fast, refine_topk=0)) + + @test result.tuned_config in probe_configs + @test normal_const_entry_count(cache_probe_kernel, probe_args(probe_configs[1])) == 1 + @test Array(probe_c) ≈ fill(3f0, n) + end + @testset "@autotune macro: NT space" begin Exp.clear_autotune_cache() fill!(c, 0f0) From 943c12f8c38756993f46be8ede50e115c0018a9f Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 2 Jun 2026 14:09:35 +0200 Subject: [PATCH 13/15] Fix TileCacheKey call --- src/launch.jl | 6 ++++-- test/device/autotune.jl | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/launch.jl b/src/launch.jl index 45814640..f76f317c 100644 --- a/src/launch.jl +++ b/src/launch.jl @@ -587,10 +587,12 @@ function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple}, session_id::U sm_arch::Union{VersionNumber, Nothing}=nothing, opt_level::Union{Int, Nothing}=nothing, num_ctas::Union{Int, Nothing}=nothing, - occupancy::Union{Int, Nothing}=nothing) + occupancy::Union{Int, Nothing}=nothing, + num_worker_warps::Union{Int, Nothing}=nothing) bytecode_version = check_tile_ir_support() resolved_sm_arch = sm_arch !== nothing ? sm_arch : default_sm_arch() - key = TileCacheKey(resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy) + key = TileCacheKey(resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy, + num_worker_warps) owner = TemporaryTileCacheKey(key, session_id) argtypes, const_argtypes = unwrap_argtypes(f, tt) return invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes, diff --git a/test/device/autotune.jl b/test/device/autotune.jl index 0738603f..6fecf8e0 100644 --- a/test/device/autotune.jl +++ b/test/device/autotune.jl @@ -41,7 +41,7 @@ const Exp = ct.Experimental world = Base.get_world_counter() key = ct.TileCacheKey(ct.default_sm_arch(), ct.bytecode_version(), - 3, nothing, nothing) + 3, nothing, nothing, nothing) cache = ct.CompilerCaching.CacheView{ct.CuTileResults}(key, world) mi = ct.CompilerCaching.method_instance(f, argtypes; world) ci = get(cache, mi, nothing) From 3f2a2ecd6dab6e212405e6a2fceabfd360a1eab7 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 3 Jun 2026 17:58:40 +0200 Subject: [PATCH 14/15] Retain context --- src/experimental/autotune.jl | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/experimental/autotune.jl b/src/experimental/autotune.jl index 37ce776f..e3f0486c 100644 --- a/src/experimental/autotune.jl +++ b/src/experimental/autotune.jl @@ -252,27 +252,31 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn jobs = Channel{Any}(length(configs)) foreach(cfg -> put!(jobs, cfg), configs) close(jobs) - + cancelled = Threads.Atomic{Bool}(false) precompile_error = Ref{Any}(nothing) - error_lock = ReentrantLock() + error_lock = ReentrantLock() + ctx = CUDACore.context() producer = Threads.@spawn try @sync for _ in 1:workers - Threads.@spawn for cfg in jobs - cancelled[] && break - try - precompile_cfg(f, cfg, args_fn, session; sm_arch, opt_level, - static_num_ctas, static_occupancy) - cancelled[] || put!(ready, cfg) - catch err - if err isa InterruptException - cancelled[] = true - rethrow() - end - lock(error_lock) do - precompile_error[] === nothing && - (precompile_error[] = (cfg, err)) + Threads.@spawn begin + CUDACore.context!(ctx) + for cfg in jobs + cancelled[] && break + try + precompile_cfg(f, cfg, args_fn, session; sm_arch, opt_level, + static_num_ctas, static_occupancy) + cancelled[] || put!(ready, cfg) + catch err + if err isa InterruptException + cancelled[] = true + rethrow() + end + lock(error_lock) do + precompile_error[] === nothing && + (precompile_error[] = (cfg, err)) + end end end end From a097938237b4a2667eaa428954b3e56cabc89651 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Wed, 3 Jun 2026 19:01:58 +0200 Subject: [PATCH 15/15] cleanup --- src/experimental/autotune.jl | 122 +++++++++++------------------------ src/launch.jl | 13 ++-- 2 files changed, 46 insertions(+), 89 deletions(-) diff --git a/src/experimental/autotune.jl b/src/experimental/autotune.jl index e3f0486c..25e656e4 100644 --- a/src/experimental/autotune.jl +++ b/src/experimental/autotune.jl @@ -15,11 +15,6 @@ const TUNING_PRESETS = ( thorough = (warmup=2, reps=7, refine_topk=4, refine_reps=6), ) -const TUNING_KEYS = ( - :warmup, :reps, :refine_topk, :refine_reps, - :seed, :force, :precompile_workers, -) - struct TuningOptions warmup::Int reps::Int @@ -30,65 +25,38 @@ struct TuningOptions precompile_workers::Int end -struct TuningSession - id::UInt -end - -const NEXT_TUNING_SESSION_ID = Threads.Atomic{UInt}(0) - -TuningSession() = - TuningSession(Threads.atomic_add!(NEXT_TUNING_SESSION_ID, UInt(1)) + UInt(1)) - _tuning_defaults() = ( seed=nothing, force=false, precompile_workers=Threads.nthreads(), ) -function _check_int(name::Symbol, value; min::Int) - (value isa Integer && !(value isa Bool)) || - throw(ArgumentError("tuning.$name must be an integer, got $(typeof(value))")) - value >= min || - throw(ArgumentError("tuning.$name must be >= $min, got $value")) - return Int(value) -end - -function _check_seed(seed) - seed === nothing && return nothing - (seed isa Integer && !(seed isa Bool)) || - throw(ArgumentError("tuning.seed must be an integer or nothing, got $(typeof(seed))")) - return Int(seed) -end - -function _check_bool(name::Symbol, value) - value isa Bool || - throw(ArgumentError("tuning.$name must be a Bool, got $(typeof(value))")) - return value -end +# Lower bound for each count field; the others have no minimum. +const _TUNING_MINIMA = (warmup=0, reps=1, refine_topk=0, refine_reps=1, + precompile_workers=0) function normalize_tuning(tuning::NamedTuple) - valid_keys = (:preset, TUNING_KEYS...) - unknown = setdiff(collect(keys(tuning)), collect(valid_keys)) + valid_keys = (:preset, fieldnames(TuningOptions)...) + unknown = setdiff(keys(tuning), valid_keys) isempty(unknown) || throw(ArgumentError("Unknown tuning option(s): $(join(unknown, ", "))")) preset = get(tuning, :preset, :default) - preset isa Symbol || throw(ArgumentError("tuning.preset must be a Symbol")) - hasproperty(TUNING_PRESETS, preset) || + preset isa Symbol && hasproperty(TUNING_PRESETS, preset) || throw(ArgumentError("Unknown tuning preset `$preset`; use :fast, :default, or :thorough")) overrides = NamedTuple(k => v for (k, v) in pairs(tuning) if k !== :preset) values = merge(_tuning_defaults(), getproperty(TUNING_PRESETS, preset), overrides) - return TuningOptions( - _check_int(:warmup, values.warmup; min=0), - _check_int(:reps, values.reps; min=1), - _check_int(:refine_topk, values.refine_topk; min=0), - _check_int(:refine_reps, values.refine_reps; min=1), - _check_seed(values.seed), - _check_bool(:force, values.force), - _check_int(:precompile_workers, values.precompile_workers; min=0), - ) + # The struct's field types coerce/reject bad value types; we only enforce + # the lower bounds that the types can't. Pull fields by name since `values` + # is in merge order, not struct-field order. + opts = TuningOptions((getproperty(values, f) for f in fieldnames(TuningOptions))...) + for (name, lo) in pairs(_TUNING_MINIMA) + getfield(opts, name) >= lo || + throw(ArgumentError("tuning.$name must be >= $lo, got $(getfield(opts, name))")) + end + return opts end @inline _hint_from_cfg(cfg, name::Symbol, fallback) = @@ -101,22 +69,14 @@ function hints_from_cfg(cfg; static_num_ctas=nothing, static_occupancy=nothing) ) end -function _check_static_hint_conflict(configs, hint::Symbol, static_value) - static_value === nothing && return nothing - for cfg in configs - if hasproperty(cfg, hint) - throw(ArgumentError( - "`$hint` is both a static kwarg and an axis in the search space. " * - "Pick one.")) - end - end - return nothing -end - function _check_static_hint_conflicts(configs; static_num_ctas=nothing, static_occupancy=nothing) - _check_static_hint_conflict(configs, :num_ctas, static_num_ctas) - _check_static_hint_conflict(configs, :occupancy, static_occupancy) + statics = (num_ctas=static_num_ctas, occupancy=static_occupancy) + for (hint, static_value) in pairs(statics) + static_value === nothing && continue + any(cfg -> hasproperty(cfg, hint), configs) && throw(ArgumentError( + "`$hint` is both a static kwarg and an axis in the search space. Pick one.")) + end return nothing end @@ -132,11 +92,11 @@ end @inline _converted_args(args_fn, cfg) = map(cuTileconvert, args_fn(cfg)) @inline _argtypes(args) = Tuple{map(Core.Typeof, args)...} -function _compile_cfg(@nospecialize(f), cfg, args_fn, session::TuningSession; +function _compile_cfg(@nospecialize(f), cfg, args_fn; sm_arch::VersionNumber, opt_level::Int, static_num_ctas=nothing, static_occupancy=nothing) converted = _converted_args(args_fn, cfg) - return temporary_cufunction(f, _argtypes(converted), session.id; + return temporary_cufunction(f, _argtypes(converted); sm_arch, opt_level, hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) end @@ -166,12 +126,12 @@ function _time_ms(run_once, get_args; return best_ms end -function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn, session::TuningSession; +function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, static_num_ctas=nothing, static_occupancy=nothing, verify=nothing, reset=nothing) grid = _grid_dims(grid_fn(cfg)) - kernel = _compile_cfg(f, cfg, args_fn, session; sm_arch, opt_level, + kernel = _compile_cfg(f, cfg, args_fn; sm_arch, opt_level, static_num_ctas, static_occupancy) run_once = converted -> kernel(converted...; blocks=grid) @@ -179,10 +139,10 @@ function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn, session::TuningSessio return _time_ms(run_once, get_args; warmup, reps, verify, reset) end -function precompile_cfg(@nospecialize(f), cfg, args_fn, session::TuningSession; +function precompile_cfg(@nospecialize(f), cfg, args_fn; sm_arch::VersionNumber, opt_level::Int, static_num_ctas=nothing, static_occupancy=nothing) - _compile_cfg(f, cfg, args_fn, session; sm_arch, opt_level, + _compile_cfg(f, cfg, args_fn; sm_arch, opt_level, static_num_ctas, static_occupancy) return nothing end @@ -190,10 +150,9 @@ end const TimingRecord = Tuple{Any, Float32} function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue, - @nospecialize(f), cfg, grid_fn, args_fn, - session::TuningSession; kwargs...) + @nospecialize(f), cfg, grid_fn, args_fn; kwargs...) ms = try - eval_cfg(f, cfg, grid_fn, args_fn, session; kwargs...) + eval_cfg(f, cfg, grid_fn, args_fn; kwargs...) catch err err isa InterruptException && rethrow() if err isa VerificationError @@ -209,8 +168,7 @@ function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue, return nothing end -function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn, - session::TuningSession; +function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, static_num_ctas=nothing, static_occupancy=nothing, @@ -218,7 +176,7 @@ function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, arg record = TimingRecord[] first_error = Ref{Any}(nothing) for cfg in configs - _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn, session; + _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, static_num_ctas, static_occupancy, verify, reset) end @@ -232,8 +190,7 @@ Compile candidate configurations on worker tasks while the caller task measures completed candidates. Measurement stays on the caller task because CUDA state is task-local; workers only run the untimed temporary compile path. """ -function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn, - session::TuningSession; +function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, workers::Int, static_num_ctas=nothing, static_occupancy=nothing, @@ -241,7 +198,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn isempty(configs) && return TimingRecord[], nothing, nothing if iszero(workers) || length(configs) == 1 - record, first_error = measure_candidates(f, configs, grid_fn, args_fn, session; + record, first_error = measure_candidates(f, configs, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, static_num_ctas, static_occupancy, verify, reset) return record, nothing, first_error @@ -265,7 +222,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn for cfg in jobs cancelled[] && break try - precompile_cfg(f, cfg, args_fn, session; sm_arch, opt_level, + precompile_cfg(f, cfg, args_fn; sm_arch, opt_level, static_num_ctas, static_occupancy) cancelled[] || put!(ready, cfg) catch err @@ -289,7 +246,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn first_error = Ref{Any}(nothing) try for cfg in ready - _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn, session; + _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, static_num_ctas, static_occupancy, verify, reset) end @@ -348,7 +305,7 @@ function _no_valid_config_error(first_error, precompile_error) end function _refine_record(@nospecialize(f), record::Vector{TimingRecord}, tuning::TuningOptions, - grid_fn, args_fn, session::TuningSession; + grid_fn, args_fn; sm_arch::VersionNumber, opt_level::Int, static_num_ctas=nothing, static_occupancy=nothing, verify=nothing, reset=nothing) @@ -356,7 +313,7 @@ function _refine_record(@nospecialize(f), record::Vector{TimingRecord}, tuning:: sort!(record, by=last) top = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] - refined, _ = measure_candidates(f, top, grid_fn, args_fn, session; + refined, _ = measure_candidates(f, top, grid_fn, args_fn; sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, static_num_ctas, static_occupancy, verify, reset) return isempty(refined) ? record : refined @@ -383,11 +340,10 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, checker = verify !== nothing ? verify() : nothing reset = setup !== nothing ? setup() : nothing - session = TuningSession() record, precompile_error, first_error = with(_SCOPED_INF_CACHE => _fresh_inf_cache()) do - pipelined_tune(f, trials, grid_fn, args_fn, session; + pipelined_tune(f, trials, grid_fn, args_fn; sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, workers=tuning.precompile_workers, @@ -397,7 +353,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, isempty(record) && _no_valid_config_error(first_error, precompile_error) - record = _refine_record(f, record, tuning, grid_fn, args_fn, session; + record = _refine_record(f, record, tuning, grid_fn, args_fn; sm_arch, opt_level, static_num_ctas, static_occupancy, verify=checker, reset) diff --git a/src/launch.jl b/src/launch.jl index f76f317c..04a1cd1f 100644 --- a/src/launch.jl +++ b/src/launch.jl @@ -125,12 +125,13 @@ TileCacheKey(sm_arch::VersionNumber, bytecode_version::VersionNumber, struct TemporaryTileCacheKey key::TileCacheKey - session_id::UInt end -# Autotune candidates use this owner so they can share work within one tuning -# pass without becoming visible to the normal `cufunction` cache keyed by -# `TileCacheKey`. The winning config is promoted by the final normal launch. +# Autotune candidates compile under this owner so they share inference/codegen +# work without becoming visible to (or polluting) the normal `cufunction` cache +# keyed by `TileCacheKey`. It's a plain marker wrapper: candidate CIs are keyed +# by content, so the cached set is bounded by the search space, not by how many +# times you tune. The winning config is promoted by the final normal launch. @inline tile_cache_key(key::TileCacheKey) = key @inline tile_cache_key(key::TemporaryTileCacheKey) = key.key @@ -583,7 +584,7 @@ function cufunction(@nospecialize(f), tt::Type{<:Tuple}=Tuple{}; key, true)::TileKernel{Core.Typeof(f), tt} end -function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple}, session_id::UInt; +function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple}; sm_arch::Union{VersionNumber, Nothing}=nothing, opt_level::Union{Int, Nothing}=nothing, num_ctas::Union{Int, Nothing}=nothing, @@ -593,7 +594,7 @@ function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple}, session_id::U resolved_sm_arch = sm_arch !== nothing ? sm_arch : default_sm_arch() key = TileCacheKey(resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy, num_worker_warps) - owner = TemporaryTileCacheKey(key, session_id) + owner = TemporaryTileCacheKey(key) argtypes, const_argtypes = unwrap_argtypes(f, tt) return invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes, owner, false)::TileKernel{Core.Typeof(f), tt}