Skip to content

Commit a097938

Browse files
committed
cleanup
1 parent 3f2a2ec commit a097938

2 files changed

Lines changed: 46 additions & 89 deletions

File tree

src/experimental/autotune.jl

Lines changed: 39 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ const TUNING_PRESETS = (
1515
thorough = (warmup=2, reps=7, refine_topk=4, refine_reps=6),
1616
)
1717

18-
const TUNING_KEYS = (
19-
:warmup, :reps, :refine_topk, :refine_reps,
20-
:seed, :force, :precompile_workers,
21-
)
22-
2318
struct TuningOptions
2419
warmup::Int
2520
reps::Int
@@ -30,65 +25,38 @@ struct TuningOptions
3025
precompile_workers::Int
3126
end
3227

33-
struct TuningSession
34-
id::UInt
35-
end
36-
37-
const NEXT_TUNING_SESSION_ID = Threads.Atomic{UInt}(0)
38-
39-
TuningSession() =
40-
TuningSession(Threads.atomic_add!(NEXT_TUNING_SESSION_ID, UInt(1)) + UInt(1))
41-
4228
_tuning_defaults() = (
4329
seed=nothing,
4430
force=false,
4531
precompile_workers=Threads.nthreads(),
4632
)
4733

48-
function _check_int(name::Symbol, value; min::Int)
49-
(value isa Integer && !(value isa Bool)) ||
50-
throw(ArgumentError("tuning.$name must be an integer, got $(typeof(value))"))
51-
value >= min ||
52-
throw(ArgumentError("tuning.$name must be >= $min, got $value"))
53-
return Int(value)
54-
end
55-
56-
function _check_seed(seed)
57-
seed === nothing && return nothing
58-
(seed isa Integer && !(seed isa Bool)) ||
59-
throw(ArgumentError("tuning.seed must be an integer or nothing, got $(typeof(seed))"))
60-
return Int(seed)
61-
end
62-
63-
function _check_bool(name::Symbol, value)
64-
value isa Bool ||
65-
throw(ArgumentError("tuning.$name must be a Bool, got $(typeof(value))"))
66-
return value
67-
end
34+
# Lower bound for each count field; the others have no minimum.
35+
const _TUNING_MINIMA = (warmup=0, reps=1, refine_topk=0, refine_reps=1,
36+
precompile_workers=0)
6837

6938
function normalize_tuning(tuning::NamedTuple)
70-
valid_keys = (:preset, TUNING_KEYS...)
71-
unknown = setdiff(collect(keys(tuning)), collect(valid_keys))
39+
valid_keys = (:preset, fieldnames(TuningOptions)...)
40+
unknown = setdiff(keys(tuning), valid_keys)
7241
isempty(unknown) ||
7342
throw(ArgumentError("Unknown tuning option(s): $(join(unknown, ", "))"))
7443

7544
preset = get(tuning, :preset, :default)
76-
preset isa Symbol || throw(ArgumentError("tuning.preset must be a Symbol"))
77-
hasproperty(TUNING_PRESETS, preset) ||
45+
preset isa Symbol && hasproperty(TUNING_PRESETS, preset) ||
7846
throw(ArgumentError("Unknown tuning preset `$preset`; use :fast, :default, or :thorough"))
7947

8048
overrides = NamedTuple(k => v for (k, v) in pairs(tuning) if k !== :preset)
8149
values = merge(_tuning_defaults(), getproperty(TUNING_PRESETS, preset), overrides)
8250

