diff --git a/src/internal_rules/core.jl b/src/internal_rules/core.jl index 23024580c4..ccaee4bc52 100644 --- a/src/internal_rules/core.jl +++ b/src/internal_rules/core.jl @@ -411,3 +411,473 @@ function EnzymeRules.reverse(config, ::Const{typeof(Base.finalizer)}, dret, tape # No-op return (nothing, nothing) end + +# -------------------- Base.wait EnzymeRules -------------------- + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.wait)}, + ::Type{<:Const}, + t::Duplicated{<:Task}, +) + Base.wait(t.val) + Base.wait(t.dval) + return nothing +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.wait)}, + ::Type{<:Const}, + t::BatchDuplicated{T, N}, +) where {T<:Task,N} + Base.wait(t.val) + for i in 1:N + Base.wait(t.dval[i]) + end + return nothing +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.wait)}, + ::Type{<:Const}, + t::Const{<:Task}, +) + Base.wait(t.val) + return nothing +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.wait)}, + ::Type{<:Const}, + t::Annotation{<:Task}, +) + Base.wait(t.val) + return EnzymeRules.AugmentedReturn(nothing, nothing, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.wait)}, + ::Type{<:Const}, + tape, + t::Duplicated{<:Task}, +) + if isdefined(t, :dval) + try + if !Base.istaskstarted(t.dval) && !Base.istaskdone(t.dval) + Base.schedule(t.dval) + end + catch e + e isa UndefRefError && rethrow() + end + end + return (nothing,) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.wait)}, + ::Type{<:Const}, + tape, + t::BatchDuplicated{T, N}, +) where {T<:Task,N} + for i in 1:N + if isdefined(t.dval, i) + try + if !Base.istaskstarted(t.dval[i]) && !Base.istaskdone(t.dval[i]) + Base.schedule(t.dval[i]) + end + catch e + e isa UndefRefError && rethrow() + end + end + end + return (nothing,) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.wait)}, + ::Type{<:Const}, + tape, + t::Const{<:Task}, +) + return (nothing,) +end + +# -------------------- Base._wait EnzymeRules -------------------- + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base._wait)}, + ::Type{<:Const}, + t::Duplicated{<:Task}, +) + Base._wait(t.val) + Base._wait(t.dval) + return nothing +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base._wait)}, + ::Type{<:Const}, + t::BatchDuplicated{T, N}, +) where {T<:Task,N} + Base._wait(t.val) + for i in 1:N + Base._wait(t.dval[i]) + end + return nothing +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base._wait)}, + ::Type{<:Const}, + t::Const{<:Task}, +) + Base._wait(t.val) + return nothing +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base._wait)}, + ::Type{<:Const}, + t::Annotation{<:Task}, +) + Base._wait(t.val) + return EnzymeRules.AugmentedReturn(nothing, nothing, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base._wait)}, + ::Type{<:Const}, + tape, + t::Duplicated{<:Task}, +) + # the reverse of _wait is to enqueue the shadow + if isdefined(t, :dval) + try + if !Base.istaskstarted(t.dval) && !Base.istaskdone(t.dval) + Base.schedule(t.dval) + end + catch e + e isa UndefRefError && rethrow() + end + end + return (nothing,) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base._wait)}, + ::Type{<:Const}, + tape, + t::BatchDuplicated{T, N}, +) where {T<:Task,N} + for i in 1:N + if isdefined(t.dval, i) + try + if !Base.istaskstarted(t.dval[i]) && !Base.istaskdone(t.dval[i]) + Base.schedule(t.dval[i]) + end + catch e + e isa UndefRefError && rethrow() + end + end + end + return (nothing,) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base._wait)}, + ::Type{<:Const}, + tape, + t::Const{<:Task}, +) + return (nothing,) +end + +@inline function _fwd_task_return(config::EnzymeRules.FwdConfig, t::Annotation{<:Task}) + needs_primal = EnzymeRules.needs_primal(config) + needs_shadow = EnzymeRules.needs_shadow(config) + + if !needs_shadow + if needs_primal + return t.val + else + return nothing + end + else + if !needs_primal + if t isa BatchDuplicated + return t.dval + else + return t.dval + end + else + return t + end + end +end + +# -------------------- Base.schedule EnzymeRules -------------------- + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.schedule)}, + ::Type{RT}, + t::Duplicated{<:Task}, +) where RT + try + if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val) + Base.schedule(t.val) + end + catch + end + if isdefined(t, :dval) + try + if !Base.istaskstarted(t.dval) && !Base.istaskdone(t.dval) + Base.schedule(t.dval) + end + catch e + e isa UndefRefError && rethrow() + end + end + return _fwd_task_return(config, t) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.schedule)}, + ::Type{RT}, + t::BatchDuplicated{T, N}, +) where {RT, T<:Task, N} + if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val) + Base.schedule(t.val) + end + for i in 1:N + if isdefined(t.dval, i) + try + if !Base.istaskstarted(t.dval[i]) && !Base.istaskdone(t.dval[i]) + Base.schedule(t.dval[i]) + end + catch e + e isa UndefRefError || rethrow() + end + end + end + return _fwd_task_return(config, t) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.schedule)}, + ::Type{RT}, + t::Const{<:Task}, +) where RT + if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val) + Base.schedule(t.val) + end + return _fwd_task_return(config, t) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.schedule)}, + ::Type{RT}, + t::Annotation{<:Task}, +) where RT + try + if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val) + Base.schedule(t.val) + end + catch + end + primal = EnzymeRules.needs_primal(config) ? t.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? (t isa Const ? nothing : t.dval) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.schedule)}, + ::Type{RT}, + tape, + t::Duplicated{<:Task}, +) where RT + # the reverse of schedule is to wait for the shadow + if isdefined(t, :dval) + try + if !Base.istaskdone(t.dval) + Base.wait(t.dval) + end + catch e + e isa UndefRefError || rethrow() + end + end + return (nothing,) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.schedule)}, + ::Type{RT}, + tape, + t::BatchDuplicated{T, N}, +) where {RT,T<:Task,N} + for i in 1:N + if isdefined(t.dval, i) + try + if !Base.istaskdone(t.dval[i]) + Base.wait(t.dval[i]) + end + catch e + e isa UndefRefError || rethrow() + end + end + end + return (nothing,) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.schedule)}, + ::Type{RT}, + tape, + t::Const{<:Task}, +) where RT + return (nothing,) +end + +# -------------------- Base.enq_work EnzymeRules -------------------- + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.enq_work)}, + ::Type{RT}, + t::Duplicated{<:Task}, +) where RT + if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val) + Base.enq_work(t.val) + end + if isdefined(t, :dval) + try + if !Base.istaskstarted(t.dval) && !Base.istaskdone(t.dval) + Base.enq_work(t.dval) + end + catch e + e isa UndefRefError && rethrow() + end + end + return _fwd_task_return(config, t) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.enq_work)}, + ::Type{RT}, + t::BatchDuplicated{T, N}, +) where {RT, T<:Task, N} + if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val) + Base.enq_work(t.val) + end + for i in 1:N + if isdefined(t.dval, i) + try + if !Base.istaskstarted(t.dval[i]) && !Base.istaskdone(t.dval[i]) + Base.enq_work(t.dval[i]) + end + catch e + e isa UndefRefError && rethrow() + end + end + end + return _fwd_task_return(config, t) +end + +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + ::Const{typeof(Base.enq_work)}, + ::Type{RT}, + t::Const{<:Task}, +) where RT + if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val) + Base.enq_work(t.val) + end + return _fwd_task_return(config, t) +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.enq_work)}, + ::Type{RT}, + t::Annotation{<:Task}, +) where RT + try + if !Base.istaskstarted(t.val) && !Base.istaskdone(t.val) + Base.enq_work(t.val) + end + catch + end + primal = EnzymeRules.needs_primal(config) ? t.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? (t isa Const ? nothing : t.dval) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.enq_work)}, + ::Type{RT}, + tape, + t::Duplicated{<:Task}, +) where RT + if isdefined(t, :dval) + try + if !Base.istaskdone(t.dval) + Base.wait(t.dval) + end + catch e + e isa UndefRefError || rethrow() + end + end + return (nothing,) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.enq_work)}, + ::Type{RT}, + tape, + t::BatchDuplicated{T, N}, +) where {RT,T<:Task,N} + for i in 1:N + if isdefined(t.dval, i) + try + if !Base.istaskdone(t.dval[i]) + Base.wait(t.dval[i]) + end + catch e + e isa UndefRefError || rethrow() + end + end + end + return (nothing,) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfig, + ::Const{typeof(Base.enq_work)}, + ::Type{RT}, + tape, + t::Const{<:Task}, +) where RT + return (nothing,) +end diff --git a/test/threads.jl b/test/threads.jl index 20db7e9203..fd67bb8486 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -154,3 +154,66 @@ end @test ty[3] ≈ [8.0, 8.0] @test ty[4] ≈ [10.0, 10.0] end + +@testset "Task Rules" begin + function wait_f(t) + Base.wait(t) + nothing + end + function _wait_f(t) + Base._wait(t) + nothing + end + function schedule_f(t) + Base.schedule(t) + nothing + end + function enq_work_f(t) + Base.enq_work(t) + nothing + end + + t1 = Task(()->nothing) + t2 = Task(()->nothing) + t3 = Task(()->nothing) + t4 = Task(()->nothing) + + # Pre-schedule tasks for wait so we don't deadlock + Base.schedule(t1) + Base.schedule(t2) + Base.schedule(t3) + Base.schedule(t4) + Base.wait(t1) + Base.wait(t2) + Base.wait(t3) + Base.wait(t4) + + @test Enzyme.autodiff(Reverse, wait_f, Const(t1)) === () + @test Enzyme.autodiff(Reverse, wait_f, Duplicated(t1, t2)) === () + @test Enzyme.autodiff(Reverse, wait_f, BatchDuplicated(t1, (t2, t3))) === () + + @test Enzyme.autodiff(Forward, wait_f, Const(t1)) === () + @test Enzyme.autodiff(Forward, wait_f, Duplicated(t1, t2)) === () + @test Enzyme.autodiff(Forward, wait_f, BatchDuplicated(t1, (t2, t3))) === () + + @test Enzyme.autodiff(Reverse, _wait_f, Const(t1)) === () + @test Enzyme.autodiff(Reverse, _wait_f, Duplicated(t1, t2)) === () + @test Enzyme.autodiff(Reverse, _wait_f, BatchDuplicated(t1, (t2, t3))) === () + + @test Enzyme.autodiff(Forward, _wait_f, Const(t1)) === () + @test Enzyme.autodiff(Forward, _wait_f, Duplicated(t1, t2)) === () + @test Enzyme.autodiff(Forward, _wait_f, BatchDuplicated(t1, (t2, t3))) === () + + + t5 = Task(()->nothing) + t6 = Task(()->nothing) + t7 = Task(()->nothing) + t8 = Task(()->nothing) + + @test Enzyme.autodiff(Reverse, schedule_f, Const(t5)) === () + @test Enzyme.autodiff(Forward, schedule_f, Const(t6)) === () + + t9 = Task(()->nothing); t10 = Task(()->nothing) + @test Enzyme.autodiff(Reverse, enq_work_f, Const(t9)) === () + @test Enzyme.autodiff(Forward, enq_work_f, Const(t10)) === () +end