@@ -36,22 +36,24 @@ function lq_pullback!(
3636 ΔA1 = view (ΔA, 1 : p, :)
3737 ΔA2 = view (ΔA, (p + 1 ): m, :)
3838
39- if minmn > p # case where A is rank-deficient
40- Δgauge = abs (zero (eltype (Q)))
41- if ! iszerotangent (ΔQ)
42- # in this case the number Householder reflections will
43- # change upon small variations, and all of the remaining
44- # columns of ΔQ should be zero for a gauge-invariant
45- # cost function
46- ΔQ2 = view (ΔQ, (p + 1 ): size (Q, 1 ), :)
47- Δgauge = max (Δgauge, norm (ΔQ2, Inf ))
48- end
49- if ! iszerotangent (ΔL)
50- ΔL22 = view (ΔL, (p + 1 ): m, (p + 1 ): minmn)
51- Δgauge = max (Δgauge, norm (ΔL22, Inf ))
39+ if isa (ΔA, Array) # not GPU friendly
40+ if minmn > p # case where A is rank-deficient
41+ Δgauge = abs (zero (eltype (Q)))
42+ if ! iszerotangent (ΔQ)
43+ # in this case the number Householder reflections will
44+ # change upon small variations, and all of the remaining
45+ # columns of ΔQ should be zero for a gauge-invariant
46+ # cost function
47+ ΔQ2 = view (ΔQ, (p + 1 ): size (Q, 1 ), :)
48+ Δgauge = max (Δgauge, norm (ΔQ2, Inf ))
49+ end
50+ if ! iszerotangent (ΔL)
51+ ΔL22 = view (ΔL, (p + 1 ): m, (p + 1 ): minmn)
52+ Δgauge = max (Δgauge, norm (ΔL22, Inf ))
53+ end
54+ Δgauge ≤ gauge_atol ||
55+ @warn " `lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
5256 end
53- Δgauge ≤ gauge_atol ||
54- @warn " `lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
5557 end
5658
5759 ΔQ̃ = zero! (similar (Q, (p, n)))
@@ -69,9 +71,11 @@ function lq_pullback!(
6971 # how the full Q2 will change, but this we omit for now, and we consider
7072 # Q2' * ΔQ2 as a gauge dependent quantity.
7173 ΔQ2Q1ᴴ = ΔQ2 * Q1'
72- Δgauge = norm (mul! (copy (ΔQ2), ΔQ2Q1ᴴ, Q1, - 1 , 1 ), Inf )
73- Δgauge ≤ gauge_atol ||
74- @warn " `lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
74+ if isa (ΔA, Array) # not GPU friendly
75+ Δgauge = norm (mul! (copy (ΔQ2), ΔQ2Q1ᴴ, Q1, - 1 , 1 ), Inf )
76+ Δgauge ≤ gauge_atol ||
77+ @warn " `lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge )"
78+ end
7579 ΔQ̃ = mul! (ΔQ̃, ΔQ2Q1ᴴ' , Q2, - 1 , 1 )
7680 end
7781 end
@@ -95,8 +99,10 @@ function lq_pullback!(
9599 Md = diagview (M)
96100 Md .= real .(Md)
97101 end
98- ldiv! (LowerTriangular (L11)' , M)
99- ldiv! (LowerTriangular (L11)' , ΔQ̃)
102+ # not GPU friendly...
103+ L11arr = typeof (L)(L11)
104+ ldiv! (LowerTriangular (L11arr)' , M)
105+ ldiv! (LowerTriangular (L11arr)' , ΔQ̃)
100106 ΔA1 = mul! (ΔA1, M, Q1, + 1 , 1 )
101107 ΔA1 .+ = ΔQ̃
102108 return ΔA
0 commit comments