@@ -15,11 +15,6 @@ const TUNING_PRESETS = (
1515 thorough = (warmup= 2 , reps= 7 , refine_topk= 4 , refine_reps= 6 ),
1616)
1717
18- const TUNING_KEYS = (
19- :warmup , :reps , :refine_topk , :refine_reps ,
20- :seed , :force , :precompile_workers ,
21- )
22-
2318struct TuningOptions
2419 warmup:: Int
2520 reps:: Int
@@ -30,65 +25,38 @@ struct TuningOptions
3025 precompile_workers:: Int
3126end
3227
33- struct TuningSession
34- id:: UInt
35- end
36-
37- const NEXT_TUNING_SESSION_ID = Threads. Atomic {UInt} (0 )
38-
39- TuningSession () =
40- TuningSession (Threads. atomic_add! (NEXT_TUNING_SESSION_ID, UInt (1 )) + UInt (1 ))
41-
4228_tuning_defaults () = (
4329 seed= nothing ,
4430 force= false ,
4531 precompile_workers= Threads. nthreads (),
4632)
4733
48- function _check_int (name:: Symbol , value; min:: Int )
49- (value isa Integer && ! (value isa Bool)) ||
50- throw (ArgumentError (" tuning.$name must be an integer, got $(typeof (value)) " ))
51- value >= min ||
52- throw (ArgumentError (" tuning.$name must be >= $min , got $value " ))
53- return Int (value)
54- end
55-
56- function _check_seed (seed)
57- seed === nothing && return nothing
58- (seed isa Integer && ! (seed isa Bool)) ||
59- throw (ArgumentError (" tuning.seed must be an integer or nothing, got $(typeof (seed)) " ))
60- return Int (seed)
61- end
62-
63- function _check_bool (name:: Symbol , value)
64- value isa Bool ||
65- throw (ArgumentError (" tuning.$name must be a Bool, got $(typeof (value)) " ))
66- return value
67- end
34+ # Lower bound for each count field; the others have no minimum.
35+ const _TUNING_MINIMA = (warmup= 0 , reps= 1 , refine_topk= 0 , refine_reps= 1 ,
36+ precompile_workers= 0 )
6837
6938function normalize_tuning (tuning:: NamedTuple )
70- valid_keys = (:preset , TUNING_KEYS ... )
71- unknown = setdiff (collect ( keys (tuning)), collect ( valid_keys) )
39+ valid_keys = (:preset , fieldnames (TuningOptions) ... )
40+ unknown = setdiff (keys (tuning), valid_keys)
7241 isempty (unknown) ||
7342 throw (ArgumentError (" Unknown tuning option(s): $(join (unknown, " , " )) " ))
7443
7544 preset = get (tuning, :preset , :default )
76- preset isa Symbol || throw (ArgumentError (" tuning.preset must be a Symbol" ))
77- hasproperty (TUNING_PRESETS, preset) ||
45+ preset isa Symbol && hasproperty (TUNING_PRESETS, preset) ||
7846 throw (ArgumentError (" Unknown tuning preset `$preset `; use :fast, :default, or :thorough" ))
7947
8048 overrides = NamedTuple (k => v for (k, v) in pairs (tuning) if k != = :preset )
8149 values = merge (_tuning_defaults (), getproperty (TUNING_PRESETS, preset), overrides)
8250
83- return TuningOptions (
84- _check_int ( :warmup , values . warmup; min = 0 ),
85- _check_int ( :reps , values . reps; min = 1 ),
86- _check_int ( :refine_topk , values . refine_topk; min = 0 ),
87- _check_int ( :refine_reps , values . refine_reps; min = 1 ),
88- _check_seed (values . seed),
89- _check_bool ( :force , values . force),
90- _check_int ( :precompile_workers , values . precompile_workers; min = 0 ),
91- )
51+ # The struct's field types coerce/reject bad value types; we only enforce
52+ # the lower bounds that the types can't. Pull fields by name since `values`
53+ # is in merge order, not struct-field order.
54+ opts = TuningOptions (( getproperty (values, f) for f in fieldnames (TuningOptions)) . .. )
55+ for (name, lo) in pairs (_TUNING_MINIMA)
56+ getfield (opts, name) >= lo ||
57+ throw ( ArgumentError ( " tuning. $name must be >= $lo , got $( getfield (opts, name)) " ))
58+ end
59+ return opts
9260end
9361
9462@inline _hint_from_cfg (cfg, name:: Symbol , fallback) =
@@ -101,22 +69,14 @@ function hints_from_cfg(cfg; static_num_ctas=nothing, static_occupancy=nothing)
10169 )
10270end
10371
104- function _check_static_hint_conflict (configs, hint:: Symbol , static_value)
105- static_value === nothing && return nothing
106- for cfg in configs
107- if hasproperty (cfg, hint)
108- throw (ArgumentError (
109- " `$hint ` is both a static kwarg and an axis in the search space. " *
110- " Pick one." ))
111- end
112- end
113- return nothing
114- end
115-
11672function _check_static_hint_conflicts (configs; static_num_ctas= nothing ,
11773 static_occupancy= nothing )
118- _check_static_hint_conflict (configs, :num_ctas , static_num_ctas)
119- _check_static_hint_conflict (configs, :occupancy , static_occupancy)
74+ statics = (num_ctas= static_num_ctas, occupancy= static_occupancy)
75+ for (hint, static_value) in pairs (statics)
76+ static_value === nothing && continue
77+ any (cfg -> hasproperty (cfg, hint), configs) && throw (ArgumentError (
78+ " `$hint ` is both a static kwarg and an axis in the search space. Pick one." ))
79+ end
12080 return nothing
12181end
12282
13292@inline _converted_args (args_fn, cfg) = map (cuTileconvert, args_fn (cfg))
13393@inline _argtypes (args) = Tuple{map (Core. Typeof, args)... }
13494
135- function _compile_cfg (@nospecialize (f), cfg, args_fn, session :: TuningSession ;
95+ function _compile_cfg (@nospecialize (f), cfg, args_fn;
13696 sm_arch:: VersionNumber , opt_level:: Int ,
13797 static_num_ctas= nothing , static_occupancy= nothing )
13898 converted = _converted_args (args_fn, cfg)
139- return temporary_cufunction (f, _argtypes (converted), session . id ;
99+ return temporary_cufunction (f, _argtypes (converted);
140100 sm_arch, opt_level,
141101 hints_from_cfg (cfg; static_num_ctas, static_occupancy)... )
142102end
@@ -166,34 +126,33 @@ function _time_ms(run_once, get_args;
166126 return best_ms
167127end
168128
169- function eval_cfg (@nospecialize (f), cfg, grid_fn, args_fn, session :: TuningSession ;
129+ function eval_cfg (@nospecialize (f), cfg, grid_fn, args_fn;
170130 sm_arch:: VersionNumber , opt_level:: Int , warmup:: Int , reps:: Int ,
171131 static_num_ctas= nothing , static_occupancy= nothing ,
172132 verify= nothing , reset= nothing )
173133 grid = _grid_dims (grid_fn (cfg))
174- kernel = _compile_cfg (f, cfg, args_fn, session ; sm_arch, opt_level,
134+ kernel = _compile_cfg (f, cfg, args_fn; sm_arch, opt_level,
175135 static_num_ctas, static_occupancy)
176136
177137 run_once = converted -> kernel (converted... ; blocks= grid)
178138 get_args = () -> _converted_args (args_fn, cfg)
179139 return _time_ms (run_once, get_args; warmup, reps, verify, reset)
180140end
181141
182- function precompile_cfg (@nospecialize (f), cfg, args_fn, session :: TuningSession ;
142+ function precompile_cfg (@nospecialize (f), cfg, args_fn;
183143 sm_arch:: VersionNumber , opt_level:: Int ,
184144 static_num_ctas= nothing , static_occupancy= nothing )
185- _compile_cfg (f, cfg, args_fn, session ; sm_arch, opt_level,
145+ _compile_cfg (f, cfg, args_fn; sm_arch, opt_level,
186146 static_num_ctas, static_occupancy)
187147 return nothing
188148end
189149
190150const TimingRecord = Tuple{Any, Float32}
191151
192152function _measure_cfg! (record:: Vector{TimingRecord} , first_error:: Base.RefValue ,
193- @nospecialize (f), cfg, grid_fn, args_fn,
194- session:: TuningSession ; kwargs... )
153+ @nospecialize (f), cfg, grid_fn, args_fn; kwargs... )
195154 ms = try
196- eval_cfg (f, cfg, grid_fn, args_fn, session ; kwargs... )
155+ eval_cfg (f, cfg, grid_fn, args_fn; kwargs... )
197156 catch err
198157 err isa InterruptException && rethrow ()
199158 if err isa VerificationError
@@ -209,16 +168,15 @@ function _measure_cfg!(record::Vector{TimingRecord}, first_error::Base.RefValue,
209168 return nothing
210169end
211170
212- function measure_candidates (@nospecialize (f), configs:: Vector{Any} , grid_fn, args_fn,
213- session:: TuningSession ;
171+ function measure_candidates (@nospecialize (f), configs:: Vector{Any} , grid_fn, args_fn;
214172 sm_arch:: VersionNumber , opt_level:: Int ,
215173 warmup:: Int , reps:: Int ,
216174 static_num_ctas= nothing , static_occupancy= nothing ,
217175 verify= nothing , reset= nothing )
218176 record = TimingRecord[]
219177 first_error = Ref {Any} (nothing )
220178 for cfg in configs
221- _measure_cfg! (record, first_error, f, cfg, grid_fn, args_fn, session ;
179+ _measure_cfg! (record, first_error, f, cfg, grid_fn, args_fn;
222180 sm_arch, opt_level, warmup, reps,
223181 static_num_ctas, static_occupancy, verify, reset)
224182 end
@@ -232,16 +190,15 @@ Compile candidate configurations on worker tasks while the caller task measures
232190completed candidates. Measurement stays on the caller task because CUDA state is
233191task-local; workers only run the untimed temporary compile path.
234192"""
235- function pipelined_tune (@nospecialize (f), configs:: Vector{Any} , grid_fn, args_fn,
236- session:: TuningSession ;
193+ function pipelined_tune (@nospecialize (f), configs:: Vector{Any} , grid_fn, args_fn;
237194 sm_arch:: VersionNumber , opt_level:: Int ,
238195 warmup:: Int , reps:: Int , workers:: Int ,
239196 static_num_ctas= nothing , static_occupancy= nothing ,
240197 verify= nothing , reset= nothing )
241198 isempty (configs) && return TimingRecord[], nothing , nothing
242199
243200 if iszero (workers) || length (configs) == 1
244- record, first_error = measure_candidates (f, configs, grid_fn, args_fn, session ;
201+ record, first_error = measure_candidates (f, configs, grid_fn, args_fn;
245202 sm_arch, opt_level, warmup, reps,
246203 static_num_ctas, static_occupancy, verify, reset)
247204 return record, nothing , first_error
@@ -265,7 +222,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn
265222 for cfg in jobs
266223 cancelled[] && break
267224 try
268- precompile_cfg (f, cfg, args_fn, session ; sm_arch, opt_level,
225+ precompile_cfg (f, cfg, args_fn; sm_arch, opt_level,
269226 static_num_ctas, static_occupancy)
270227 cancelled[] || put! (ready, cfg)
271228 catch err
@@ -289,7 +246,7 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn
289246 first_error = Ref {Any} (nothing )
290247 try
291248 for cfg in ready
292- _measure_cfg! (record, first_error, f, cfg, grid_fn, args_fn, session ;
249+ _measure_cfg! (record, first_error, f, cfg, grid_fn, args_fn;
293250 sm_arch, opt_level, warmup, reps,
294251 static_num_ctas, static_occupancy, verify, reset)
295252 end
@@ -348,15 +305,15 @@ function _no_valid_config_error(first_error, precompile_error)
348305end
349306
350307function _refine_record (@nospecialize (f), record:: Vector{TimingRecord} , tuning:: TuningOptions ,
351- grid_fn, args_fn, session :: TuningSession ;
308+ grid_fn, args_fn;
352309 sm_arch:: VersionNumber , opt_level:: Int ,
353310 static_num_ctas= nothing , static_occupancy= nothing ,
354311 verify= nothing , reset= nothing )
355312 (tuning. refine_topk > 0 && length (record) > 1 ) || return record
356313
357314 sort! (record, by= last)
358315 top = Any[first (r) for r in record[1 : min (tuning. refine_topk, length (record))]]
359- refined, _ = measure_candidates (f, top, grid_fn, args_fn, session ;
316+ refined, _ = measure_candidates (f, top, grid_fn, args_fn;
360317 sm_arch, opt_level, warmup= tuning. warmup, reps= tuning. refine_reps,
361318 static_num_ctas, static_occupancy, verify, reset)
362319 return isempty (refined) ? record : refined
@@ -383,11 +340,10 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace,
383340
384341 checker = verify != = nothing ? verify () : nothing
385342 reset = setup != = nothing ? setup () : nothing
386- session = TuningSession ()
387343
388344 record, precompile_error, first_error =
389345 with (_SCOPED_INF_CACHE => _fresh_inf_cache ()) do
390- pipelined_tune (f, trials, grid_fn, args_fn, session ;
346+ pipelined_tune (f, trials, grid_fn, args_fn;
391347 sm_arch, opt_level,
392348 warmup= tuning. warmup, reps= tuning. reps,
393349 workers= tuning. precompile_workers,
@@ -397,7 +353,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace,
397353
398354 isempty (record) && _no_valid_config_error (first_error, precompile_error)
399355
400- record = _refine_record (f, record, tuning, grid_fn, args_fn, session ;
356+ record = _refine_record (f, record, tuning, grid_fn, args_fn;
401357 sm_arch, opt_level,
402358 static_num_ctas, static_occupancy,
403359 verify= checker, reset)
0 commit comments