Skip to content

Commit af275f5

Browse files
committed
Incremental progress on pb
1 parent 7f2ec86 commit af275f5

1 file changed

Lines changed: 22 additions & 4 deletions

File tree

src/pullbacks/qr.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,18 @@ function check_and_prepare_qr_cotangents(
3131
ΔR₁₁ = UpperTriangular(view(ΔR, 1:p, 1:p))
3232
ΔR₁₂ = view(ΔR, 1:p, (p + 1):n)
3333
ΔR₂₂ = view(ΔR, (p + 1):minmn, (p + 1):n)
34-
Δgauge_R = norm(view(ΔR₂₂, uppertriangularind(ΔR₂₂)), Inf)
35-
Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf))
36-
Δgauge = max(Δgauge, Δgauge_R)
34+
if p < minmn # otherwise ΔR₂₂ is empty
35+
# uppertriangularind generates linear indices
36+
# compute the appropriate offset in ΔR so we aren't
37+
# operating on a view-of-view, which doesn't work
38+
# for GPU arrays
39+
offset = LinearIndices(ΔR)[p + 1, p + 1]
40+
upper_inds = uppertriangularind(ΔR₂₂) .+ offset
41+
ΔR₂₂upper = view(ΔR, upper_inds)
42+
Δgauge_R = norm(ΔR₂₂upper, Inf)
43+
Δgauge_R = max(Δgauge_R, norm(view(ΔR₂₂, diagind(ΔR₂₂)), Inf))
44+
Δgauge = max(Δgauge, Δgauge_R)
45+
end
3746
else
3847
ΔR₁₁ = nothing
3948
ΔR₁₂ = nothing
@@ -160,7 +169,16 @@ function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebr
160169
end
161170
ΔR₂₂ = view(ΔR, (r + 1):minmn, (r + 1):size(R, 2))
162171
zero!(diagview(ΔR₂₂))
163-
zero!(view(ΔR₂₂, uppertriangularind(ΔR₂₂)))
172+
if r < minmn
173+
# uppertriangularind generates linear indices
174+
# compute the appropriate offset in ΔR so we aren't
175+
# operating on a view-of-view, which doesn't work
176+
# for GPU arrays
177+
offset = LinearIndices(ΔR)[r + 1, r + 1]
178+
upper_inds = uppertriangularind(ΔR₂₂) .+ offset
179+
ΔR₂₂upper = view(ΔR, upper_inds)
180+
zero!(ΔR₂₂upper)
181+
end
164182
return ΔQ, ΔR
165183
end
166184

0 commit comments

Comments
 (0)