Skip to content

Commit 3f2a2ec

Browse files
committed
Retain context
1 parent 943c12f commit 3f2a2ec

1 file changed

Lines changed: 20 additions & 16 deletions

File tree

src/experimental/autotune.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -252,27 +252,31 @@ function pipelined_tune(@nospecialize(f), configs::Vector{Any}, grid_fn, args_fn
252252
jobs = Channel{Any}(length(configs))
253253
foreach(cfg -> put!(jobs, cfg), configs)
254254
close(jobs)
255-
255+
256256
cancelled = Threads.Atomic{Bool}(false)
257257
precompile_error = Ref{Any}(nothing)
258-
error_lock = ReentrantLock()
258+
error_lock = ReentrantLock()
259+
ctx = CUDACore.context()
259260

260261
producer = Threads.@spawn try
261262
@sync for _ in 1:workers
262-
Threads.@spawn for cfg in jobs
263-
cancelled[] && break
264-
try
265-
precompile_cfg(f, cfg, args_fn, session; sm_arch, opt_level,
266-
static_num_ctas, static_occupancy)
267-
cancelled[] || put!(ready, cfg)
268-
catch err
269-
if err isa InterruptException
270-
cancelled[] = true
271-
rethrow()
272-
end
273-
lock(error_lock) do
274-
precompile_error[] === nothing &&
275-
(precompile_error[] = (cfg, err))
263+
Threads.@spawn begin
264+
CUDACore.context!(ctx)
265+
for cfg in jobs
266+
cancelled[] && break
267+
try
268+
precompile_cfg(f, cfg, args_fn, session; sm_arch, opt_level,
269+
static_num_ctas, static_occupancy)
270+
cancelled[] || put!(ready, cfg)
271+
catch err
272+
if err isa InterruptException
273+
cancelled[] = true
274+
rethrow()
275+
end
276+
lock(error_lock) do
277+
precompile_error[] === nothing &&
278+
(precompile_error[] = (cfg, err))
279+
end
276280
end
277281
end
278282
end

0 commit comments

Comments
 (0)