Skip to content

Commit 89d03e2

Browse files
LemonPiclaude
andcommitted
Modernize APIs, improve numerical stability, bump to 0.5.0
- kronecker_product: replace manual implementation with torch.kron - GELS: replace deprecated torch.cholesky with torch.linalg.cholesky - Lookahead optimizer: fix deprecated add_(scalar, tensor) signature - ls_cov: use torch.linalg.lstsq for params, torch.linalg.solve instead of explicit .inverse() for better numerical stability - StandardScaler: precompute reciprocal to multiply instead of divide - Bump version 0.4.3 → 0.5.0 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 52d67f2 commit 89d03e2

4 files changed

Lines changed: 18 additions & 26 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "arm_pytorch_utilities"
3-
version = "0.4.3"
3+
version = "0.5.0"
44
description = "Utilities for working with pytorch"
55
readme = "README.md" # Optional
66

src/arm_pytorch_utilities/linalg.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,7 @@ def kronecker_product(t1, t2):
4242
Computes the Kronecker product between two tensors.
4343
See https://en.wikipedia.org/wiki/Kronecker_product
4444
"""
45-
t1_height, t1_width = t1.size()
46-
t2_height, t2_width = t2.size()
47-
out_height = t1_height * t2_height
48-
out_width = t1_width * t2_width
49-
50-
tiled_t2 = t2.repeat(t1_height, t1_width)
51-
expanded_t1 = (
52-
t1.unsqueeze(2)
53-
.unsqueeze(3)
54-
.repeat(1, t2_height, t2_width, 1)
55-
.view(out_height, out_width)
56-
)
57-
58-
return expanded_t1 * tiled_t2
45+
return torch.kron(t1, t2)
5946

6047

6148
def cov(x, rowvar=False, bias=False, ddof=None, aweights=None):
@@ -152,16 +139,16 @@ def forward(ctx, A, b):
152139
# A: (..., M, N)
153140
# b: (..., M, K)
154141
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py#L267
155-
u = torch.cholesky(torch.matmul(A.transpose(-1, -2), A), upper=True)
156-
ret = torch.cholesky_solve(torch.matmul(A.transpose(-1, -2), b), u, upper=True)
157-
ctx.save_for_backward(u, ret, A, b)
142+
L = torch.linalg.cholesky(torch.matmul(A.transpose(-1, -2), A))
143+
ret = torch.cholesky_solve(torch.matmul(A.transpose(-1, -2), b), L, upper=False)
144+
ctx.save_for_backward(L, ret, A, b)
158145
return ret
159146

160147
@staticmethod
161148
def backward(ctx, grad_output):
162149
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L223
163150
chol, x, a, b = ctx.saved_tensors
164-
z = torch.cholesky_solve(grad_output, chol, upper=True)
151+
z = torch.cholesky_solve(grad_output, chol, upper=False)
165152
xzt = torch.matmul(x, z.transpose(-1, -2))
166153
zx_sym = xzt + xzt.transpose(-1, -2)
167154
grad_A = - torch.matmul(a, zx_sym) + torch.matmul(b, z.transpose(-1, -2))
@@ -197,11 +184,13 @@ def ls(X, Y, weights=None):
197184
def ls_cov(X, Y, weights=None, make_symmetric=True, sigreg=1e-4):
198185
X, Y = _apply_weights(X, Y, weights)
199186

200-
pinvXX = X.pinverse()
201-
params = (pinvXX @ Y).t()
187+
# Solve least squares via lstsq (more stable than pinverse)
188+
result = torch.linalg.lstsq(X, Y)
189+
params = result.solution.t()
202190

203191
# estimate covariance according to: http://users.stat.umn.edu/~helwig/notes/mvlr-Notes.pdf (see up to slide 66)
204192
# hat/projection matrix - Yhat = H*Y
193+
pinvXX = X.pinverse()
205194
H = X @ pinvXX
206195

207196
N = X.shape[0]
@@ -231,8 +220,9 @@ def ls_cov(X, Y, weights=None, make_symmetric=True, sigreg=1e-4):
231220
XXXX = XXXX_sym
232221
error_covariance = error_covariance_sym
233222

234-
# TODO might be able to use cholesky decomp here since XXXX > 0
235-
covariance = kronecker_product(error_covariance, XXXX.inverse())
223+
# Use solve instead of explicit inverse: solve(A, I) = A^{-1}, more numerically stable
224+
XXXX_inv = torch.linalg.solve(XXXX, torch.eye(XXXX.shape[0], dtype=XXXX.dtype, device=XXXX.device)).contiguous()
225+
covariance = kronecker_product(error_covariance, XXXX_inv)
236226

237227
return params, covariance
238228

src/arm_pytorch_utilities/optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ def step(self, closure=None):
106106
for group in self.optimizer.param_groups:
107107
for p in group['params']:
108108
param_state = self.state[p]
109-
p.data.mul_(self.la_alpha).add_(1.0 - self.la_alpha, param_state['cached_params']) # crucial line
109+
p.data.mul_(self.la_alpha).add_(param_state['cached_params'], alpha=1.0 - self.la_alpha) # crucial line
110110
param_state['cached_params'].copy_(p.data)
111111
if self.pullback_momentum == "pullback":
112112
internal_momentum = self.optimizer.state[p]["momentum_buffer"]
113113
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_(
114-
1.0 - self.la_alpha, param_state["cached_mom"])
114+
param_state["cached_mom"], alpha=1.0 - self.la_alpha)
115115
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
116116
elif self.pullback_momentum == "reset":
117117
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)

src/arm_pytorch_utilities/preprocess.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,13 +311,15 @@ def __init__(self, **kwargs):
311311
super().__init__(**kwargs)
312312
self._m = None
313313
self._s = None
314+
self._inv_s = None
314315

315316
def fit(self, X):
316317
self._m = X.mean(0, keepdim=True)
317318
self._s = X.std(0, unbiased=False, keepdim=True)
319+
self._inv_s = 1.0 / self._s
318320

319321
def transform(self, X):
320-
return (X - self._m) / self._s
322+
return (X - self._m) * self._inv_s
321323

322324
def inverse_transform(self, X):
323325
return (X * self._s) + self._m

0 commit comments

Comments
 (0)