1+ module MatrixAlgebraKitChainRulesCoreExt
2+
3+ using MatrixAlgebraKit
4+ using MatrixAlgebraKit: copy_input, TruncatedAlgorithm, zero!
5+ using ChainRulesCore
6+ using LinearAlgebra
7+
8+ # TODO : Decide on an interface to pass on the kwargs for the pullback functions
9+ # from the primal function calls
10+
11+ MatrixAlgebraKit. iszerotangent (:: AbstractZero ) = true
12+
13+ function ChainRulesCore. rrule (:: typeof (copy_input), f, A:: AbstractMatrix )
14+ project = ProjectTo (A)
15+ copy_input_pullback (ΔA) = (NoTangent (), NoTangent (), project (unthunk (ΔA)))
16+ return copy_input (f, A), copy_input_pullback
17+ end
18+
19+ for qr_f in (:qr_compact , :qr_full )
20+ qr_f! = Symbol (qr_f, ' !' )
21+ @eval begin
22+ function ChainRulesCore. rrule (:: typeof ($ qr_f!), A:: AbstractMatrix , QR, alg)
23+ Ac = copy_input ($ qr_f, A)
24+ QR = $ (qr_f!)(Ac, QR, alg)
25+ function qr_pullback (ΔQR)
26+ ΔA = zero (A)
27+ MatrixAlgebraKit. qr_compact_pullback! (ΔA, QR, unthunk .(ΔQR))
28+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
29+ end
30+ function qr_pullback (:: Tuple{ZeroTangent,ZeroTangent} ) # is this extra definition useful?
31+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
32+ end
33+ return QR, qr_pullback
34+ end
35+ end
36+ end
37+ function ChainRulesCore. rrule (:: typeof (qr_null!), A:: AbstractMatrix , N, alg)
38+ Ac = copy_input (qr_full, A)
39+ QR = MatrixAlgebraKit. initialize_output (qr_full!, A, alg)
40+ Q, R = qr_full! (Ac, QR, alg)
41+ N = copy! (N, view (Q, 1 : size (A, 1 ), (size (A, 2 ) + 1 ): size (A, 1 )))
42+ function qr_null_pullback (ΔN)
43+ ΔA = zero (A)
44+ (m, n) = size (A)
45+ minmn = min (m, n)
46+ ΔQ = zero! (similar (A, (m, m)))
47+ view (ΔQ, 1 : m, (minmn + 1 ): m) .= unthunk .(ΔN)
48+ MatrixAlgebraKit. qr_compact_pullback! (ΔA, (Q, R), (ΔQ, ZeroTangent ()))
49+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
50+ end
51+ function qr_null_pullback (:: ZeroTangent ) # is this extra definition useful?
52+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
53+ end
54+ return N, qr_null_pullback
55+ end
56+
57+ for lq_f in (:lq_compact , :lq_full )
58+ lq_f! = Symbol (lq_f, ' !' )
59+ @eval begin
60+ function ChainRulesCore. rrule (:: typeof ($ lq_f!), A:: AbstractMatrix , LQ, alg)
61+ Ac = copy_input ($ lq_f, A)
62+ LQ = $ (lq_f!)(Ac, LQ, alg)
63+ function lq_pullback (ΔLQ)
64+ ΔA = zero (A)
65+ MatrixAlgebraKit. lq_compact_pullback! (ΔA, LQ, unthunk .(ΔLQ))
66+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
67+ end
68+ function lq_pullback (:: Tuple{ZeroTangent,ZeroTangent} ) # is this extra definition useful?
69+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
70+ end
71+ return LQ, lq_pullback
72+ end
73+ end
74+ end
75+ function ChainRulesCore. rrule (:: typeof (lq_null!), A:: AbstractMatrix , Nᴴ, alg)
76+ Ac = copy_input (lq_full, A)
77+ LQ = MatrixAlgebraKit. initialize_output (lq_full!, A, alg)
78+ L, Q = lq_full! (Ac, LQ, alg)
79+ Nᴴ = copy! (Nᴴ, view (Q, (size (A, 1 ) + 1 ): size (A, 2 ), 1 : size (A, 2 )))
80+ function lq_null_pullback (ΔNᴴ)
81+ ΔA = zero (A)
82+ (m, n) = size (A)
83+ minmn = min (m, n)
84+ ΔQ = zero! (similar (A, (n, n)))
85+ view (ΔQ, (minmn + 1 ): n, 1 : n) .= unthunk .(ΔNᴴ)
86+ MatrixAlgebraKit. lq_compact_pullback! (ΔA, (L, Q), (ZeroTangent (), ΔQ))
87+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
88+ end
89+ function lq_null_pullback (:: ZeroTangent ) # is this extra definition useful?
90+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
91+ end
92+ return Nᴴ, lq_null_pullback
93+ end
94+
95+ for eig in (:eig , :eigh )
96+ eig_f = Symbol (eig, " _full" )
97+ eig_f! = Symbol (eig_f, " !" )
98+ eig_f_pb! = Symbol (eig, " _full_pullback!" )
99+ eig_pb = Symbol (eig, " _pullback" )
100+ @eval begin
101+ function ChainRulesCore. rrule (:: typeof ($ eig_f!), A:: AbstractMatrix , DV, alg)
102+ Ac = copy_input ($ eig_f, A)
103+ DV = $ (eig_f!)(Ac, DV, alg)
104+ function $eig_pb (ΔDV)
105+ ΔA = zero (A)
106+ MatrixAlgebraKit.$ eig_f_pb! (ΔA, DV, unthunk .(ΔDV))
107+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
108+ end
109+ function $eig_pb (:: Tuple{ZeroTangent,ZeroTangent} ) # is this extra definition useful?
110+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
111+ end
112+ return DV, $ eig_pb
113+ end
114+ end
115+ end
116+
117+ for svd_f in (:svd_compact , :svd_full )
118+ svd_f! = Symbol (svd_f, " !" )
119+ @eval begin
120+ function ChainRulesCore. rrule (:: typeof ($ svd_f!), A:: AbstractMatrix , USVᴴ, alg)
121+ Ac = copy_input ($ svd_f, A)
122+ USVᴴ = $ (svd_f!)(Ac, USVᴴ, alg)
123+ function svd_pullback (ΔUSVᴴ)
124+ ΔA = zero (A)
125+ MatrixAlgebraKit. svd_compact_pullback! (ΔA, USVᴴ, unthunk .(ΔUSVᴴ))
126+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
127+ end
128+ function svd_pullback (:: Tuple{ZeroTangent,ZeroTangent,ZeroTangent} ) # is this extra definition useful?
129+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
130+ end
131+ return USVᴴ, svd_pullback
132+ end
133+ end
134+ end
135+
136+ function ChainRulesCore. rrule (:: typeof (svd_trunc!), A:: AbstractMatrix , USVᴴ,
137+ alg:: TruncatedAlgorithm )
138+ Ac = MatrixAlgebraKit. copy_input (svd_compact, A)
139+ USVᴴ = svd_compact! (Ac, USVᴴ, alg. alg)
140+ function svd_trunc_pullback (ΔUSVᴴ)
141+ ΔA = zero (A)
142+ MatrixAlgebraKit. svd_compact_pullback! (ΔA, USVᴴ, unthunk .(ΔUSVᴴ))
143+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
144+ end
145+ function svd_trunc_pullback (:: Tuple{ZeroTangent,ZeroTangent,ZeroTangent} ) # is this extra definition useful?
146+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
147+ end
148+ return MatrixAlgebraKit. truncate! (svd_trunc!, USVᴴ, alg. trunc), svd_trunc_pullback
149+ end
150+
151+ function ChainRulesCore. rrule (:: typeof (left_polar!), A:: AbstractMatrix , WP, alg)
152+ Ac = copy_input (left_polar, A)
153+ WP = left_polar! (Ac, WP, alg)
154+ function left_polar_pullback (ΔWP)
155+ ΔA = zero (A)
156+ MatrixAlgebraKit. left_polar_pullback! (ΔA, WP, unthunk .(ΔWP))
157+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
158+ end
159+ function left_polar_pullback (:: Tuple{ZeroTangent,ZeroTangent} ) # is this extra definition useful?
160+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
161+ end
162+ return WP, left_polar_pullback
163+ end
164+
165+ function ChainRulesCore. rrule (:: typeof (right_polar!), A:: AbstractMatrix , PWᴴ, alg)
166+ Ac = copy_input (left_polar, A)
167+ PWᴴ = right_polar! (Ac, PWᴴ, alg)
168+ function right_polar_pullback (ΔPWᴴ)
169+ ΔA = zero (A)
170+ MatrixAlgebraKit. right_polar_pullback! (ΔA, PWᴴ, unthunk .(ΔPWᴴ))
171+ return NoTangent (), ΔA, ZeroTangent (), NoTangent ()
172+ end
173+ function right_polar_pullback (:: Tuple{ZeroTangent,ZeroTangent} ) # is this extra definition useful?
174+ return NoTangent (), ZeroTangent (), ZeroTangent (), NoTangent ()
175+ end
176+ return PWᴴ, right_polar_pullback
177+ end
178+
179+ end
0 commit comments