@@ -42,20 +42,7 @@ def kronecker_product(t1, t2):
4242 Computes the Kronecker product between two tensors.
4343 See https://en.wikipedia.org/wiki/Kronecker_product
4444 """
45- t1_height , t1_width = t1 .size ()
46- t2_height , t2_width = t2 .size ()
47- out_height = t1_height * t2_height
48- out_width = t1_width * t2_width
49-
50- tiled_t2 = t2 .repeat (t1_height , t1_width )
51- expanded_t1 = (
52- t1 .unsqueeze (2 )
53- .unsqueeze (3 )
54- .repeat (1 , t2_height , t2_width , 1 )
55- .view (out_height , out_width )
56- )
57-
58- return expanded_t1 * tiled_t2
45+ return torch .kron (t1 , t2 )
5946
6047
6148def cov (x , rowvar = False , bias = False , ddof = None , aweights = None ):
@@ -152,16 +139,16 @@ def forward(ctx, A, b):
152139 # A: (..., M, N)
153140 # b: (..., M, K)
154141 # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py#L267
155- u = torch .cholesky (torch .matmul (A .transpose (- 1 , - 2 ), A ), upper = True )
156- ret = torch .cholesky_solve (torch .matmul (A .transpose (- 1 , - 2 ), b ), u , upper = True )
157- ctx .save_for_backward (u , ret , A , b )
142+ L = torch .linalg . cholesky (torch .matmul (A .transpose (- 1 , - 2 ), A ))
143+ ret = torch .cholesky_solve (torch .matmul (A .transpose (- 1 , - 2 ), b ), L , upper = False )
144+ ctx .save_for_backward (L , ret , A , b )
158145 return ret
159146
160147 @staticmethod
161148 def backward (ctx , grad_output ):
162149 # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L223
163150 chol , x , a , b = ctx .saved_tensors
164- z = torch .cholesky_solve (grad_output , chol , upper = True )
151+ z = torch .cholesky_solve (grad_output , chol , upper = False )
165152 xzt = torch .matmul (x , z .transpose (- 1 , - 2 ))
166153 zx_sym = xzt + xzt .transpose (- 1 , - 2 )
167154 grad_A = - torch .matmul (a , zx_sym ) + torch .matmul (b , z .transpose (- 1 , - 2 ))
@@ -197,11 +184,13 @@ def ls(X, Y, weights=None):
197184def ls_cov (X , Y , weights = None , make_symmetric = True , sigreg = 1e-4 ):
198185 X , Y = _apply_weights (X , Y , weights )
199186
200- pinvXX = X .pinverse ()
201- params = (pinvXX @ Y ).t ()
187+ # Solve least squares via lstsq (more stable than pinverse)
188+ result = torch .linalg .lstsq (X , Y )
189+ params = result .solution .t ()
202190
203191 # estimate covariance according to: http://users.stat.umn.edu/~helwig/notes/mvlr-Notes.pdf (see up to slide 66)
204192 # hat/projection matrix - Yhat = H*Y
193+ pinvXX = X .pinverse ()
205194 H = X @ pinvXX
206195
207196 N = X .shape [0 ]
@@ -231,8 +220,9 @@ def ls_cov(X, Y, weights=None, make_symmetric=True, sigreg=1e-4):
231220 XXXX = XXXX_sym
232221 error_covariance = error_covariance_sym
233222
234- # TODO might be able to use cholesky decomp here since XXXX > 0
235- covariance = kronecker_product (error_covariance , XXXX .inverse ())
223+ # Use solve instead of explicit inverse: solve(A, I) = A^{-1}, more numerically stable
224+ XXXX_inv = torch .linalg .solve (XXXX , torch .eye (XXXX .shape [0 ], dtype = XXXX .dtype , device = XXXX .device )).contiguous ()
225+ covariance = kronecker_product (error_covariance , XXXX_inv )
236226
237227 return params , covariance
238228
0 commit comments