diff --git a/src/compiler/driver.jl b/src/compiler/driver.jl index f4c93ac6..48276d09 100644 --- a/src/compiler/driver.jl +++ b/src/compiler/driver.jl @@ -307,7 +307,7 @@ function emit_tile!(cache::CacheView, mi::Core.MethodInstance, # Compute bytecode via driver sci, rettype, kernel_meta = ir_result - key = cache.owner::TileCacheKey + key = tile_cache_key(cache.owner) opts = CGOpts((sm_arch=unpack_version(key.sm_arch), opt_level=unpack_hint(key.opt_level), num_ctas=unpack_hint(key.num_ctas), diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index c5e03515..c7ad2256 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -1,8 +1,17 @@ # Integration with Julia's abstract interpreter +using Base.ScopedValues: ScopedValue Base.Experimental.@MethodTable cuTileMethodTable +# When assigned, every `cuTileInterpreter(cache)` constructed within +# the dynamic scope reuses this inference cache instead of allocating +# a fresh one. Lets batched inference passes (autotuning over many +# const-seeded variants of the same kernel) share work; without it, +# kernels that hit slow inference paths (e.g. `ct.load(..., order=...)`) +# pay the cost on every config. +const _SCOPED_INF_CACHE = ScopedValue{Any}() + function get_method_table_view(world::UInt) CC.CachedMethodTable(CC.OverlayMethodTable(world, cuTileMethodTable)) end @@ -21,10 +30,14 @@ end function cuTileInterpreter(cache::CacheView; always_inline::Bool=true) method_table = get_method_table_view(cache.world) - @static if isdefined(CC, :InferenceCache) - inf_cache = CC.InferenceCache() + inf_cache = if isassigned(_SCOPED_INF_CACHE) + _SCOPED_INF_CACHE[] else - inf_cache = Vector{CC.InferenceResult}() + @static if isdefined(CC, :InferenceCache) + CC.InferenceCache() + else + Vector{CC.InferenceResult}() + end end inf_params = CC.InferenceParams() opt_params = if always_inline diff --git a/src/cuTile.jl b/src/cuTile.jl index 784e1918..6a488c0c 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -88,9 +88,27 @@ include("cache.jl") include("launch.jl") public launch, TileBackend, DefaultBackend, Tiled, ByTarget, - @compiler_options, @fpmode, @., + @compiler_options, @fpmode, @., @cutile, bytecode_version +""" + @cutile [kwargs...] kernel(args...) + +Shorthand for `CUDACore.@cuda backend=cuTile [kwargs...] kernel(args...)`. + +Works from any module — does not require the caller to `using CUDACore`, +since the macro expands to a fully-qualified reference to the actual +`CUDACore` module object. + +```julia +@cutile blocks=N kernel(a, b, c) +@cutile blocks=N occupancy=4 kernel(a, b, c) +``` +""" +macro cutile(args...) + esc(:($CUDACore.@cuda backend=$cuTile $(args...))) +end + # World age captured at __init__ time. The compilation pipeline # (typeinf!, codegen, bytecode emission) is invoked in this world via # `invoke_frozen` so that precompiled native code stays usable even after @@ -123,4 +141,6 @@ end include("precompile.jl") +include("experimental/Experimental.jl") + end # module cuTile diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl new file mode 100644 index 00000000..162cf67b --- /dev/null +++ b/src/experimental/Experimental.jl @@ -0,0 +1,27 @@ +module Experimental + +using ..cuTile +using ..cuTile: cuTileconvert, default_sm_arch, temporary_cufunction, + _SCOPED_INF_CACHE + +using CUDACore: CUDACore + +using Base.ScopedValues: with +import Core.Compiler as CC +using Random + +# Builds a fresh inference cache compatible with the running Julia version. +# Used to wrap an autotune pass in `with(_SCOPED_INF_CACHE => ...)` so all the +# per-config const-seeded inference calls share results instead of paying +# the slow paths (e.g. `ct.load(..., order=...)`) once per config. +@inline _fresh_inf_cache() = @static if isdefined(CC, :InferenceCache) + CC.InferenceCache() +else + Vector{CC.InferenceResult}() +end + +include("search_space.jl") +include("autotune.jl") +include("autotune_macro.jl") + +end diff --git a/src/experimental/autotune.jl b/src/experimental/autotune.jl new file mode 100644 index 00000000..25e656e4 --- /dev/null +++ b/src/experimental/autotune.jl @@ -0,0 +1,441 @@ +public autotune_launch, clear_autotune_cache + +const AUTOTUNE_CACHE = Base.Lockable(Dict{Any, Dict{Any, Any}}()) + +struct VerificationError <: Exception + msg::String +end + +Base.showerror(io::IO, err::VerificationError) = + print(io, "VerificationError: ", err.msg) + +const TUNING_PRESETS = ( + fast = (warmup=1, reps=3, refine_topk=0, refine_reps=2), + default = (warmup=2, reps=5, refine_topk=2, refine_reps=4), + thorough = (warmup=2, reps=7, refine_topk=4, refine_reps=6), +) + +struct TuningOptions + warmup::Int + reps::Int + refine_topk::Int + refine_reps::Int + seed::Union{Nothing, Int} + force::Bool + precompile_workers::Int +end + +_tuning_defaults() = ( + seed=nothing, + force=false, + precompile_workers=Threads.nthreads(), +) + +# Lower bound for each count field; the others have no minimum. +const _TUNING_MINIMA = (warmup=0, reps=1, refine_topk=0, refine_reps=1, + precompile_workers=0) + +function normalize_tuning(tuning::NamedTuple) + valid_keys = (:preset, fieldnames(TuningOptions)...) + unknown = setdiff(keys(tuning), valid_keys) + isempty(unknown) || + throw(ArgumentError("Unknown tuning option(s): $(join(unknown, ", "))")) + + preset = get(tuning, :preset, :default) + preset isa Symbol && hasproperty(TUNING_PRESETS, preset) || + throw(ArgumentError("Unknown tuning preset `$preset`; use :fast, :default, or :thorough")) + + overrides = NamedTuple(k => v for (k, v) in pairs(tuning) if k !== :preset) + values = merge(_tuning_defaults(), getproperty(TUNING_PRESETS, preset), overrides) + + # The struct's field types coerce/reject bad value types; we only enforce + # the lower bounds that the types can't. Pull fields by name since `values` + # is in merge order, not struct-field order. + opts = TuningOptions((getproperty(values, f) for f in fieldnames(TuningOptions))...) + for (name, lo) in pairs(_TUNING_MINIMA) + getfield(opts, name) >= lo || + throw(ArgumentError("tuning.$name must be >= $lo, got $(getfield(opts, name))")) + end + return opts +end + +@inline _hint_from_cfg(cfg, name::Symbol, fallback) = + hasproperty(cfg, name) ? getproperty(cfg, name) : fallback + +function hints_from_cfg(cfg; static_num_ctas=nothing, static_occupancy=nothing) + return ( + num_ctas=_hint_from_cfg(cfg, :num_ctas, static_num_ctas), + occupancy=_hint_from_cfg(cfg, :occupancy, static_occupancy), + ) +end + +function _check_static_hint_conflicts(configs; static_num_ctas=nothing, + static_occupancy=nothing) + statics = (num_ctas=static_num_ctas, occupancy=static_occupancy) + for (hint, static_value) in pairs(statics) + static_value === nothing && continue + any(cfg -> hasproperty(cfg, hint), configs) && throw(ArgumentError( + "`$hint` is both a static kwarg and an axis in the search space. Pick one.")) + end + return nothing +end + +function _collect_trials(space::AbstractSearchSpace, seed) + trials = Any[cfg for cfg in space] + if seed !== nothing && length(trials) > 1 + shuffle!(MersenneTwister(seed), trials) + end + return trials +end + +@inline _grid_dims(grid) = grid isa Integer ? (grid,) : grid +@inline _converted_args(args_fn, cfg) = map(cuTileconvert, args_fn(cfg)) +@inline _argtypes(args) = Tuple{map(Core.Typeof, args)...} + +function _compile_cfg(@nospecialize(f), cfg, args_fn; + sm_arch::VersionNumber, opt_level::Int, + static_num_ctas=nothing, static_occupancy=nothing) + converted = _converted_args(args_fn, cfg) + return temporary_cufunction(f, _argtypes(converted); + sm_arch, opt_level, + hints_from_cfg(cfg; static_num_ctas, static_occupancy)...) +end + +function _time_ms(run_once, get_args; + warmup::Int, reps::Int, verify=nothing, reset=nothing) + CUDACore.synchronize() + for _ in 1:max(warmup, verify === nothing ? 0 : 1) + reset !== nothing && reset() + run_once(get_args()) + end + + if verify !== nothing + CUDACore.synchronize() + verify() || throw(VerificationError("config produced incorrect output")) + end + + best_ms = Inf32 + for _ in 1:reps + reset !== nothing && reset() + args = get_args() + CUDACore.synchronize() + elapsed_s = CUDACore.@elapsed run_once(args) + CUDACore.synchronize() + best_ms = min(best_ms, Float32(elapsed_s * 1000)) + end + return best_ms +end + +function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn; + sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, + static_num_ctas=nothing, static_occupancy=nothing, + verify=nothing, reset=nothing) + grid = _grid_dims(grid_fn(cfg)) + kernel = _compile_cfg(f, cfg, args_fn; sm_arch, opt_level, + static_num_ctas, static_occupancy) + + run_once = converted -> kernel(converted...; blocks=grid) + get_args = () -> _converted_args(args_fn, cfg) + return _time_ms(run_once, get_args; warmup, reps, verify, reset) +end + +function precompile_cfg(@nospecialize(f), cfg, args_fn; + sm_arch::VersionNumber, opt_level::Int, + static_num_ctas=nothing, static_occupancy=nothing) + _compile_cfg(f, cfg, args_fn; sm_arch, opt_level, + static_num_ctas, static_occupancy) + return nothing +end + +const TimingRecord = Tuple{Any, Float32} + +function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue, + @nospecialize(f), cfg, grid_fn, args_fn; kwargs...) + ms = try + eval_cfg(f, cfg, grid_fn, args_fn; kwargs...) + catch err + err isa InterruptException && rethrow() + if err isa VerificationError + @warn "Config failed verification; skipping" cfg + else + bt = catch_backtrace() + @debug "Config failed during autotuning; skipping" cfg exception=(err, bt) + end + first_error[] === nothing && (first_error[] = (cfg, err)) + return nothing + end + push!(record, (cfg, ms)) + return nothing +end + +function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; + sm_arch::VersionNumber, opt_level::Int, + warmup::Int, reps::Int, + static_num_ctas=nothing, static_occupancy=nothing, + verify=nothing, reset=nothing) + record = TimingRecord[] + first_error = Ref{Any}(nothing) + for cfg in configs + _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn; + sm_arch, opt_level, warmup, reps, + static_num_ctas, static_occupancy, verify, reset) + end + return record, first_error[] +end + +""" + pipelined_tune(f, configs, grid_fn, args_fn; ...) -> (record, precompile_error, first_error) + +Compile candidate configurations on worker tasks while the caller task measures +completed candidates. Measurement stays on the caller task because CUDA state is +task-local; workers only run the untimed temporary compile path. +""" +function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn; + sm_arch::VersionNumber, opt_level::Int, + warmup::Int, reps::Int, workers::Int, + static_num_ctas=nothing, static_occupancy=nothing, + verify=nothing, reset=nothing) + isempty(configs) && return TimingRecord[], nothing, nothing + + if iszero(workers) || length(configs) == 1 + record, first_error = measure_candidates(f, configs, grid_fn, args_fn; + sm_arch, opt_level, warmup, reps, + static_num_ctas, static_occupancy, verify, reset) + return record, nothing, first_error + end + + workers = min(workers, Threads.nthreads(), length(configs)) + ready = Channel{Any}(length(configs)) + jobs = Channel{Any}(length(configs)) + foreach(cfg -> put!(jobs, cfg), configs) + close(jobs) + + cancelled = Threads.Atomic{Bool}(false) + precompile_error = Ref{Any}(nothing) + error_lock = ReentrantLock() + ctx = CUDACore.context() + + producer = Threads.@spawn try + @sync for _ in 1:workers + Threads.@spawn begin + CUDACore.context!(ctx) + for cfg in jobs + cancelled[] && break + try + precompile_cfg(f, cfg, args_fn; sm_arch, opt_level, + static_num_ctas, static_occupancy) + cancelled[] || put!(ready, cfg) + catch err + if err isa InterruptException + cancelled[] = true + rethrow() + end + lock(error_lock) do + precompile_error[] === nothing && + (precompile_error[] = (cfg, err)) + end + end + end + end + end + finally + close(ready) + end + + record = TimingRecord[] + first_error = Ref{Any}(nothing) + try + for cfg in ready + _measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn; + sm_arch, opt_level, warmup, reps, + static_num_ctas, static_occupancy, verify, reset) + end + wait(producer) + catch + cancelled[] = true + while isready(ready) + take!(ready) + end + try + wait(producer) + catch + end + rethrow() + end + + return record, precompile_error[], first_error[] +end + +@inline _entry_in_trials(entry, trials) = + any(cfg -> cfg == entry.best_config, trials) + +function _cached_entry(kernel_key, arg_key, trials) + entry = Base.@lock AUTOTUNE_CACHE begin + per_kernel = get(AUTOTUNE_CACHE[], kernel_key, nothing) + per_kernel === nothing ? nothing : get(per_kernel, arg_key, nothing) + end + entry !== nothing && _entry_in_trials(entry, trials) ? entry : nothing +end + +function _cache_candidate!(candidate, kernel_key, arg_key, trials; force::Bool) + Base.@lock AUTOTUNE_CACHE begin + per_kernel = get!(AUTOTUNE_CACHE[], kernel_key) do + Dict{Any, Any}() + end + if !force + entry = get(per_kernel, arg_key, nothing) + if entry !== nothing && _entry_in_trials(entry, trials) + return entry, true + end + end + per_kernel[arg_key] = candidate + return candidate, false + end +end + +function _no_valid_config_error(first_error, precompile_error) + err_info = first_error !== nothing ? first_error : precompile_error + if err_info === nothing + throw(ArgumentError("No valid config found in search space")) + end + + cfg, err = err_info + throw(ArgumentError( + "No valid config found. First failure for cfg=$cfg: $(sprint(showerror, err))")) +end + +function _refine_record(@nospecialize(f), record::Vector{TimingRecord}, tuning::TuningOptions, + grid_fn, args_fn; + sm_arch::VersionNumber, opt_level::Int, + static_num_ctas=nothing, static_occupancy=nothing, + verify=nothing, reset=nothing) + (tuning.refine_topk > 0 && length(record) > 1) || return record + + sort!(record, by=last) + top = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] + refined, _ = measure_candidates(f, top, grid_fn, args_fn; + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, + static_num_ctas, static_occupancy, verify, reset) + return isempty(refined) ? record : refined +end + +function _best_candidate(record::Vector{TimingRecord}) + _, best_idx = findmin(last, record) + return (; best_config=record[best_idx][1], tuning_record=record) +end + +function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, + grid_fn, args_fn, tuning::TuningOptions; + sm_arch::VersionNumber, opt_level::Int, kernel_key, arg_key, + static_num_ctas=nothing, static_occupancy=nothing, + verify=nothing, setup=nothing) + trials = _collect_trials(space, tuning.seed) + isempty(trials) && throw(ArgumentError("No valid config found in search space")) + _check_static_hint_conflicts(trials; static_num_ctas, static_occupancy) + + if !tuning.force + entry = _cached_entry(kernel_key, arg_key, trials) + entry !== nothing && return entry, true, nothing + end + + checker = verify !== nothing ? verify() : nothing + reset = setup !== nothing ? setup() : nothing + + record, precompile_error, first_error = + with(_SCOPED_INF_CACHE => _fresh_inf_cache()) do + pipelined_tune(f, trials, grid_fn, args_fn; + sm_arch, opt_level, + warmup=tuning.warmup, reps=tuning.reps, + workers=tuning.precompile_workers, + static_num_ctas, static_occupancy, + verify=checker, reset) + end + + isempty(record) && _no_valid_config_error(first_error, precompile_error) + + record = _refine_record(f, record, tuning, grid_fn, args_fn; + sm_arch, opt_level, + static_num_ctas, static_occupancy, + verify=checker, reset) + + candidate = _best_candidate(record) + entry, cache_hit = _cache_candidate!(candidate, kernel_key, arg_key, trials; + force=tuning.force) + return entry, cache_hit, reset +end + +@inline _as_cfg_fn(f::Function) = f +@inline _as_cfg_fn(x) = Returns(x) + +""" + autotune_launch(f, space, grid, args; key, launch_args, verify, setup, + tuning, sm_arch, opt_level, + num_ctas=nothing, occupancy=nothing) + +Tune `f` over `space` and launch the fastest valid config. + +`space` can be an `AbstractSearchSpace`, a `NamedTuple` of cartesian axes, or +an iterable of `NamedTuple` configs. `grid`, `args`, and `launch_args` can be +plain values or `cfg -> value` functions. Results are cached per +`(f, sm_arch, opt_level, num_ctas, occupancy)` and user `key`. +""" +function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, + grid, args; + key=nothing, + launch_args=nothing, + verify=nothing, + setup=nothing, + tuning::NamedTuple=NamedTuple(), + sm_arch::VersionNumber=default_sm_arch(), + opt_level::Int=3, + num_ctas=nothing, + occupancy=nothing) + tuning = normalize_tuning(tuning) + + grid_fn = _as_cfg_fn(grid) + args_fn = _as_cfg_fn(args) + launch_args_fn = launch_args === nothing ? args_fn : _as_cfg_fn(launch_args) + + kernel_key = (f, sm_arch, opt_level, num_ctas, occupancy) + entry, cache_hit, reset = find_or_tune(f, space, grid_fn, args_fn, tuning; + sm_arch, opt_level, kernel_key, arg_key=key, + static_num_ctas=num_ctas, static_occupancy=occupancy, + verify, setup) + + cfg = entry.best_config + grid = grid_fn(cfg) + launched_args = launch_args_fn(cfg) + + reset !== nothing && reset() + + cuTile.launch(f, grid, launched_args...; sm_arch, opt_level, + hints_from_cfg(cfg; static_num_ctas=num_ctas, + static_occupancy=occupancy)...) + + return (; tuned_config=cfg, grid, tuning_record=copy(entry.tuning_record), cache_hit) +end + +function autotune_launch(@nospecialize(f), configs, grid, args; kwargs...) + space = configs isa NamedTuple ? CartesianSpace(configs) : FixedSpace(configs) + return autotune_launch(f, space, grid, args; kwargs...) +end + +function clear_autotune_cache(; kernel=nothing, key=nothing) + Base.@lock AUTOTUNE_CACHE begin + cache = AUTOTUNE_CACHE[] + if kernel === nothing + key === nothing || throw(ArgumentError("`key` requires `kernel`")) + empty!(cache) + return nothing + end + + for kernel_key in collect(keys(cache)) + kernel_key isa Tuple || continue + kernel_key[1] === kernel || continue + per_kernel = cache[kernel_key] + key === nothing ? empty!(per_kernel) : pop!(per_kernel, key, nothing) + isempty(per_kernel) && delete!(cache, kernel_key) + end + end + return nothing +end diff --git a/src/experimental/autotune_macro.jl b/src/experimental/autotune_macro.jl new file mode 100644 index 00000000..9cdf9b27 --- /dev/null +++ b/src/experimental/autotune_macro.jl @@ -0,0 +1,173 @@ +public @autotune + +# `@autotune`: surface syntax for `autotune_launch`. +# +# Desugars +# +# @autotune( +# key = (T, Dk_pow2), +# space = (TILE_M=(32, 64), TILE_N=(32, 64), occupancy=(1, 2)), +# blocks = (cld(SeqLen_Q, $TILE_M), Heads * Batch), +# mha_fwd(Q, K, V, ..., Constant($TILE_M), Constant($TILE_N)), +# ) +# +# into a call to `autotune_launch`, with `$X` interpolated as `cfg.X` in +# `blocks` and the kernel-call args. `$X` is the parser's `Expr(:$, :X)` +# node, which is normally rejected by lowering; the macro intercepts and +# rewrites it before that happens. + +# Walk an expression replacing `Expr(:$, :X)` with `cfg.X`. +function _autotune_interp(ex, cfg::Symbol) + if Meta.isexpr(ex, :$) + length(ex.args) == 1 || + error("@autotune: \$ takes exactly one argument") + sym = ex.args[1] + sym isa Symbol || + error("@autotune: \$ argument must be a symbol, got `$sym`") + return :($cfg.$sym) + elseif ex isa Expr + return Expr(ex.head, [_autotune_interp(a, cfg) for a in ex.args]...) + else + return ex + end +end + +# Try to extract axis names from a literal NT space expression. +# Returns Vector{Symbol} on success, or `nothing` if the expression +# isn't a recognizable NT literal (e.g. it's a variable reference, a +# function call, or a constructor like `FixedSpace([...])`). +function _autotune_space_axes(space_expr) + Meta.isexpr(space_expr, :tuple) || return nothing + axes = Symbol[] + for a in space_expr.args + if Meta.isexpr(a, :(=)) && a.args[1] isa Symbol + push!(axes, a.args[1]) + elseif Meta.isexpr(a, :parameters) + for kv in a.args + if Meta.isexpr(kv, :kw) && kv.args[1] isa Symbol + push!(axes, kv.args[1]) + else + return nothing + end + end + else + return nothing + end + end + return axes +end + +const _AUTOTUNE_KWARGS = (:key, :space, :blocks, :tuning, :verify, :setup, + :sm_arch, :opt_level, :launch_args, + :num_ctas, :occupancy) + +""" + @autotune key=... space=... blocks=... [kwargs...] kernel(args...) + +Tune `kernel` over `space` and launch the best config. Inside `blocks=` +and the kernel-call args, `\$X` interpolates `cfg.X` (where `cfg` is the +tuning configuration being evaluated). + +# Required kwargs +- `space` - a `NamedTuple` like `(A=(...), B=(...))` (becomes a + `CartesianSpace`), a `Vector` of `NamedTuple`s (becomes a `FixedSpace`), + or any `AbstractSearchSpace` (passed through; useful for + `CartesianSpace(constraint; ...)`). +- `blocks` - grid dimensions, an `Int` or `Tuple`. May reference `\$X`. + +# Optional kwargs +- `key` - cache key (any value) +- `tuning` - `NamedTuple` of tuning knobs (`preset`, `force`, etc.) +- `verify` - `() -> (() -> Bool)` factory; the returned checker is + called after each warmup pass to reject incorrect cfgs +- `setup` - `() -> (() -> Nothing)` factory; reset between reps +- `launch_args` - final-launch args (or `cfg -> args` if it should differ + from the kernel-call args). Use this when the timed args + are throwaway copies (in-place kernels) and the final + launch should hit the real buffers +- `sm_arch`, `opt_level` - forwarded to `cufunction` +- `num_ctas`, `occupancy` - **static** hints applied uniformly to every + cfg. May not coexist with same-named axes in `space` + (the macro flags the conflict at expansion time when + `space` is a literal NT; otherwise `autotune_launch` + catches it at run time) + +# Example + +```julia +@autotune( + key = (eltype(A), size(A, 2)), + space = (TILE_M=(64, 128), TILE_N=(64, 128), occupancy=(1, 2, 4)), + blocks = (cld(M, \$TILE_M), cld(N, \$TILE_N)), + matmul(A, B, C, Constant(\$TILE_M), Constant(\$TILE_N)) +) +``` +""" +macro autotune(args...) + kwargs = Dict{Symbol, Any}() + call = nothing + for arg in args + if Meta.isexpr(arg, :(=)) || Meta.isexpr(arg, :kw) + k, v = arg.args + k isa Symbol || error("@autotune: kwarg key must be a symbol, got `$k`") + k in _AUTOTUNE_KWARGS || + error("@autotune: unknown kwarg `$k`. Valid: $(join(_AUTOTUNE_KWARGS, ", "))") + haskey(kwargs, k) && error("@autotune: duplicate kwarg `$k`") + kwargs[k] = v + elseif Meta.isexpr(arg, :call) + call === nothing || error("@autotune: only one kernel call allowed") + call = arg + else + error("@autotune: unexpected argument `$arg`; expected `kwarg=val` or a kernel call") + end + end + + call === nothing && error("@autotune: missing kernel call (e.g. `kernel(args...)`)") + haskey(kwargs, :space) || error("@autotune: missing required kwarg `space=`") + haskey(kwargs, :blocks) || error("@autotune: missing required kwarg `blocks=`") + + space_expr = kwargs[:space] + blocks_expr = kwargs[:blocks] + + # Macro-time conflict check: if `space` is a literal NT and contains + # `num_ctas`/`occupancy` as an axis, the same name can't also appear + # as a static kwarg. (Opaque spaces are caught at run time.) + space_axes = _autotune_space_axes(space_expr) + if space_axes !== nothing + for hint in (:num_ctas, :occupancy) + if hint in space_axes && haskey(kwargs, hint) + error("@autotune: `$hint` appears both as an axis in `space=` and " * + "as a static kwarg. Pick one.") + end + end + end + + # Extract the kernel call (positional only; no kernel kwargs). + Meta.isexpr(call, :call) || + error("@autotune: kernel must be a function-call expression") + f_expr = call.args[1] + arg_exprs = call.args[2:end] + for a in arg_exprs + if Meta.isexpr(a, :parameters) || Meta.isexpr(a, :kw) + error("@autotune: kernel-side kwargs are not supported; pass values " * + "as positional args (typically wrapped in `Constant(...)`)") + end + end + + # Substitute `$X` -> `cfg.X` inside blocks/args. + cfg = gensym(:cfg) + grid_body = _autotune_interp(blocks_expr, cfg) + arg_subs = [_autotune_interp(a, cfg) for a in arg_exprs] + + grid_fn = :($cfg -> $grid_body) + args_fn = :($cfg -> ($(arg_subs...),)) + + # Forward all macro kwargs (except space/blocks, which are positional + # / lifted into the closures) to `autotune_launch`. + forwarded_keys = (:key, :tuning, :verify, :setup, :sm_arch, :opt_level, + :launch_args, :num_ctas, :occupancy) + kw_exprs = [Expr(:kw, k, kwargs[k]) for k in forwarded_keys if haskey(kwargs, k)] + + launch = GlobalRef(@__MODULE__, :autotune_launch) + return esc(:($launch($f_expr, $space_expr, $grid_fn, $args_fn; $(kw_exprs...)))) +end diff --git a/src/experimental/search_space.jl b/src/experimental/search_space.jl new file mode 100644 index 00000000..d23a3840 --- /dev/null +++ b/src/experimental/search_space.jl @@ -0,0 +1,58 @@ +public AbstractSearchSpace, FixedSpace, CartesianSpace + +abstract type AbstractSearchSpace end + +Base.length(space::AbstractSearchSpace) = count(Returns(true), space) + +struct FixedSpace{T<:NamedTuple} <: AbstractSearchSpace + elements::Vector{T} +end + +FixedSpace(elements::AbstractVector{T}) where {T<:NamedTuple} = + FixedSpace{T}(collect(elements)) + +function FixedSpace(configs) + elements = collect(configs) + all(config -> config isa NamedTuple, elements) || + throw(ArgumentError("FixedSpace requires NamedTuple configs")) + return FixedSpace(NamedTuple[elements...]) +end + +Base.eltype(::Type{<:FixedSpace{T}}) where {T} = T +Base.length(space::FixedSpace) = length(space.elements) +Base.iterate(space::FixedSpace, args...) = iterate(space.elements, args...) + +struct CartesianSpace{names,F,Axes<:NamedTuple{names}} <: AbstractSearchSpace + constraint::F + axes::Axes +end + +_axis_tuple(axis::Tuple) = axis +_axis_tuple(axis) = Tuple(axis) + +function _cartesian_space(constraint::F, axes::NamedTuple) where {F} + tuple_axes = map(_axis_tuple, axes) + return CartesianSpace{keys(tuple_axes),F,typeof(tuple_axes)}(constraint, tuple_axes) +end + +CartesianSpace(axes::NamedTuple) = _cartesian_space(Returns(true), axes) +CartesianSpace(; axes...) = CartesianSpace(NamedTuple(axes)) +CartesianSpace(constraint::Function; axes...) = + _cartesian_space(constraint, NamedTuple(axes)) + +Base.eltype(::Type{<:CartesianSpace{names}}) where {names} = NamedTuple{names} + +function Base.iterate(space::CartesianSpace{names}, state=nothing) where {names} + product, cursor = state === nothing ? + (Iterators.product(values(space.axes)...), nothing) : + (state.product, state.cursor) + + result = cursor === nothing ? iterate(product) : iterate(product, cursor) + while result !== nothing + values, cursor = result + cfg = NamedTuple{names}(values) + space.constraint(cfg) && return cfg, (; product, cursor) + result = iterate(product, cursor) + end + return nothing +end diff --git a/src/launch.jl b/src/launch.jl index c4c097eb..04a1cd1f 100644 --- a/src/launch.jl +++ b/src/launch.jl @@ -123,6 +123,18 @@ TileCacheKey(sm_arch::VersionNumber, bytecode_version::VersionNumber, pack_hint(opt_level), pack_hint(num_ctas), pack_hint(occupancy), pack_hint(num_worker_warps)) +struct TemporaryTileCacheKey + key::TileCacheKey +end + +# Autotune candidates compile under this owner so they share inference/codegen +# work without becoming visible to (or polluting) the normal `cufunction` cache +# keyed by `TileCacheKey`. It's a plain marker wrapper: candidate CIs are keyed +# by content, so the cached set is bounded by the search space, not by how many +# times you tune. The winning config is promoted by the final normal launch. +@inline tile_cache_key(key::TileCacheKey) = key +@inline tile_cache_key(key::TemporaryTileCacheKey) = key.key + #============================================================================= Toolkit / device validation (cached: once per `(capability, cuda_version)`). @@ -394,27 +406,40 @@ end Compilation: bytecode → CUBIN → CuFunction. =============================================================================# +# Serializes the Julia-side codegen pipeline (inference, structured IR, +# tile-IR emission). `emit_tile!` and everything it calls into mutates shared +# state: `CacheView` entries, `CuTileResults` fields, the inference cache, +# and CompilerCaching's per-CI const_entries vector. None of that is +# thread-safe. We hold the lock only across `emit_tile!`; the tileiras +# subprocess below runs unlocked so concurrent `cufunction` calls (e.g. +# from autotuning's precompile fan-out) can still overlap their tileiras +# shell-outs. +const EMIT_TILE_LOCK = ReentrantLock() + """ - emit_binary!(cache, mi, ci, res; const_argtypes=nothing) -> Vector{UInt8} + emit_binary!(cache, mi, ci, res; const_argtypes=nothing, + store_disk=true) -> Vector{UInt8} Cached binary phase: compile Tile IR bytecode to CUBIN using tileiras. """ function emit_binary!(cache::CacheView, mi::Core.MethodInstance, ci::Core.CodeInstance, res::CuTileResults; - const_argtypes::Union{Vector{Any}, Nothing}=nothing) + const_argtypes::Union{Vector{Any}, Nothing}=nothing, + store_disk::Bool=true) # Recurse first — emit_structured! at the bottom of the chain fires # `compile_hook` for `@device_code_*` reflection, which must run on every # launch even when downstream artifacts are fully cached. - bytecode = emit_tile!(cache, mi, ci, res; const_argtypes) + bytecode = Base.@lock EMIT_TILE_LOCK emit_tile!(cache, mi, ci, res; const_argtypes) res.cuda_bin !== nothing && return res.cuda_bin - sm_arch = unpack_version(cache.owner.sm_arch) + owner = tile_cache_key(cache.owner) + sm_arch = unpack_version(owner.sm_arch) # Resolve opt_level here (not in emit_tile) because it's a tileiras flag, not bytecode. # num_ctas/occupancy/num_worker_warps are resolved in emit_tile because they're encoded in bytecode. _, _, kernel_meta = res.julia_ir - opt_level = something(resolve_hint(unpack_hint(cache.owner.opt_level), + opt_level = something(resolve_hint(unpack_hint(owner.opt_level), kernel_meta, :opt_level, sm_arch), 3) # Disk cache lookup. The hash covers every input that changes the CUBIN @@ -467,7 +492,7 @@ function emit_binary!(cache::CacheView, mi::Core.MethodInstance, rm(output_path, force=true) end - if cache_key !== nothing + if store_disk && cache_key !== nothing try DiskCache.put!(dc, cache_key, res.cuda_bin) catch err @@ -555,7 +580,24 @@ function cufunction(@nospecialize(f), tt::Type{<:Tuple}=Tuple{}; # invalidated by any package that defines methods on Base.Compiler hooks # like `OptimizationParams(::AbstractInterpreter)`. To reuse precompiled # native code, run the pipeline in the world captured at __init__. - invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes, key)::TileKernel{Core.Typeof(f), tt} + invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes, + key, true)::TileKernel{Core.Typeof(f), tt} +end + +function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple}; + sm_arch::Union{VersionNumber, Nothing}=nothing, + opt_level::Union{Int, Nothing}=nothing, + num_ctas::Union{Int, Nothing}=nothing, + occupancy::Union{Int, Nothing}=nothing, + num_worker_warps::Union{Int, Nothing}=nothing) + bytecode_version = check_tile_ir_support() + resolved_sm_arch = sm_arch !== nothing ? sm_arch : default_sm_arch() + key = TileCacheKey(resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy, + num_worker_warps) + owner = TemporaryTileCacheKey(key) + argtypes, const_argtypes = unwrap_argtypes(f, tt) + return invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes, + owner, false)::TileKernel{Core.Typeof(f), tt} end """ @@ -567,7 +609,7 @@ by [`link`](@ref) to load the result onto the GPU. No CUDA context required. """ function compile(@nospecialize(f), @nospecialize(argtypes), const_argtypes::Union{Vector{Any}, Nothing}, - key::TileCacheKey) + key, store_disk::Bool=true) world = Base.get_world_counter() mi = method_instance(f, argtypes; world) mi === nothing && throw(MethodError(f, argtypes)) @@ -584,12 +626,17 @@ function compile(@nospecialize(f), @nospecialize(argtypes), # underlying `CodeInstance` via CompilerCaching; the `TileKernel` wrapper # rides along in the same `CuTileResults`, so kernel-instance lifecycle # follows the CI's instead of needing a separate global Dict. - ci, res = ensure_compiled(cache, mi, const_argtypes) + # + # Held under EMIT_TILE_LOCK so concurrent compiles (e.g. autotuning's + # precompile fan-out) don't race on inference / CompilerCaching state. + # The lock-protected region also includes the lookup fast path; that + # path is just a hashtable read, so brief contention here is fine. + ci, res = Base.@lock EMIT_TILE_LOCK ensure_compiled(cache, mi, const_argtypes) # Always walk the emit chain (each phase short-circuits on its own cached # field, but `emit_structured!` also fires `compile_hook` for reflection, # which has to run on every launch even when the cube/cufunc is cached). - emit_binary!(cache, mi, ci, res; const_argtypes) + emit_binary!(cache, mi, ci, res; const_argtypes, store_disk) return cache, mi, ci, res end @@ -598,8 +645,8 @@ end # even when later-loaded packages would otherwise have invalidated it. function cufunction_compile(@nospecialize(f), @nospecialize(tt), @nospecialize(argtypes), const_argtypes::Union{Vector{Any}, Nothing}, - key::TileCacheKey) - cache, mi, ci, res = compile(f, argtypes, const_argtypes, key) + key, store_disk::Bool=true) + cache, mi, ci, res = compile(f, argtypes, const_argtypes, key, store_disk) cufunc = link(cache, mi, ci, res) diff --git a/test/device/autotune.jl b/test/device/autotune.jl new file mode 100644 index 00000000..6fecf8e0 --- /dev/null +++ b/test/device/autotune.jl @@ -0,0 +1,454 @@ +using CUDA + +const Exp = ct.Experimental + +@testset "Autotune" begin + + function vadd_kernel(a::ct.TileArray{Float32,1}, + b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}, + tile::Int) + pid = ct.bid(1) + ta = ct.load(a, pid, (tile,)) + tb = ct.load(b, pid, (tile,)) + ct.store(c, pid, ta + tb) + return nothing + end + + function inplace_add_kernel(x::ct.TileArray{Float32,1}, + tile::Int) + pid = ct.bid(1) + tx = ct.load(x, pid, (tile,)) + ct.store(x, pid, tx .+ 1f0) + return nothing + end + + function cache_probe_kernel(a::ct.TileArray{Float32,1}, + b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}, + tile::Int) + pid = ct.bid(1) + ta = ct.load(a, pid, (tile,)) + tb = ct.load(b, pid, (tile,)) + ct.store(c, pid, ta + tb) + return nothing + end + + function normal_const_entry_count(f, args) + converted = map(ct.cuTileconvert, args) + tt = Tuple{map(Core.Typeof, converted)...} + argtypes, _ = ct.unwrap_argtypes(f, tt) + + world = Base.get_world_counter() + key = ct.TileCacheKey(ct.default_sm_arch(), ct.bytecode_version(), + 3, nothing, nothing, nothing) + cache = ct.CompilerCaching.CacheView{ct.CuTileResults}(key, world) + mi = ct.CompilerCaching.method_instance(f, argtypes; world) + ci = get(cache, mi, nothing) + ci === nothing && return 0 + + cached = ct.CC.traverse_analysis_results(ci) do result + result isa ct.CompilerCaching.CachedResult{ct.CuTileResults} ? + result : nothing + end + cached === nothing && return 0 + return length(cached.const_entries) + end + + n = 512 + a = CUDA.fill(1f0, n) + b = CUDA.fill(2f0, n) + c = CUDA.zeros(Float32, n) + + # cfg 1's `occupancy=nothing, num_ctas=nothing` slots are present for + # `FixedSpace` shape uniformity — cfgs 2 and 3 carry real hint values, + # and `FixedSpace{names, NT<:NamedTuple{names}}` requires every element + # to share the same `names` set. + configs = [ + (; tile=16, occupancy=nothing, num_ctas=nothing), + (; tile=32, occupancy=2, num_ctas=nothing), + (; tile=64, occupancy=4, num_ctas=2), + ] + args_fn = cfg -> (a, b, c, ct.Constant(cfg.tile)) + grid_fn = cfg -> cld(n, cfg.tile) + + @testset "basic tuning" begin + Exp.clear_autotune_cache() + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test !result.cache_hit + @test result.tuned_config in configs + @test !isempty(result.tuning_record) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "cache hit" begin + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "force retune" begin + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0, force=true), + ) + @test !result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "CartesianSpace" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + # Keep `occupancy=(nothing, 2)` — legitimately tunes between + # "no hint" and 2. Single-value `nothing` axes are noise (see + # other testsets); a 2-value axis with `nothing` as one option + # is meaningful. + space = Exp.CartesianSpace(; + tile=(16, 32), occupancy=(nothing, 2)) + result = Exp.autotune_launch( + vadd_kernel, space, grid_fn, args_fn; + key=(:cartesian, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test hasproperty(result.tuned_config, :tile) + @test hasproperty(result.tuned_config, :occupancy) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "CartesianSpace with constraint" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + space = Exp.CartesianSpace( + cfg -> cfg.tile == 16; + tile=(16, 32, 64)) + result = Exp.autotune_launch( + vadd_kernel, space, grid_fn, args_fn; + key=(:constrained, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.tuned_config.tile == 16 + @test Array(c) ≈ fill(3f0, n) + end + + @testset "CartesianSpace range axis" begin + @test collect(Exp.CartesianSpace(tile=16:16:32)) == + [(; tile=16), (; tile=32)] + end + + @testset "NamedTuple convenience → CartesianSpace" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, + (tile=(16, 32),), + grid_fn, args_fn; + key=(:nt_convenience, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.tuned_config.tile in (16, 32) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "launch_args (inplace kernel)" begin + x = CUDA.zeros(Float32, n) + original_x = Array(x) + Exp.clear_autotune_cache() + result = Exp.autotune_launch( + inplace_add_kernel, + [(; tile=16), (; tile=32)], + grid_fn, + cfg -> (copy(x), ct.Constant(cfg.tile)); + launch_args=cfg -> (x, ct.Constant(cfg.tile)), + key=(:inplace, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test !result.cache_hit + @test Array(x) == original_x .+ 1f0 + end + + @testset "refinement" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:refine, n), + tuning=(warmup=1, reps=2, refine_topk=2, refine_reps=4), + ) + @test !result.cache_hit + # Refinement record replaces initial — has at most refine_topk entries + @test length(result.tuning_record) <= 2 + @test Array(c) ≈ fill(3f0, n) + end + + @testset "verify" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + verify_called = Ref(false) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:verify, n), + tuning=(preset=:fast, refine_topk=0), + verify=() -> let + ref = Array(a) .+ Array(b) + verify_called[] = true + () -> (CUDA.@allowscalar all(isapprox.(Array(c), ref, atol=1f-5))) + end, + ) + @test verify_called[] + @test Array(c) ≈ fill(3f0, n) + end + + @testset "clear cache per-kernel per-key" begin + Exp.clear_autotune_cache() + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k1, n), tuning=(preset=:fast, refine_topk=0)) + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k2, n), tuning=(preset=:fast, refine_topk=0)) + + # Clear only one key + Exp.clear_autotune_cache(kernel=vadd_kernel, key=(:k1, n)) + fill!(c, 0f0) + r1 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k1, n), tuning=(preset=:fast, refine_topk=0)) + @test !r1.cache_hit # was cleared + + fill!(c, 0f0) + r2 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k2, n), tuning=(preset=:fast, refine_topk=0)) + @test r2.cache_hit # still cached + end + + @testset "shared key across shapes" begin + Exp.clear_autotune_cache() + n2 = 1024 + a2 = CUDA.fill(1f0, n2) + b2 = CUDA.fill(2f0, n2) + c2 = CUDA.zeros(Float32, n2) + shared_key = (:shape_agnostic, eltype(a)) + + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=shared_key, tuning=(preset=:fast, refine_topk=0)) + + fill!(c2, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, + cfg -> cld(n2, cfg.tile), + cfg -> (a2, b2, c2, ct.Constant(cfg.tile)); + key=shared_key, tuning=(preset=:fast, refine_topk=0)) + @test result.cache_hit + @test result.grid == cld(n2, result.tuned_config.tile) + @test Array(c2) ≈ fill(3f0, n2) + end + + @testset "literal grid/args (no closure)" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + # Pass `grid` and `args` as values rather than `cfg -> …` closures. + # cfg-independent grid: cld(n, tile) == cld(512, 16) for every cfg + # so the literal `32` happens to be valid here. + result = Exp.autotune_launch( + vadd_kernel, + [(; tile=16)], + cld(n, 16), # literal grid + (a, b, c, ct.Constant(16)); # literal args + key=(:literal, n), + tuning=(preset=:fast, refine_topk=0)) + @test !result.cache_hit + @test result.grid == cld(n, 16) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "static num_ctas / occupancy as kwargs" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + # `space` has no num_ctas/occupancy axes — they're static kwargs. + result = Exp.autotune_launch( + vadd_kernel, + [(; tile=16), (; tile=32)], + cfg -> cld(n, cfg.tile), + cfg -> (a, b, c, ct.Constant(cfg.tile)); + key=(:static_hints, n), + occupancy=2, + tuning=(preset=:fast, refine_topk=0)) + @test !result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "conflict: static + space axis" begin + Exp.clear_autotune_cache() + # Run-time path (opaque space): autotune_launch should reject. + @test_throws ArgumentError Exp.autotune_launch( + vadd_kernel, + [(; tile=16, occupancy=2)], + cfg -> cld(n, cfg.tile), + cfg -> (a, b, c, ct.Constant(cfg.tile)); + key=(:conflict, n), + occupancy=4, + tuning=(preset=:fast, refine_topk=0)) + end + + @testset "conflict scans every config" begin + space = Exp.FixedSpace(Any[(; tile=16), (; tile=32, occupancy=2)]) + @test_throws ArgumentError Exp.autotune_launch( + vadd_kernel, + space, + cfg -> cld(n, cfg.tile), + cfg -> (a, b, c, ct.Constant(cfg.tile)); + key=(:conflict_late, n), + occupancy=4, + tuning=(preset=:fast, refine_topk=0)) + end + + @testset "tuning validation" begin + @test_throws ArgumentError Exp.autotune_launch( + vadd_kernel, + [(; tile=16)], + grid_fn, args_fn; + key=(:bad_reps, n), + tuning=(preset=:fast, reps=0)) + + @test_throws ArgumentError Exp.autotune_launch( + vadd_kernel, + [(; tile=16)], + grid_fn, args_fn; + key=(:bad_key, n), + tuning=(preset=:fast, typo=1)) + end + + @testset "cached config must belong to current space" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + r1 = Exp.autotune_launch( + vadd_kernel, + [(; tile=16)], + grid_fn, args_fn; + key=(:space_sensitive, n), + tuning=(preset=:fast, refine_topk=0)) + @test !r1.cache_hit + @test r1.tuned_config.tile == 16 + + fill!(c, 0f0) + r2 = Exp.autotune_launch( + vadd_kernel, + [(; tile=32)], + grid_fn, args_fn; + key=(:space_sensitive, n), + tuning=(preset=:fast, refine_topk=0)) + @test !r2.cache_hit + @test r2.tuned_config.tile == 32 + @test Array(c) ≈ fill(3f0, n) + end + + @testset "only winner enters normal compiler cache" begin + Exp.clear_autotune_cache() + probe_c = CUDA.zeros(Float32, n) + probe_configs = [(; tile=16), (; tile=32), (; tile=64)] + probe_args = cfg -> (a, b, probe_c, ct.Constant(cfg.tile)) + @test normal_const_entry_count(cache_probe_kernel, probe_args(probe_configs[1])) == 0 + + result = Exp.autotune_launch( + cache_probe_kernel, + probe_configs, + cfg -> cld(n, cfg.tile), + probe_args; + key=(:temporary_candidates, n), + tuning=(preset=:fast, refine_topk=0)) + + @test result.tuned_config in probe_configs + @test normal_const_entry_count(cache_probe_kernel, probe_args(probe_configs[1])) == 1 + @test Array(probe_c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: NT space" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.@autotune( + key = (:macro_nt, n), + space = (tile=(16, 32, 64),), + blocks = cld(n, $tile), + tuning = (preset=:fast, refine_topk=0), + vadd_kernel(a, b, c, ct.Constant($tile)), + ) + @test !result.cache_hit + @test result.tuned_config.tile in (16, 32, 64) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: vector space" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.@autotune( + key = (:macro_vec, n), + space = [(; tile=16), (; tile=32)], + blocks = cld(n, $tile), + tuning = (preset=:fast, refine_topk=0), + vadd_kernel(a, b, c, ct.Constant($tile)), + ) + @test !result.cache_hit + @test result.tuned_config.tile in (16, 32) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: tuple blocks + 2D \$interp" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + # Use a 1D kernel but pass a Tuple blocks=(N, 1) to exercise the + # tuple-grid + $X-interp-in-blocks path. + result = Exp.@autotune( + key = (:macro_tuple, n), + space = (tile=(16, 32),), + blocks = (cld(n, $tile), 1), + tuning = (preset=:fast, refine_topk=0), + vadd_kernel(a, b, c, ct.Constant($tile)), + ) + @test result.grid == (cld(n, result.tuned_config.tile), 1) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: static num_ctas as kwarg" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.@autotune( + key = (:macro_static, n), + space = (tile=(16, 32),), + blocks = cld(n, $tile), + occupancy = 2, + tuning = (preset=:fast, refine_topk=0), + vadd_kernel(a, b, c, ct.Constant($tile)), + ) + @test !result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "@autotune macro: macro-time conflict error" begin + # Should error at macro expansion (not run time). + @test_throws LoadError @eval Exp.@autotune( + space = (tile=(16,), num_ctas=(1, 2)), + blocks = 1, + num_ctas = 4, + kernel(a), + ) + end + + @testset "@autotune macro: required kwargs" begin + @test_throws LoadError @eval Exp.@autotune(blocks = 1, kernel(a)) + @test_throws LoadError @eval Exp.@autotune(space = (tile=(16,),), kernel(a)) + @test_throws LoadError @eval Exp.@autotune(space = (tile=(16,),), blocks = 1) + end +end