Skip to content

Commit 29ba61e

Browse files
committed
revert changes
1 parent 6872c31 commit 29ba61e

2 files changed

Lines changed: 2 additions & 58 deletions

File tree

src/common/pullbacks.jl

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,5 @@ function iszerotangent end
1111
iszerotangent(::Any) = false
1212
iszerotangent(::Nothing) = true
1313

14-
# Solve the Sylvester equation A*X + X*B + C = 0.
15-
# When A === B (same Hermitian PD matrix, as in polar pullbacks), use an
16-
# eigendecomposition-based solver to avoid LAPACK's trsyl! failing with
17-
# LAPACKException(1) for close eigenvalues.
18-
function _sylvester(A, B, C)
19-
if A === B
20-
return _sylvester_symm(A, C)
21-
end
22-
return LinearAlgebra.sylvester(A, B, C)
23-
end
24-
25-
function _sylvester_symm(P, C)
26-
D, Q = LinearAlgebra.eigen(LinearAlgebra.Hermitian(P))
27-
Y = Q' * C * Q
28-
@inbounds for j in axes(Y, 2), i in axes(Y, 1)
29-
Y[i, j] = -Y[i, j] / (D[i] + D[j])
30-
end
31-
return Q * Y * Q'
32-
end
14+
# fallback
15+
_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C)

test/testsuite/enzyme.jl

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ function test_enzyme(T::Type, sz; kwargs...)
105105
test_enzyme_polar(T, sz; kwargs...)
106106
test_enzyme_orthnull(T, sz; kwargs...)
107107
end
108-
test_enzyme_projections(T, sz; kwargs...)
109108
end
110109
end
111110

@@ -463,41 +462,3 @@ function test_enzyme_orthnull(
463462
end
464463
end
465464
end
466-
467-
function test_enzyme_projections(
468-
T::Type, sz;
469-
atol::Real = 0, rtol::Real = precision(T),
470-
kwargs...
471-
)
472-
summary_str = testargs_summary(T, sz)
473-
return @testset "Projections Enzyme AD rules $summary_str" begin
474-
A = instantiate_matrix(T, sz)
475-
m, n = size(A)
476-
fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
477-
if m == n
478-
@testset "project_hermitian" begin
479-
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
480-
Aₕ, ΔAₕ = ad_project_hermitian_setup(A)
481-
eltype(T) <: BlasFloat && test_reverse(project_hermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₕ, fdm)
482-
is_cpu(A) && enz_test_pullbacks_match(rng, project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ)
483-
end
484-
end
485-
@testset "project_antihermitian" begin
486-
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
487-
Aₐ, ΔAₐ = ad_project_antihermitian_setup(A)
488-
eltype(T) <: BlasFloat && test_reverse(project_antihermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₐ, fdm)
489-
is_cpu(A) && enz_test_pullbacks_match(rng, project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ)
490-
end
491-
end
492-
end
493-
if m > n
494-
@testset "project_isometric" begin
495-
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
496-
W, ΔW = ad_project_isometric_setup(A)
497-
eltype(T) <: BlasFloat && test_reverse(project_isometric, RT, (A, TA); atol, rtol, output_tangent = ΔW, fdm)
498-
is_cpu(A) && enz_test_pullbacks_match(rng, project_isometric!, project_isometric, A, W, ΔW)
499-
end
500-
end
501-
end
502-
end
503-
end

0 commit comments

Comments
 (0)