Skip to content

Commit 978effe

Browse files
authored
fix: LQ pullback modifies input cotangents (#226)
* make sure LQ pullback does not modify input * remove unnecessary line * copy of view to make sure this is copied * add tests for not modifying output tangents in chainrules
1 parent 4b19746 commit 978effe

3 files changed

Lines changed: 180 additions & 26 deletions

File tree

src/pullbacks/lq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ function check_and_prepare_lq_cotangents(
1313
size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q"))
1414
ΔQ₁ .= view(ΔQ, 1:p, 1:n)
1515
if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁
16+
ΔQ₃ = copy(view(ΔQ, (minmn + 1):size(Q, 1), :)) # extra columns in the case of qr_full
1617
Q₃ = view(Q, (minmn + 1):size(Q, 1), :)
17-
ΔQ₃ = view(ΔQ, (minmn + 1):size(Q, 1), :)
1818
ΔQ₃Q₁ᴴ = ΔQ₃ * Q₁'
1919
mul!(ΔQ₃, ΔQ₃Q₁ᴴ, Q₁, -1, 1)
2020
Δgauge_Q = norm(ΔQ₃, Inf)

src/pullbacks/qr.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ function check_and_prepare_qr_cotangents(
1414
size(ΔQ) == size(Q) || throw(DimensionMismatch("ΔQ must have the same size as Q"))
1515
ΔQ₁ .= view(ΔQ, 1:m, 1:p)
1616
if p == minmn # full rank case, ΔQ₃ contains gauge-invariant information along Q₁
17-
ΔQ₃ = ΔQ[:, (minmn + 1):size(Q, 2)] # extra columns in the case of qr_full
18-
Q₁ = view(Q, :, 1:minmn)
17+
ΔQ₃ = copy(view(ΔQ, :, (minmn + 1):size(Q, 2))) # extra columns in the case of qr_full
1918
Q₃ = view(Q, :, (minmn + 1):size(Q, 2))
2019
Q₁ᴴΔQ₃ = Q₁' * ΔQ₃
2120
mul!(ΔQ₃, Q₁, Q₁ᴴΔQ₃, -1, 1)

0 commit comments

Comments
 (0)