Skip to content

Commit bc21f31

Browse files
author
Katharine Hyatt
committed
Try to pretty up JET changes a little
1 parent 215fdb5 commit bc21f31

2 files changed

Lines changed: 9 additions & 6 deletions

File tree

src/implementations/polar.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,14 @@ end
117117
function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
118118
m, n = size(A) # we must have m >= n
119119
Rᴴinv = isempty(P) ? similar(P, (n, n)) : P # use P as workspace when available
120-
Q = similar(A, (0, 0))
121-
if m > n # initial QR
120+
istall = m > n
121+
if istall # initial QR
122122
Q, R = qr_compact!(A)
123123
Rc = view(A, 1:n, 1:n)
124124
Rc .= R
125125
Rᴴinv = ldiv!(UpperTriangular(Rc)', one!(Rᴴinv))
126126
else # m == n
127+
Q = similar(A, (0, 0)) # needed for JET
127128
R = A
128129
Rc = view(W, 1:n, 1:n)
129130
Rc .= R
@@ -149,7 +150,7 @@ function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol =
149150
if conv > tol
150151
@warn "`left_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv)"
151152
end
152-
if m > n
153+
if istall
153154
return mul!(W, Q, Rc)
154155
end
155156
return W
@@ -158,13 +159,14 @@ end
158159
function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
159160
m, n = size(A) # we must have m <= n
160161
Lᴴinv = isempty(P) ? similar(P, (m, m)) : P # use P as workspace when available
161-
Q = similar(A, (0, 0))
162-
if m < n # initial QR
162+
isshort = m < n
163+
if isshort # initial QR
163164
L, Q = lq_compact!(A)
164165
Lc = view(A, 1:m, 1:m)
165166
copy!(Lc, L)
166167
Lᴴinv = ldiv!(LowerTriangular(Lc)', one!(Lᴴinv))
167168
else # m == n
169+
Q = similar(A, (0, 0)) # needed for JET
168170
L = A
169171
Lc = view(Wᴴ, 1:m, 1:m)
170172
Lc .= L
@@ -190,7 +192,7 @@ function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); to
190192
if conv > tol
191193
@warn "`right_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv)"
192194
end
193-
if m < n
195+
if isshort
194196
return mul!(Wᴴ, Lc, Q)
195197
end
196198
return Wᴴ

src/pullbacks/eig.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ function eig_trunc_pullback!(
137137
ΔVperp = ΔV - V * inv(G) * VᴴΔV
138138
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
139139
else
140+
ΔVperp = similar(G, (0, 0)) # needed for JET
140141
VᴴΔV = zero(G)
141142
end
142143

0 commit comments

Comments
 (0)