83-
return TuningOptions(
84-
_check_int(:warmup, values.warmup; min=0),
85-
_check_int(:reps, values.reps; min=1),
86-
_check_int(:refine_topk, values.refine_topk; min=0),
87-
_check_int(:refine_reps, values.refine_reps; min=1),
88-
_check_seed(values.seed),
89-
_check_bool(:force, values.force),
90-
_check_int(:precompile_workers, values.precompile_workers; min=0),
91-
)
51+
# The struct's field types coerce/reject bad value types; we only enforce
52+
# the lower bounds that the types can't. Pull fields by name since `values`
53+
# is in merge order, not struct-field order.
54+
opts = TuningOptions((getproperty(values, f) for f in fieldnames(TuningOptions))...)
55+
for (name, lo) in pairs(_TUNING_MINIMA)
56+
getfield(opts, name) >= lo ||
57+
throw(ArgumentError("tuning.$name must be >= $lo, got $(getfield(opts, name))"))
58+
end
59+
return opts
9260
end
9361

9462
@inline _hint_from_cfg(cfg, name::Symbol, fallback) =
@@ -101,22 +69,14 @@ function hints_from_cfg(cfg; static_num_ctas=nothing, static_occupancy=nothing)
10169
)
10270
end
10371

104-
function _check_static_hint_conflict(configs, hint::Symbol, static_value)
105-
static_value === nothing && return nothing
106-
for cfg in configs
107-
if hasproperty(cfg, hint)
108-
throw(ArgumentError(
109-
"`$hint` is both a static kwarg and an axis in the search space. " *
110-
"Pick one."))
111-
end
112-
end
113-
return nothing
114-
end
115-
11672
function _check_static_hint_conflicts(configs; static_num_ctas=nothing,
11773
static_occupancy=nothing)
118-
_check_static_hint_conflict(configs, :num_ctas, static_num_ctas)
119-
_check_static_hint_conflict(configs, :occupancy, static_occupancy)
74+
statics = (num_ctas=static_num_ctas, occupancy=static_occupancy)
75+
for (hint, static_value) in pairs(statics)
76+
static_value === nothing && continue
77+
any(cfg -> hasproperty(cfg, hint), configs) && throw(ArgumentError(
78+
"`$hint` is both a static kwarg and an axis in the search space. Pick one."))
79+
end
12080
return nothing
12181
end
12282

@@ -132,11 +92,11 @@ end
13292
@inline _converted_args(args_fn, cfg) = map(cuTileconvert, args_fn(cfg))
13393
@inline _argtypes(args) = Tuple{map(Core.Typeof, args)...}
13494

