@@ -37,27 +37,29 @@ function qr_pullback!(
3737 ΔA1 = view (ΔA, :, 1 : p)
3838 ΔA2 = view (ΔA, :, (p + 1 ): n)
3939
40- if minmn > p # case where A is rank-deficient
41- Δgauge = abs (zero (eltype (Q)))
42- if ! iszerotangent (ΔQ)
43- # in this case the number Householder reflections will
44- # change upon small variations, and all of the remaining
45- # columns of ΔQ should be zero for a gauge-invariant
46- # cost function
47- ΔQ2 = view (ΔQ, :, (p + 1 ): size (Q, 2 ))
48- Δgauge = max (Δgauge, norm (ΔQ2, Inf ))
49- end
50- if ! iszerotangent (ΔR)
51- ΔR22 = view (ΔR, (p + 1 ): minmn, (p + 1 ): n)
52- Δgauge = max (Δgauge, norm (ΔR22, Inf ))
40+ if isa (ΔA, Array) # not GPU friendly
41+ if minmn > p # case where A is rank-deficient
42+ Δgauge = abs (zero (eltype (Q)))
43+ if ! iszerotangent (ΔQ)
44+ # in this case the number Householder reflections will
45+ # change upon small variations, and all of the remaining
46+ # columns of ΔQ should be zero for a gauge-invariant
47+ # cost function
48+ ΔQ2 = view (ΔQ, :, (p + 1 ): size (Q, 2 ))
49+ Δgauge = max (Δgauge, norm (ΔQ2, Inf ))
50+ end
51+ if ! iszerotangent (ΔR)
52+ ΔR22 = view (ΔR, (p + 1 ): minmn, (p + 1 ): n)
53+ Δgauge = max (Δgauge, norm (ΔR22, Inf ))
54+ end
55+ Δgauge ≤ gauge_atol ||
56+ @warn " `qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
5357 end
54- Δgauge ≤ gauge_atol ||
55- @warn " `qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
5658 end
5759
5860 ΔQ̃ = zero! (similar (Q, (m, p)))
5961 if ! iszerotangent (ΔQ)
60- copy! ( ΔQ̃, view (ΔQ, :, 1 : p) )
62+ ΔQ̃ . = view (ΔQ, :, 1 : p)
6163 if p < size (Q, 2 )
6264 Q2 = view (Q, :, (p + 1 ): size (Q, 2 ))
6365 ΔQ2 = view (ΔQ, :, (p + 1 ): size (Q, 2 ))
@@ -69,9 +71,11 @@ function qr_pullback!(
6971 # how the full Q2 will change, but this we omit for now, and we consider
7072 # Q2' * ΔQ2 as a gauge dependent quantity.
7173 Q1dΔQ2 = Q1' * ΔQ2
72- Δgauge = norm (mul! (copy (ΔQ2), Q1, Q1dΔQ2, - 1 , 1 ), Inf )
73- Δgauge ≤ gauge_atol ||
74- @warn " `qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
74+ if isa (ΔA, Array) # not GPU friendly
75+ Δgauge = norm (mul! (copy (ΔQ2), Q1, Q1dΔQ2, - 1 , 1 ), Inf )
76+ Δgauge ≤ gauge_atol ||
77+ @warn " `qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
78+ end
7579 ΔQ̃ = mul! (ΔQ̃, Q2, Q1dΔQ2' , - 1 , 1 )
7680 end
7781 end
@@ -87,16 +91,18 @@ function qr_pullback!(
8791 M = zero! (similar (R, (p, p)))
8892 if ! iszerotangent (ΔR)
8993 ΔR11 = view (ΔR, 1 : p, 1 : p)
90- M = mul! (M, ΔR11, R11' , 1 , 1 )
94+ M += ΔR11 * R11'
9195 end
92- M = mul! (M, Q1' , ΔQ̃, - 1 , 1 )
96+ M -= Q1' * ΔQ̃
9397 view (M, lowertriangularind (M)) .= conj .(view (M, uppertriangularind (M)))
9498 if eltype (M) <: Complex
9599 Md = diagview (M)
96100 Md .= real .(Md)
97101 end
98- rdiv! (M, UpperTriangular (R11)' )
99- rdiv! (ΔQ̃, UpperTriangular (R11)' )
102+ # not GPU-friendly...
103+ R11arr = typeof (R)(R11)
104+ rdiv! (M, UpperTriangular (R11arr)' )
105+ rdiv! (ΔQ̃, UpperTriangular (R11arr)' )
100106 ΔA1 = mul! (ΔA1, Q1, M, + 1 , 1 )
101107 ΔA1 .+ = ΔQ̃
102108 return ΔA
0 commit comments