Skip to content

Commit 80bf040

Browse files
committed
Fix JET issues and Mooncake fails
1 parent 8c81e02 commit 80bf040

5 files changed

Lines changed: 9 additions & 8 deletions

File tree

src/implementations/polar.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ end
117117
function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
118118
m, n = size(A) # we must have m >= n
119119
Rᴴinv = isempty(P) ? similar(P, (n, n)) : P # use P as workspace when available
120+
Q = similar(A, (0, 0))
120121
if m > n # initial QR
121122
Q, R = qr_compact!(A)
122123
Rc = view(A, 1:n, 1:n)
@@ -157,6 +158,7 @@ end
157158
function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
158159
m, n = size(A) # we must have m <= n
159160
Lᴴinv = isempty(P) ? similar(P, (m, m)) : P # use P as workspace when available
161+
Q = similar(A, (0, 0))
160162
if m < n # initial QR
161163
L, Q = lq_compact!(A)
162164
Lc = view(A, 1:m, 1:m)

src/implementations/qr.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ function _lapack_qr!(
144144
throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required, and using the unblocked algorithm (`blocksize=1`) with `positive=false`"))
145145
end
146146

147+
jpvt = Vector{Int}(undef, 0)
147148
if blocksize > 1
148149
nb = min(minmn, blocksize)
149150
if computeR # first use R as space for T

src/pullbacks/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function eig_trunc_pullback!(
125125
p == length(D) || throw(DimensionMismatch())
126126
(n, n) == size(ΔA) || throw(DimensionMismatch())
127127
G = V' * V
128-
128+
ΔVperp = similar(V, (0, 0))
129129
if !iszerotangent(ΔV)
130130
(n, p) == size(ΔV) || throw(DimensionMismatch())
131131
VᴴΔV = V' * ΔV

src/pullbacks/svd.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ function svd_pullback!(
4242
ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ
4343
UΔU = fill!(similar(U, (r, r)), 0)
4444
VΔV = fill!(similar(Vᴴ, (r, r)), 0)
45+
indU = axes(U, 2)
46+
indV = axes(Vᴴ, 1)
4547
if !iszerotangent(ΔU)
4648
m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)"))
4749
pU = size(ΔU, 2)

test/mooncake.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,9 @@ using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
99

1010
include("ad_utils.jl")
1111

12-
make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem))
13-
make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA
14-
make_mooncake_tangent(ΔA::Vector{<:Real}) = ΔA
15-
make_mooncake_tangent(ΔA::Matrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA)
16-
make_mooncake_tangent(ΔA::Vector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA)
17-
make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD))
18-
make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD)))
12+
make_mooncake_tangent(ΔA::Matrix) = ΔA
13+
make_mooncake_tangent(ΔA::Vector) = ΔA
14+
make_mooncake_tangent(ΔD::Diagonal) = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD))
1915

2016
make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...)
2117

0 commit comments

Comments
 (0)