Skip to content

Commit de16d62

Browse files
authored
Merge pull request #2 from UM-ARM-Lab/refactor/torch-compile
Refactor to improve performance/torch compile compatibility
2 parents 6312d50 + 89d03e2 commit de16d62

12 files changed

Lines changed: 605 additions & 215 deletions

File tree

benchmarks/bench_compile.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#!/usr/bin/env python
2+
"""Benchmark script for arm_pytorch_utilities torch.compile optimization.
3+
4+
Measures eager-mode and torch.compile performance for refactor-target functions.
5+
Outputs a printed table and JSON file for before/after comparison.
6+
7+
Usage:
8+
python benchmarks/bench_compile.py --device cpu
9+
python benchmarks/bench_compile.py --device cuda
10+
"""
11+
import argparse
12+
import json
13+
import time
14+
from datetime import datetime
15+
from pathlib import Path
16+
17+
import torch
18+
19+
from arm_pytorch_utilities import math_utils, linalg, tensor_utils, preprocess, softknn
20+
21+
22+
def bench(fn, *args, warmup=5, repeats=20, device='cpu'):
23+
"""Time a function with warmup and repeats. Returns median time in ms."""
24+
for _ in range(warmup):
25+
fn(*args)
26+
if device == 'cuda':
27+
torch.cuda.synchronize()
28+
29+
times = []
30+
for _ in range(repeats):
31+
if device == 'cuda':
32+
torch.cuda.synchronize()
33+
t0 = time.perf_counter()
34+
fn(*args)
35+
if device == 'cuda':
36+
torch.cuda.synchronize()
37+
t1 = time.perf_counter()
38+
times.append((t1 - t0) * 1000)
39+
40+
times.sort()
41+
return times[len(times) // 2]
42+
43+
44+
def try_compile_bench(fn, *args, device='cpu', warmup=5, repeats=20):
45+
"""Try to compile and benchmark a function. Returns (time_ms, success) or (None, False)."""
46+
try:
47+
compiled_fn = torch.compile(fn, fullgraph=True)
48+
# Warmup includes compilation
49+
for _ in range(warmup):
50+
compiled_fn(*args)
51+
if device == 'cuda':
52+
torch.cuda.synchronize()
53+
54+
times = []
55+
for _ in range(repeats):
56+
if device == 'cuda':
57+
torch.cuda.synchronize()
58+
t0 = time.perf_counter()
59+
compiled_fn(*args)
60+
if device == 'cuda':
61+
torch.cuda.synchronize()
62+
t1 = time.perf_counter()
63+
times.append((t1 - t0) * 1000)
64+
65+
times.sort()
66+
return times[len(times) // 2], True
67+
except Exception as e:
68+
return None, False, str(e)
69+
70+
71+
def make_psd(n, device):
72+
R = torch.randn(n, n, device=device)
73+
return R.t() @ R + torch.eye(n, device=device) * 0.1
74+
75+
76+
def run_benchmarks(device_str):
77+
device = torch.device(device_str)
78+
results = {}
79+
80+
benchmarks = {}
81+
82+
# --- Setup inputs ---
83+
# replace_nan_and_inf
84+
x_nan = torch.randn(10000, 100, device=device)
85+
mask = torch.rand_like(x_nan) < 0.1
86+
x_nan[mask] = float('nan')
87+
benchmarks['replace_nan_and_inf'] = (math_utils.replace_nan_and_inf, (x_nan.clone(), 0), True)
88+
89+
# angular_diff_batch
90+
a_ang = torch.randn(100000, device=device)
91+
b_ang = torch.randn(100000, device=device)
92+
benchmarks['angular_diff_batch'] = (math_utils.angular_diff_batch, (a_ang, b_ang), False)
93+
94+
# angle_between_stable
95+
u_abs = torch.randn(200, 50, device=device)
96+
v_abs = torch.randn(150, 50, device=device)
97+
benchmarks['angle_between_stable'] = (math_utils.angle_between_stable, (u_abs, v_abs), False)
98+
99+
# cos_sim_pairwise
100+
x1_cos = torch.randn(500, 50, device=device)
101+
x2_cos = torch.randn(300, 50, device=device)
102+
benchmarks['cos_sim_pairwise'] = (math_utils.cos_sim_pairwise, (x1_cos, x2_cos), False)
103+
104+
# batch_batch_product
105+
X_bbp = torch.randn(10000, 20, device=device)
106+
A_bbp = torch.randn(10000, 20, 20, device=device)
107+
benchmarks['batch_batch_product'] = (linalg.batch_batch_product, (X_bbp, A_bbp), False)
108+
109+
# batch_quadratic_product
110+
X_bqp = torch.randn(10000, 20, device=device)
111+
A_bqp = make_psd(20, device)
112+
benchmarks['batch_quadratic_product'] = (linalg.batch_quadratic_product, (X_bqp, A_bqp), False)
113+
114+
# batch_outer_product
115+
u_bop = torch.randn(10000, 20, device=device)
116+
v_bop = torch.randn(10000, 20, device=device)
117+
benchmarks['batch_outer_product'] = (linalg.batch_outer_product, (u_bop, v_bop), False)
118+
119+
# squeeze_n
120+
x_sq = torch.randn(1, 1, 1, 1000, 50, device=device)
121+
benchmarks['squeeze_n'] = (tensor_utils.squeeze_n, (x_sq, 3), False)
122+
123+
# MinMaxScaler.transform
124+
x_mm = torch.randn(10000, 50, device=device)
125+
scaler = preprocess.MinMaxScaler()
126+
scaler.fit(x_mm)
127+
benchmarks['MinMaxScaler.transform'] = (scaler.transform, (x_mm,), False)
128+
129+
# SoftKNN.forward
130+
x_knn = torch.randn(200, 10, device=device)
131+
knn = softknn.SoftKNN(min_k=20)
132+
benchmarks['SoftKNN.forward'] = (knn, (x_knn,), False)
133+
134+
# sqrtm (CPU only due to .numpy())
135+
if device_str == 'cpu':
136+
A_sqrtm = make_psd(50, device)
137+
benchmarks['sqrtm'] = (linalg.sqrtm, (A_sqrtm,), False)
138+
139+
# --- Run benchmarks ---
140+
print(f"\n{'Function':<30} {'Eager (ms)':>12} {'Compile (ms)':>14} {'Speedup':>10} {'Compile OK':>12}")
141+
print("-" * 80)
142+
143+
for name, (fn, args, needs_clone) in benchmarks.items():
144+
# Eager benchmark
145+
if needs_clone:
146+
# For in-place functions, clone first arg each call
147+
template = args[0]
148+
rest_args = args[1:]
149+
150+
def cloning_fn(*a, _fn=fn, _tpl=template, _rest=rest_args):
151+
return _fn(_tpl.clone(), *_rest)
152+
153+
eager_ms = bench(cloning_fn, warmup=5, repeats=20, device=device_str)
154+
else:
155+
try:
156+
eager_ms = bench(fn, *args, device=device_str)
157+
except Exception as e:
158+
print(f"{name:<30} {'ERROR':>12} {'N/A':>14} {'N/A':>10} {'N/A':>12} ({e})")
159+
results[name] = {'eager_ms': None, 'compile_ms': None, 'compile_ok': False,
160+
'compile_error': None, 'eager_error': str(e)}
161+
continue
162+
163+
# Compile benchmark
164+
compile_result = try_compile_bench(fn, *args, device=device_str)
165+
166+
if len(compile_result) == 2:
167+
compile_ms, compile_ok = compile_result
168+
compile_err = None
169+
else:
170+
compile_ms, compile_ok, compile_err = compile_result
171+
172+
speedup = f"{eager_ms / compile_ms:.2f}x" if compile_ms else "N/A"
173+
compile_str = f"{compile_ms:.3f}" if compile_ms else "FAIL"
174+
175+
print(f"{name:<30} {eager_ms:>12.3f} {compile_str:>14} {speedup:>10} {'yes' if compile_ok else 'no':>12}")
176+
177+
results[name] = {
178+
'eager_ms': eager_ms,
179+
'compile_ms': compile_ms,
180+
'compile_ok': compile_ok,
181+
'compile_error': compile_err,
182+
}
183+
184+
return results
185+
186+
187+
def main():
188+
parser = argparse.ArgumentParser(description='Benchmark arm_pytorch_utilities functions')
189+
parser.add_argument('--device', choices=['cpu', 'cuda'], default='cpu')
190+
args = parser.parse_args()
191+
192+
if args.device == 'cuda' and not torch.cuda.is_available():
193+
print("CUDA not available, falling back to CPU")
194+
args.device = 'cpu'
195+
196+
print(f"Running benchmarks on {args.device}")
197+
print(f"PyTorch version: {torch.__version__}")
198+
199+
results = run_benchmarks(args.device)
200+
201+
# Save JSON
202+
output_dir = Path(__file__).parent
203+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
204+
output_file = output_dir / f"results_{args.device}_{timestamp}.json"
205+
output = {
206+
'device': args.device,
207+
'torch_version': torch.__version__,
208+
'timestamp': timestamp,
209+
'results': results,
210+
}
211+
with open(output_file, 'w') as f:
212+
json.dump(output, f, indent=2)
213+
print(f"\nResults saved to {output_file}")
214+
215+
216+
if __name__ == '__main__':
217+
main()

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "arm_pytorch_utilities"
3-
version = "0.4.3"
3+
version = "0.5.0"
44
description = "Utilities for working with pytorch"
55
readme = "README.md" # Optional
66

@@ -68,7 +68,7 @@ dependencies = [# Optional
6868
# Similar to `dependencies` above, these must be valid existing
6969
# projects.
7070
[project.optional-dependencies] # Optional
71-
test = ["pytest"]
71+
test = ["pytest", "pytest-benchmark"]
7272

7373
# List URLs that are relevant to your project
7474
#

src/arm_pytorch_utilities/linalg.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,7 @@ def kronecker_product(t1, t2):
4242
Computes the Kronecker product between two tensors.
4343
See https://en.wikipedia.org/wiki/Kronecker_product
4444
"""
45-
t1_height, t1_width = t1.size()
46-
t2_height, t2_width = t2.size()
47-
out_height = t1_height * t2_height
48-
out_width = t1_width * t2_width
49-
50-
tiled_t2 = t2.repeat(t1_height, t1_width)
51-
expanded_t1 = (
52-
t1.unsqueeze(2)
53-
.unsqueeze(3)
54-
.repeat(1, t2_height, t2_width, 1)
55-
.view(out_height, out_width)
56-
)
57-
58-
return expanded_t1 * tiled_t2
45+
return torch.kron(t1, t2)
5946

6047

6148
def cov(x, rowvar=False, bias=False, ddof=None, aweights=None):
@@ -115,7 +102,7 @@ class MatrixSquareRoot(Function):
115102

116103
@staticmethod
117104
def forward(ctx, input):
118-
m = input.detach().numpy().astype(np.float_)
105+
m = input.detach().numpy().astype(np.float64)
119106
sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).type_as(input)
120107
ctx.save_for_backward(sqrtm)
121108
return sqrtm
@@ -125,8 +112,8 @@ def backward(ctx, grad_output):
125112
grad_input = None
126113
if ctx.needs_input_grad[0]:
127114
sqrtm, = ctx.saved_tensors
128-
sqrtm = sqrtm.data.numpy().astype(np.float_)
129-
gm = grad_output.data.numpy().astype(np.float_)
115+
sqrtm = sqrtm.data.numpy().astype(np.float64)
116+
gm = grad_output.data.numpy().astype(np.float64)
130117

131118
# Given a positive semi-definite matrix X,
132119
# since X = X^{1/2}X^{1/2}, we can compute the gradient of the
@@ -152,16 +139,16 @@ def forward(ctx, A, b):
152139
# A: (..., M, N)
153140
# b: (..., M, K)
154141
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py#L267
155-
u = torch.cholesky(torch.matmul(A.transpose(-1, -2), A), upper=True)
156-
ret = torch.cholesky_solve(torch.matmul(A.transpose(-1, -2), b), u, upper=True)
157-
ctx.save_for_backward(u, ret, A, b)
142+
L = torch.linalg.cholesky(torch.matmul(A.transpose(-1, -2), A))
143+
ret = torch.cholesky_solve(torch.matmul(A.transpose(-1, -2), b), L, upper=False)
144+
ctx.save_for_backward(L, ret, A, b)
158145
return ret
159146

160147
@staticmethod
161148
def backward(ctx, grad_output):
162149
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L223
163150
chol, x, a, b = ctx.saved_tensors
164-
z = torch.cholesky_solve(grad_output, chol, upper=True)
151+
z = torch.cholesky_solve(grad_output, chol, upper=False)
165152
xzt = torch.matmul(x, z.transpose(-1, -2))
166153
zx_sym = xzt + xzt.transpose(-1, -2)
167154
grad_A = - torch.matmul(a, zx_sym) + torch.matmul(b, z.transpose(-1, -2))
@@ -197,11 +184,13 @@ def ls(X, Y, weights=None):
197184
def ls_cov(X, Y, weights=None, make_symmetric=True, sigreg=1e-4):
198185
X, Y = _apply_weights(X, Y, weights)
199186

200-
pinvXX = X.pinverse()
201-
params = (pinvXX @ Y).t()
187+
# Solve least squares via lstsq (more stable than pinverse)
188+
result = torch.linalg.lstsq(X, Y)
189+
params = result.solution.t()
202190

203191
# estimate covariance according to: http://users.stat.umn.edu/~helwig/notes/mvlr-Notes.pdf (see up to slide 66)
204192
# hat/projection matrix - Yhat = H*Y
193+
pinvXX = X.pinverse()
205194
H = X @ pinvXX
206195

207196
N = X.shape[0]
@@ -231,8 +220,9 @@ def ls_cov(X, Y, weights=None, make_symmetric=True, sigreg=1e-4):
231220
XXXX = XXXX_sym
232221
error_covariance = error_covariance_sym
233222

234-
# TODO might be able to use cholesky decomp here since XXXX > 0
235-
covariance = kronecker_product(error_covariance, XXXX.inverse())
223+
# Use solve instead of explicit inverse: solve(A, I) = A^{-1}, more numerically stable
224+
XXXX_inv = torch.linalg.solve(XXXX, torch.eye(XXXX.shape[0], dtype=XXXX.dtype, device=XXXX.device)).contiguous()
225+
covariance = kronecker_product(error_covariance, XXXX_inv)
236226

237227
return params, covariance
238228

src/arm_pytorch_utilities/math_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ def clip(a, min_val, max_val):
1111

1212
def replace_nan_and_inf(a, replacement=0):
1313
"""Replaces nan,inf,-inf values with replacement value in place"""
14-
a[torch.isnan(a)] = replacement
15-
a[a == float('inf')] = replacement
16-
a[a == -float('inf')] = replacement
14+
torch.nan_to_num(a, nan=replacement, posinf=replacement, neginf=replacement, out=a)
1715
return a
1816

1917

@@ -66,8 +64,8 @@ def angle_between_stable(u: torch.tensor, v: torch.tensor):
6664
dim = -1
6765
u_norm = u.norm(dim=dim, keepdim=True)
6866
v_norm = v.norm(dim=dim, keepdim=True)
69-
uv = u.unsqueeze(1).repeat(1, v.shape[0], 1) * v_norm
70-
vu = v.unsqueeze(0).repeat(u.shape[0], 1, 1) * u_norm.unsqueeze(1)
67+
uv = u.unsqueeze(1) * v_norm.transpose(-2, -1).unsqueeze(-1)
68+
vu = v.unsqueeze(0) * u_norm.unsqueeze(1)
7169
num = (uv - vu).norm(dim=dim)
7270
den = (uv + vu).norm(dim=dim)
7371
return 2 * torch.atan2(num, den)
@@ -104,9 +102,7 @@ def angular_diff(a, b):
104102
def angular_diff_batch(a, b):
105103
"""Angle difference from b to a (a - b)"""
106104
d = a - b
107-
d[d > math.pi] -= 2 * math.pi
108-
d[d < -math.pi] += 2 * math.pi
109-
return d
105+
return ((d + math.pi) % (2 * math.pi)) - math.pi
110106

111107

112108
def angle_normalize(a):

src/arm_pytorch_utilities/optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ def step(self, closure=None):
106106
for group in self.optimizer.param_groups:
107107
for p in group['params']:
108108
param_state = self.state[p]
109-
p.data.mul_(self.la_alpha).add_(1.0 - self.la_alpha, param_state['cached_params']) # crucial line
109+
p.data.mul_(self.la_alpha).add_(param_state['cached_params'], alpha=1.0 - self.la_alpha) # crucial line
110110
param_state['cached_params'].copy_(p.data)
111111
if self.pullback_momentum == "pullback":
112112
internal_momentum = self.optimizer.state[p]["momentum_buffer"]
113113
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_(
114-
1.0 - self.la_alpha, param_state["cached_mom"])
114+
param_state["cached_mom"], alpha=1.0 - self.la_alpha)
115115
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
116116
elif self.pullback_momentum == "reset":
117117
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)

src/arm_pytorch_utilities/preprocess.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,13 +311,15 @@ def __init__(self, **kwargs):
311311
super().__init__(**kwargs)
312312
self._m = None
313313
self._s = None
314+
self._inv_s = None
314315

315316
def fit(self, X):
316317
self._m = X.mean(0, keepdim=True)
317318
self._s = X.std(0, unbiased=False, keepdim=True)
319+
self._inv_s = 1.0 / self._s
318320

319321
def transform(self, X):
320-
return (X - self._m) / self._s
322+
return (X - self._m) * self._inv_s
321323

322324
def inverse_transform(self, X):
323325
return (X * self._s) + self._m

0 commit comments

Comments
 (0)