Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions benchmarks/bench_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
#!/usr/bin/env python
"""Benchmark script for arm_pytorch_utilities torch.compile optimization.

Measures eager-mode and torch.compile performance for refactor-target functions.
Outputs a printed table and JSON file for before/after comparison.

Usage:
python benchmarks/bench_compile.py --device cpu
python benchmarks/bench_compile.py --device cuda
"""
import argparse
import json
import time
from datetime import datetime
from pathlib import Path

import torch

from arm_pytorch_utilities import math_utils, linalg, tensor_utils, preprocess, softknn


def bench(fn, *args, warmup=5, repeats=20, device='cpu'):
"""Time a function with warmup and repeats. Returns median time in ms."""
for _ in range(warmup):
fn(*args)
if device == 'cuda':
torch.cuda.synchronize()

times = []
for _ in range(repeats):
if device == 'cuda':
torch.cuda.synchronize()
t0 = time.perf_counter()
fn(*args)
if device == 'cuda':
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append((t1 - t0) * 1000)

times.sort()
return times[len(times) // 2]


def try_compile_bench(fn, *args, device='cpu', warmup=5, repeats=20):
"""Try to compile and benchmark a function. Returns (time_ms, success) or (None, False)."""
try:
compiled_fn = torch.compile(fn, fullgraph=True)
# Warmup includes compilation
for _ in range(warmup):
compiled_fn(*args)
if device == 'cuda':
torch.cuda.synchronize()

times = []
for _ in range(repeats):
if device == 'cuda':
torch.cuda.synchronize()
t0 = time.perf_counter()
compiled_fn(*args)
if device == 'cuda':
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append((t1 - t0) * 1000)

times.sort()
return times[len(times) // 2], True
except Exception as e:
return None, False, str(e)


def make_psd(n, device):
R = torch.randn(n, n, device=device)
return R.t() @ R + torch.eye(n, device=device) * 0.1


def run_benchmarks(device_str):
device = torch.device(device_str)
results = {}

benchmarks = {}

# --- Setup inputs ---
# replace_nan_and_inf
x_nan = torch.randn(10000, 100, device=device)
mask = torch.rand_like(x_nan) < 0.1
x_nan[mask] = float('nan')
benchmarks['replace_nan_and_inf'] = (math_utils.replace_nan_and_inf, (x_nan.clone(), 0), True)

# angular_diff_batch
a_ang = torch.randn(100000, device=device)
b_ang = torch.randn(100000, device=device)
benchmarks['angular_diff_batch'] = (math_utils.angular_diff_batch, (a_ang, b_ang), False)

# angle_between_stable
u_abs = torch.randn(200, 50, device=device)
v_abs = torch.randn(150, 50, device=device)
benchmarks['angle_between_stable'] = (math_utils.angle_between_stable, (u_abs, v_abs), False)

# cos_sim_pairwise
x1_cos = torch.randn(500, 50, device=device)
x2_cos = torch.randn(300, 50, device=device)
benchmarks['cos_sim_pairwise'] = (math_utils.cos_sim_pairwise, (x1_cos, x2_cos), False)

# batch_batch_product
X_bbp = torch.randn(10000, 20, device=device)
A_bbp = torch.randn(10000, 20, 20, device=device)
benchmarks['batch_batch_product'] = (linalg.batch_batch_product, (X_bbp, A_bbp), False)

# batch_quadratic_product
X_bqp = torch.randn(10000, 20, device=device)
A_bqp = make_psd(20, device)
benchmarks['batch_quadratic_product'] = (linalg.batch_quadratic_product, (X_bqp, A_bqp), False)

# batch_outer_product
u_bop = torch.randn(10000, 20, device=device)
v_bop = torch.randn(10000, 20, device=device)
benchmarks['batch_outer_product'] = (linalg.batch_outer_product, (u_bop, v_bop), False)

# squeeze_n
x_sq = torch.randn(1, 1, 1, 1000, 50, device=device)
benchmarks['squeeze_n'] = (tensor_utils.squeeze_n, (x_sq, 3), False)

# MinMaxScaler.transform
x_mm = torch.randn(10000, 50, device=device)
scaler = preprocess.MinMaxScaler()
scaler.fit(x_mm)
benchmarks['MinMaxScaler.transform'] = (scaler.transform, (x_mm,), False)

# SoftKNN.forward
x_knn = torch.randn(200, 10, device=device)
knn = softknn.SoftKNN(min_k=20)
benchmarks['SoftKNN.forward'] = (knn, (x_knn,), False)

# sqrtm (CPU only due to .numpy())
if device_str == 'cpu':
A_sqrtm = make_psd(50, device)
benchmarks['sqrtm'] = (linalg.sqrtm, (A_sqrtm,), False)

# --- Run benchmarks ---
print(f"\n{'Function':<30} {'Eager (ms)':>12} {'Compile (ms)':>14} {'Speedup':>10} {'Compile OK':>12}")
print("-" * 80)

for name, (fn, args, needs_clone) in benchmarks.items():
# Eager benchmark
if needs_clone:
# For in-place functions, clone first arg each call
template = args[0]
rest_args = args[1:]

def cloning_fn(*a, _fn=fn, _tpl=template, _rest=rest_args):
return _fn(_tpl.clone(), *_rest)

eager_ms = bench(cloning_fn, warmup=5, repeats=20, device=device_str)
else:
try:
eager_ms = bench(fn, *args, device=device_str)
except Exception as e:
print(f"{name:<30} {'ERROR':>12} {'N/A':>14} {'N/A':>10} {'N/A':>12} ({e})")
results[name] = {'eager_ms': None, 'compile_ms': None, 'compile_ok': False,
'compile_error': None, 'eager_error': str(e)}
continue

# Compile benchmark
compile_result = try_compile_bench(fn, *args, device=device_str)

if len(compile_result) == 2:
compile_ms, compile_ok = compile_result
compile_err = None
else:
compile_ms, compile_ok, compile_err = compile_result

speedup = f"{eager_ms / compile_ms:.2f}x" if compile_ms else "N/A"
compile_str = f"{compile_ms:.3f}" if compile_ms else "FAIL"

print(f"{name:<30} {eager_ms:>12.3f} {compile_str:>14} {speedup:>10} {'yes' if compile_ok else 'no':>12}")

results[name] = {
'eager_ms': eager_ms,
'compile_ms': compile_ms,
'compile_ok': compile_ok,
'compile_error': compile_err,
}

return results


def main():
parser = argparse.ArgumentParser(description='Benchmark arm_pytorch_utilities functions')
parser.add_argument('--device', choices=['cpu', 'cuda'], default='cpu')
args = parser.parse_args()

if args.device == 'cuda' and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
args.device = 'cpu'

print(f"Running benchmarks on {args.device}")
print(f"PyTorch version: {torch.__version__}")

results = run_benchmarks(args.device)

# Save JSON
output_dir = Path(__file__).parent
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_file = output_dir / f"results_{args.device}_{timestamp}.json"
output = {
'device': args.device,
'torch_version': torch.__version__,
'timestamp': timestamp,
'results': results,
}
with open(output_file, 'w') as f:
json.dump(output, f, indent=2)
print(f"\nResults saved to {output_file}")


if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "arm_pytorch_utilities"
version = "0.4.3"
version = "0.5.0"
description = "Utilities for working with pytorch"
readme = "README.md" # Optional

Expand Down Expand Up @@ -68,7 +68,7 @@ dependencies = [# Optional
# Similar to `dependencies` above, these must be valid existing
# projects.
[project.optional-dependencies] # Optional
test = ["pytest"]
test = ["pytest", "pytest-benchmark"]

# List URLs that are relevant to your project
#
Expand Down
40 changes: 15 additions & 25 deletions src/arm_pytorch_utilities/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,7 @@ def kronecker_product(t1, t2):
Computes the Kronecker product between two tensors.
See https://en.wikipedia.org/wiki/Kronecker_product
"""
t1_height, t1_width = t1.size()
t2_height, t2_width = t2.size()
out_height = t1_height * t2_height
out_width = t1_width * t2_width

tiled_t2 = t2.repeat(t1_height, t1_width)
expanded_t1 = (
t1.unsqueeze(2)
.unsqueeze(3)
.repeat(1, t2_height, t2_width, 1)
.view(out_height, out_width)
)

return expanded_t1 * tiled_t2
return torch.kron(t1, t2)


def cov(x, rowvar=False, bias=False, ddof=None, aweights=None):
Expand Down Expand Up @@ -115,7 +102,7 @@ class MatrixSquareRoot(Function):

@staticmethod
def forward(ctx, input):
m = input.detach().numpy().astype(np.float_)
m = input.detach().numpy().astype(np.float64)
sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).type_as(input)
ctx.save_for_backward(sqrtm)
return sqrtm
Expand All @@ -125,8 +112,8 @@ def backward(ctx, grad_output):
grad_input = None
if ctx.needs_input_grad[0]:
sqrtm, = ctx.saved_tensors
sqrtm = sqrtm.data.numpy().astype(np.float_)
gm = grad_output.data.numpy().astype(np.float_)
sqrtm = sqrtm.data.numpy().astype(np.float64)
gm = grad_output.data.numpy().astype(np.float64)

# Given a positive semi-definite matrix X,
# since X = X^{1/2}X^{1/2}, we can compute the gradient of the
Expand All @@ -152,16 +139,16 @@ def forward(ctx, A, b):
# A: (..., M, N)
# b: (..., M, K)
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py#L267
u = torch.cholesky(torch.matmul(A.transpose(-1, -2), A), upper=True)
ret = torch.cholesky_solve(torch.matmul(A.transpose(-1, -2), b), u, upper=True)
ctx.save_for_backward(u, ret, A, b)
L = torch.linalg.cholesky(torch.matmul(A.transpose(-1, -2), A))
ret = torch.cholesky_solve(torch.matmul(A.transpose(-1, -2), b), L, upper=False)
ctx.save_for_backward(L, ret, A, b)
return ret

@staticmethod
def backward(ctx, grad_output):
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L223
chol, x, a, b = ctx.saved_tensors
z = torch.cholesky_solve(grad_output, chol, upper=True)
z = torch.cholesky_solve(grad_output, chol, upper=False)
xzt = torch.matmul(x, z.transpose(-1, -2))
zx_sym = xzt + xzt.transpose(-1, -2)
grad_A = - torch.matmul(a, zx_sym) + torch.matmul(b, z.transpose(-1, -2))
Expand Down Expand Up @@ -197,11 +184,13 @@ def ls(X, Y, weights=None):
def ls_cov(X, Y, weights=None, make_symmetric=True, sigreg=1e-4):
X, Y = _apply_weights(X, Y, weights)

pinvXX = X.pinverse()
params = (pinvXX @ Y).t()
# Solve least squares via lstsq (more stable than pinverse)
result = torch.linalg.lstsq(X, Y)
params = result.solution.t()

# estimate covariance according to: http://users.stat.umn.edu/~helwig/notes/mvlr-Notes.pdf (see up to slide 66)
# hat/projection matrix - Yhat = H*Y
pinvXX = X.pinverse()
H = X @ pinvXX

N = X.shape[0]
Expand Down Expand Up @@ -231,8 +220,9 @@ def ls_cov(X, Y, weights=None, make_symmetric=True, sigreg=1e-4):
XXXX = XXXX_sym
error_covariance = error_covariance_sym

# TODO might be able to use cholesky decomp here since XXXX > 0
covariance = kronecker_product(error_covariance, XXXX.inverse())
# Use solve instead of explicit inverse: solve(A, I) = A^{-1}, more numerically stable
XXXX_inv = torch.linalg.solve(XXXX, torch.eye(XXXX.shape[0], dtype=XXXX.dtype, device=XXXX.device)).contiguous()
covariance = kronecker_product(error_covariance, XXXX_inv)

return params, covariance

Expand Down
12 changes: 4 additions & 8 deletions src/arm_pytorch_utilities/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def clip(a, min_val, max_val):

def replace_nan_and_inf(a, replacement=0):
"""Replaces nan,inf,-inf values with replacement value in place"""
a[torch.isnan(a)] = replacement
a[a == float('inf')] = replacement
a[a == -float('inf')] = replacement
torch.nan_to_num(a, nan=replacement, posinf=replacement, neginf=replacement, out=a)
return a


Expand Down Expand Up @@ -66,8 +64,8 @@ def angle_between_stable(u: torch.tensor, v: torch.tensor):
dim = -1
u_norm = u.norm(dim=dim, keepdim=True)
v_norm = v.norm(dim=dim, keepdim=True)
uv = u.unsqueeze(1).repeat(1, v.shape[0], 1) * v_norm
vu = v.unsqueeze(0).repeat(u.shape[0], 1, 1) * u_norm.unsqueeze(1)
uv = u.unsqueeze(1) * v_norm.transpose(-2, -1).unsqueeze(-1)
vu = v.unsqueeze(0) * u_norm.unsqueeze(1)
num = (uv - vu).norm(dim=dim)
den = (uv + vu).norm(dim=dim)
return 2 * torch.atan2(num, den)
Expand Down Expand Up @@ -104,9 +102,7 @@ def angular_diff(a, b):
def angular_diff_batch(a, b):
"""Angle difference from b to a (a - b)"""
d = a - b
d[d > math.pi] -= 2 * math.pi
d[d < -math.pi] += 2 * math.pi
return d
return ((d + math.pi) % (2 * math.pi)) - math.pi


def angle_normalize(a):
Expand Down
4 changes: 2 additions & 2 deletions src/arm_pytorch_utilities/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def step(self, closure=None):
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.mul_(self.la_alpha).add_(1.0 - self.la_alpha, param_state['cached_params']) # crucial line
p.data.mul_(self.la_alpha).add_(param_state['cached_params'], alpha=1.0 - self.la_alpha) # crucial line
param_state['cached_params'].copy_(p.data)
if self.pullback_momentum == "pullback":
internal_momentum = self.optimizer.state[p]["momentum_buffer"]
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_(
1.0 - self.la_alpha, param_state["cached_mom"])
param_state["cached_mom"], alpha=1.0 - self.la_alpha)
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
elif self.pullback_momentum == "reset":
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)
Expand Down
4 changes: 3 additions & 1 deletion src/arm_pytorch_utilities/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,15 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self._m = None
self._s = None
self._inv_s = None

def fit(self, X):
self._m = X.mean(0, keepdim=True)
self._s = X.std(0, unbiased=False, keepdim=True)
self._inv_s = 1.0 / self._s

def transform(self, X):
return (X - self._m) / self._s
return (X - self._m) * self._inv_s

def inverse_transform(self, X):
return (X * self._s) + self._m
Expand Down
Loading