Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/EnzymeTestUtils/src/test_forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ constraints:
- `rtol`: Relative tolerance for `isapprox`.
- `atol`: Absolute tolerance for `isapprox`.
- `testset_name`: Name to use for a testset in which all tests are evaluated.
- `ignore_const_checks::Bool=false`: If `true`, skip the post-call assertion that
`Const` arguments were mutated identically by the rule and the primal function.
Useful when a rule legitimately scribbles on a `Const` scratch buffer.

# Examples

Expand Down Expand Up @@ -61,7 +64,8 @@ function test_forward(
rtol::Real=1e-9,
atol::Real=1e-9,
testset_name=nothing,
runtime_activity::Bool=false
runtime_activity::Bool=false,
ignore_const_checks::Bool=false,
)
call_with_copy = CallWithCopyKWargs(fkwargs)
call_with_kwargs = CallWithKWargs(fkwargs)
Expand Down Expand Up @@ -126,6 +130,7 @@ function test_forward(
rtol,
)
for (i, (act_i, arg_i)) in enumerate(zip(Base.tail(activities), args_copy))
ignore_const_checks && act_i isa Const && continue
test_approx(
act_i.val,
arg_i,
Expand Down
12 changes: 10 additions & 2 deletions lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ additional constraints:
- `rtol`: Relative tolerance for `isapprox`.
- `atol`: Absolute tolerance for `isapprox`.
- `testset_name`: Name to use for a testset in which all tests are evaluated.
- `output_tangent`: Optional final tangent to provide at the beginning of the reverse-mode differentiation
- `output_tangent`: Optional final tangent to provide at the beginning of the reverse-mode differentiation
- `ignore_const_checks::Bool=false`: If `true`, skip the post-call assertion that
`Const` arguments were mutated identically by the rule and the primal function,
and skip the FD-vs-AD derivative comparison for `Const` arguments. Useful when a
rule legitimately scribbles on a `Const` scratch buffer.

# Examples

Expand Down Expand Up @@ -76,7 +80,9 @@ function test_reverse(
atol::Real=1e-9,
testset_name=nothing,
runtime_activity::Bool=false,
output_tangent=nothing)
output_tangent=nothing,
ignore_const_checks::Bool=false,
)
call_with_captured_kwargs = CallWithKWargs(fkwargs)
if testset_name === nothing
testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))"
Expand Down Expand Up @@ -119,6 +125,7 @@ function test_reverse(
rtol,
)
for (i, (act_i, arg_i)) in enumerate(zip(Base.tail(activities), args_copy))
ignore_const_checks && act_i isa Const && continue
test_approx(
act_i.val,
arg_i,
Expand Down Expand Up @@ -146,6 +153,7 @@ function test_reverse(
@test length(dx_ad) == length(dx_fdm) == length(activities)
# check all returned derivatives against FiniteDifferences
for (i, (act_i, dx_ad_i, dx_fdm_i)) in enumerate(zip(activities, dx_ad, dx_fdm))
ignore_const_checks && act_i isa Const && continue
target_str = if i == 1
"active derivative for callable"
else
Expand Down
41 changes: 41 additions & 0 deletions lib/EnzymeTestUtils/test/test_forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@ function f_kwargs_fwd!(x; kwargs...)
return nothing
end

# Function with a scratch buffer that the rule mutates only during AD.
# The primal does not touch `scratch`, so the rule's scribble would normally fail
# the post-call mutation check unless `ignore_const_checks=true`.
f_const_scratch_fwd(x, scratch) = x .^ 2

function EnzymeRules.forward(
config,
func::Const{typeof(f_const_scratch_fwd)},
RT::Type{<:Union{Const,Duplicated,DuplicatedNoNeed}},
x::Union{Const,Duplicated},
scratch::Const,
)
scratch.val .= x.val # AD-only scribble on Const scratch buffer
if RT <: Const
return func.val(x.val, scratch.val)
end
dval = x isa Duplicated ? 2 .* x.val .* x.dval : zero(x.val)
if RT <: DuplicatedNoNeed
return dval
else
return Duplicated(func.val(x.val, scratch.val), dval)
end
end

function EnzymeRules.forward(
config,
func::Const{typeof(f_kwargs_fwd)},
Expand Down Expand Up @@ -214,6 +238,23 @@ end
end
end

@testset "ignore_const_checks" begin
x = randn(3)
scratch = zeros(3)
@test fails() do
test_forward(f_const_scratch_fwd, Duplicated, (x, Duplicated), (scratch, Const))
end
@test !fails() do
test_forward(
f_const_scratch_fwd,
Duplicated,
(x, Duplicated),
(scratch, Const);
ignore_const_checks=true,
)
end
end

@testset "mutated callable" begin
n = 3
@testset for Tret in (Const, Duplicated, BatchDuplicated),
Expand Down
49 changes: 49 additions & 0 deletions lib/EnzymeTestUtils/test/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,38 @@ function f_kwargs_rev!(x; kwargs...)
return nothing
end

# Function with a scratch buffer that the rule mutates only during AD.
# The primal does not touch `scratch`, so the rule's scribble would normally fail
# the post-call mutation check unless `ignore_const_checks=true`.
f_const_scratch_rev(x, scratch) = sum(abs2, x)

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(f_const_scratch_rev)},
RT::Type{<:Union{Const,Active}},
x::Union{Const,Duplicated},
scratch::Const,
)
scratch.val .= x.val # AD-only scribble on Const scratch buffer
primal = EnzymeRules.needs_primal(config) ? func.val(x.val, scratch.val) : nothing
tape = copy(x.val)
return EnzymeRules.AugmentedReturn(primal, nothing, tape)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(f_const_scratch_rev)},
dret::Union{Active,Type{<:Const}},
tape,
x::Union{Const,Duplicated},
scratch::Const,
)
if !(x isa Const) && dret isa Active
x.dval .+= 2 .* dret.val .* tape
end
return (nothing, nothing)
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(f_kwargs_rev)},
Expand Down Expand Up @@ -234,6 +266,23 @@ end
end
end

@testset "ignore_const_checks" begin
x = randn(3)
scratch = zeros(3)
@test fails() do
test_reverse(f_const_scratch_rev, Active, (x, Duplicated), (scratch, Const))
end
@test !fails() do
test_reverse(
f_const_scratch_rev,
Active,
(x, Duplicated),
(scratch, Const);
ignore_const_checks=true,
)
end
end

@testset "incorrect tangent detected" begin
@testset for Tx in (Duplicated,)
x = randn(3)
Expand Down
Loading