135-
function _compile_cfg(@nospecialize(f), cfg, args_fn, session::TuningSession;
95+
function _compile_cfg(@nospecialize(f), cfg, args_fn;
13696
sm_arch::VersionNumber, opt_level::Int,
13797
static_num_ctas=nothing, static_occupancy=nothing)
13898
converted = _converted_args(args_fn, cfg)
139-
return temporary_cufunction(f, _argtypes(converted), session.id;
99+
return temporary_cufunction(f, _argtypes(converted);
140100
sm_arch, opt_level,
141101
hints_from_cfg(cfg; static_num_ctas, static_occupancy)...)
142102
end
@@ -166,34 +126,33 @@ function _time_ms(run_once, get_args;
166126
return best_ms
167127
end
168128

169-
function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn, session::TuningSession;
129+
function eval_cfg(@nospecialize(f), cfg, grid_fn, args_fn;
170130
sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int,
171131
static_num_ctas=nothing, static_occupancy=nothing,
172132
verify=nothing, reset=nothing)
173133
grid = _grid_dims(grid_fn(cfg))
174-
kernel = _compile_cfg(f, cfg, args_fn, session; sm_arch, opt_level,
134+
kernel = _compile_cfg(f, cfg, args_fn; sm_arch, opt_level,
175135
static_num_ctas, static_occupancy)
176136

177137
run_once = converted -> kernel(converted...; blocks=grid)
178138
get_args = () -> _converted_args(args_fn, cfg)
179139
return _time_ms(run_once, get_args; warmup, reps, verify, reset)
180140
end
181141

182-
function precompile_cfg(@nospecialize(f), cfg, args_fn, session::TuningSession;
142+
function precompile_cfg(@nospecialize(f), cfg, args_fn;
183143
sm_arch::VersionNumber, opt_level::Int,
184144
static_num_ctas=nothing, static_occupancy=nothing)
185-
_compile_cfg(f, cfg, args_fn, session; sm_arch, opt_level,
145+
_compile_cfg(f, cfg, args_fn; sm_arch, opt_level,
186146
static_num_ctas, static_occupancy)
187147
return nothing
188148
end
189149

190150
const TimingRecord = Tuple{Any, Float32}
191151

192152
function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue,
193-
@nospecialize(f), cfg, grid_fn, args_fn,
194-
session::TuningSession; kwargs...)
153+
@nospecialize(f), cfg, grid_fn, args_fn; kwargs...)
195154
ms = try
196-
eval_cfg(f, cfg, grid_fn, args_fn, session; kwargs...)
155+
eval_cfg(f, cfg, grid_fn, args_fn; kwargs...)
197156
catch err
198157
err isa InterruptException && rethrow()
199158
if err isa VerificationError
@@ -209,16 +168,15 @@ function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue,
209168
return nothing
210169
end
211170

212-
function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn,
213-
session::TuningSession;
171+
function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn;
214172
sm_arch::VersionNumber, opt_level::Int,
215173
warmup::Int, reps::Int,
216174
static_num_ctas=nothing, static_occupancy=nothing,
217175
verify=nothing, reset=nothing)
218176
record = TimingRecord[]
219177
first_error = Ref{Any}(nothing)
220178
for cfg in configs
221-
_measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn, session;
179+
_measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn;
222180
sm_arch, opt_level, warmup, reps,
223181
static_num_ctas, static_occupancy, verify, reset)
224182
end
@@ -232,16 +190,15 @@ Compile candidate configurations on worker tasks while the caller task measures
232190
completed candidates. Measurement stays on the caller task because CUDA state is
233191
task-local; workers only run the untimed temporary compile path.
234192
"""
235-
function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn,
236-
session::TuningSession;
193+
function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn;
237194
sm_arch::VersionNumber, opt_level::Int,
238195
warmup::Int, reps::Int, workers::Int,
239196
static_num_ctas=nothing, static_occupancy=nothing,
240197
verify=nothing, reset=nothing)
241198
isempty(configs) && return TimingRecord[], nothing, nothing
242199

243200
if iszero(workers) || length(configs) == 1
244-
record, first_error = measure_candidates(f, configs, grid_fn, args_fn, session;
201+
record, first_error = measure_candidates(f, configs, grid_fn, args_fn;
245202
sm_arch, opt_level, warmup, reps,
246203
static_num_ctas, static_occupancy, verify, reset)
247204
return record, nothing, first_error
@@ -265,7 +222,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn
265222
for cfg in jobs
266223
cancelled[] && break
267224
try
268-
precompile_cfg(f, cfg, args_fn, session; sm_arch, opt_level,
225+
precompile_cfg(f, cfg, args_fn; sm_arch, opt_level,
269226
static_num_ctas, static_occupancy)
270227
cancelled[] || put!(ready, cfg)
271228
catch err
@@ -289,7 +246,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn
289246
first_error = Ref{Any}(nothing)
290247
try
291248
for cfg in ready
292-
_measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn, session;
249+
_measure_cfg!(record, first_error, f, cfg, grid_fn, args_fn;
293250
sm_arch, opt_level, warmup, reps,
294251
static_num_ctas, static_occupancy, verify, reset)
295252
end
@@ -348,15 +305,15 @@ function _no_valid_config_error(first_error, precompile_error)
348305
end
349306

350307
function _refine_record(@nospecialize(f), record::Vector{TimingRecord}, tuning::TuningOptions,
351-
grid_fn, args_fn, session::TuningSession;
308+
grid_fn, args_fn;
352309
sm_arch::VersionNumber, opt_level::Int,
353310
static_num_ctas=nothing, static_occupancy=nothing,
354311
verify=nothing, reset=nothing)
355312
(tuning.refine_topk > 0 && length(record) > 1) || return record
356313

357314
sort!(record, by=last)
358315
top = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]]
359-
refined, _ = measure_candidates(f, top, grid_fn, args_fn, session;
316+
refined, _ = measure_candidates(f, top, grid_fn, args_fn;
360317
sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps,
361318
static_num_ctas, static_occupancy, verify, reset)
362319
return isempty(refined) ? record : refined
@@ -383,11 +340,10 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace,
383340

384341
checker = verify !== nothing ? verify() : nothing
385342
reset = setup !== nothing ? setup() : nothing
386-
session = TuningSession()
387343

388344
record, precompile_error, first_error =
389345
with(_SCOPED_INF_CACHE => _fresh_inf_cache()) do
390-
pipelined_tune(f, trials, grid_fn, args_fn, session;
346+
pipelined_tune(f, trials, grid_fn, args_fn;
391347
sm_arch, opt_level,
392348
warmup=tuning.warmup, reps=tuning.reps,
393349
workers=tuning.precompile_workers,
@@ -397,7 +353,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace,
397353

398354
isempty(record) && _no_valid_config_error(first_error, precompile_error)
399355

400-
record = _refine_record(f, record, tuning, grid_fn, args_fn, session;
356+
record = _refine_record(f, record, tuning, grid_fn, args_fn;
401357
sm_arch, opt_level,
402358
static_num_ctas, static_occupancy,
403359
verify=checker, reset)

src/launch.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,13 @@ TileCacheKey(sm_arch::VersionNumber, bytecode_version::VersionNumber,
125125

126126
struct TemporaryTileCacheKey
127127
key::TileCacheKey
128-
session_id::UInt
129128
end
130129

131-
# Autotune candidates use this owner so they can share work within one tuning
132-
# pass without becoming visible to the normal `cufunction` cache keyed by
133-
# `TileCacheKey`. The winning config is promoted by the final normal launch.
130+
# Autotune candidates compile under this owner so they share inference/codegen
131+
# work without becoming visible to (or polluting) the normal `cufunction` cache
132+
# keyed by `TileCacheKey`. It's a plain marker wrapper: candidate CIs are keyed
133+
# by content, so the cached set is bounded by the search space, not by how many
134+
# times you tune. The winning config is promoted by the final normal launch.
134135
@inline tile_cache_key(key::TileCacheKey) = key
135136
@inline tile_cache_key(key::TemporaryTileCacheKey) = key.key
136137

@@ -583,7 +584,7 @@ function cufunction(@nospecialize(f), tt::Type{<:Tuple}=Tuple{};
583584
key, true)::TileKernel{Core.Typeof(f), tt}
584585
end
585586

586-
function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple}, session_id::UInt;
587+
function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple};
587588
sm_arch::Union{VersionNumber, Nothing}=nothing,
588589
opt_level::Union{Int, Nothing}=nothing,
589590
num_ctas::Union{Int, Nothing}=nothing,
@@ -593,7 +594,7 @@ function temporary_cufunction(@nospecialize(f), tt::Type{<:Tuple}, session_id::U
593594
resolved_sm_arch = sm_arch !== nothing ? sm_arch : default_sm_arch()
594595
key = TileCacheKey(resolved_sm_arch, bytecode_version, opt_level, num_ctas, occupancy,
595596
num_worker_warps)
596-
owner = TemporaryTileCacheKey(key, session_id)
597+
owner = TemporaryTileCacheKey(key)
597598
argtypes, const_argtypes = unwrap_argtypes(f, tt)
598599
return invoke_frozen(cufunction_compile, f, tt, argtypes, const_argtypes,
599600
owner, false)::TileKernel{Core.Typeof(f), tt}

0 commit comments

Comments
 (0)