@@ -5,9 +5,13 @@ using ChainRulesCore
55using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
66using Symbolics
77using LinearAlgebra
8-
98using Test
109
10+ const fwd = Diffractor. PrimeDerivativeFwd
11+ const bwd = Diffractor. PrimeDerivativeBack
12+
13+ @testset verbose= true " Diffractor.jl" begin # overall testset, ensures all tests run
14+
1115# Unit tests
1216function tup2 (f)
1317 a, b = ∂⃖ {2} ()(f, 1 )
@@ -88,9 +92,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
8892 @test @inferred (sin' (1.0 )) == cos (1.0 )
8993 @test @inferred (sin'' (1.0 )) == - sin (1.0 )
9094 @test sin''' (1.0 ) == - cos (1.0 )
91- @test sin'''' (1.0 ) == sin (1.0 )
92- @test sin''''' (1.0 ) == cos (1.0 )
93- @test sin'''''' (1.0 ) == - sin (1.0 )
95+ @test sin'''' (1.0 ) == sin (1.0 ) broken = VERSION >= v " 1.8 "
96+ @test sin''''' (1.0 ) == cos (1.0 ) broken = VERSION >= v " 1.8 "
97+ @test sin'''''' (1.0 ) == - sin (1.0 ) broken = VERSION >= v " 1.8 "
9498
9599 f_getfield (x) = getfield ((x,), 1 )
96100 @test f_getfield' (1 ) == 1
@@ -101,9 +105,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
101105
102106 complicated_2sin (x) = (x = map (sin, Diffractor. xfill (x, 2 )); x[1 ] + x[2 ])
103107 @test @inferred (complicated_2sin' (1.0 )) == 2 sin' (1.0 )
104- @test @inferred (complicated_2sin'' (1.0 )) == 2 sin'' (1.0 )
105- @test @inferred (complicated_2sin''' (1.0 )) == 2 sin''' (1.0 )
106- @test @inferred (complicated_2sin'''' (1.0 )) == 2 sin'''' (1.0 )
108+ @test @inferred (complicated_2sin'' (1.0 )) == 2 sin'' (1.0 ) broken = true
109+ @test @inferred (complicated_2sin''' (1.0 )) == 2 sin''' (1.0 ) broken = true
110+ @test @inferred (complicated_2sin'''' (1.0 )) == 2 sin'''' (1.0 ) broken = true
107111
108112 # Control flow cases
109113 @test @inferred ((x-> simple_control_flow (true , x))' (1.0 )) == sin' (1.0 )
149153# Regression tests
150154@test gradient (x -> sum (abs2, x .+ 1.0 ), zeros (3 ))[1 ] == [2.0 , 2.0 , 2.0 ]
151155
152- const fwd = Diffractor. PrimeDerivativeFwd
153- const bwd = Diffractor. PrimeDerivativeBack
154-
155156function f_broadcast (a)
156157 l = a / 2.0 * [[0. 1. 1. ]; [1. 0. 1. ]; [1. 1. 0. ]]
157158 return sum (l)
161162# Make sure that there's no infinite recursion in kwarg calls
162163g_kw (;x= 1.0 ) = sin (x)
163164f_kw (x) = g_kw (;x)
164- @test bwd (f_kw)(1.0 ) == bwd (sin)(1.0 )
165+ @test bwd (f_kw)(1.0 ) == bwd (sin)(1.0 ) broken= true
166+ #=
167+ MethodError: no method matching +(::Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}, ::Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}})
168+ ...
169+ [2] elementwise_add(a::NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}, b::NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}})
170+ @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/tangent.jl:287
171+ [3] +(a::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}}, b::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}}})
172+ @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:130
173+ =#
165174
166175function f_crit_edge (a, b, c, x)
167176 # A function with two critical edges. This used to trigger an issue where
@@ -220,3 +229,5 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
220229
221230# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
222231# include("pinn.jl")
232+
233+ end # overall testset
0 commit comments