Skip to content

Commit 52d67f2

Browse files
LemonPiclaude
andcommitted
Optimize math_utils for performance and fix sqrtm numpy 2.0 crash
- replace_nan_and_inf: use torch.nan_to_num (~3.9x faster) - angular_diff_batch: use modulo wrapping (~1.4x faster, fixes correctness for large diffs) - angle_between_stable: use broadcasting instead of .repeat() (~1.15x faster) - sqrtm: replace removed np.float_ alias with np.float64 - bench_compile: remove lambda wrappers that broke torch.compile tracing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5dc5a71 commit 52d67f2

4 files changed

Lines changed: 40 additions & 35 deletions

File tree

benchmarks/bench_compile.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,69 +84,73 @@ def run_benchmarks(device_str):
8484
x_nan = torch.randn(10000, 100, device=device)
8585
mask = torch.rand_like(x_nan) < 0.1
8686
x_nan[mask] = float('nan')
87-
benchmarks['replace_nan_and_inf'] = (math_utils.replace_nan_and_inf, (x_nan.clone(), 0))
87+
benchmarks['replace_nan_and_inf'] = (math_utils.replace_nan_and_inf, (x_nan.clone(), 0), True)
8888

8989
# angular_diff_batch
9090
a_ang = torch.randn(100000, device=device)
9191
b_ang = torch.randn(100000, device=device)
92-
benchmarks['angular_diff_batch'] = (math_utils.angular_diff_batch, (a_ang, b_ang))
92+
benchmarks['angular_diff_batch'] = (math_utils.angular_diff_batch, (a_ang, b_ang), False)
9393

9494
# angle_between_stable
9595
u_abs = torch.randn(200, 50, device=device)
9696
v_abs = torch.randn(150, 50, device=device)
97-
benchmarks['angle_between_stable'] = (math_utils.angle_between_stable, (u_abs, v_abs))
97+
benchmarks['angle_between_stable'] = (math_utils.angle_between_stable, (u_abs, v_abs), False)
9898

9999
# cos_sim_pairwise
100100
x1_cos = torch.randn(500, 50, device=device)
101101
x2_cos = torch.randn(300, 50, device=device)
102-
benchmarks['cos_sim_pairwise'] = (math_utils.cos_sim_pairwise, (x1_cos, x2_cos))
102+
benchmarks['cos_sim_pairwise'] = (math_utils.cos_sim_pairwise, (x1_cos, x2_cos), False)
103103

104104
# batch_batch_product
105105
X_bbp = torch.randn(10000, 20, device=device)
106106
A_bbp = torch.randn(10000, 20, 20, device=device)
107-
benchmarks['batch_batch_product'] = (linalg.batch_batch_product, (X_bbp, A_bbp))
107+
benchmarks['batch_batch_product'] = (linalg.batch_batch_product, (X_bbp, A_bbp), False)
108108

109109
# batch_quadratic_product
110110
X_bqp = torch.randn(10000, 20, device=device)
111111
A_bqp = make_psd(20, device)
112-
benchmarks['batch_quadratic_product'] = (linalg.batch_quadratic_product, (X_bqp, A_bqp))
112+
benchmarks['batch_quadratic_product'] = (linalg.batch_quadratic_product, (X_bqp, A_bqp), False)
113113

114114
# batch_outer_product
115115
u_bop = torch.randn(10000, 20, device=device)
116116
v_bop = torch.randn(10000, 20, device=device)
117-
benchmarks['batch_outer_product'] = (linalg.batch_outer_product, (u_bop, v_bop))
117+
benchmarks['batch_outer_product'] = (linalg.batch_outer_product, (u_bop, v_bop), False)
118118

119119
# squeeze_n
120120
x_sq = torch.randn(1, 1, 1, 1000, 50, device=device)
121-
benchmarks['squeeze_n'] = (lambda x: tensor_utils.squeeze_n(x, 3), (x_sq,))
121+
benchmarks['squeeze_n'] = (tensor_utils.squeeze_n, (x_sq, 3), False)
122122

123123
# MinMaxScaler.transform
124124
x_mm = torch.randn(10000, 50, device=device)
125125
scaler = preprocess.MinMaxScaler()
126126
scaler.fit(x_mm)
127-
benchmarks['MinMaxScaler.transform'] = (scaler.transform, (x_mm,))
127+
benchmarks['MinMaxScaler.transform'] = (scaler.transform, (x_mm,), False)
128128

129129
# SoftKNN.forward
130130
x_knn = torch.randn(200, 10, device=device)
131131
knn = softknn.SoftKNN(min_k=20)
132-
benchmarks['SoftKNN.forward'] = (knn, (x_knn,))
132+
benchmarks['SoftKNN.forward'] = (knn, (x_knn,), False)
133133

134134
# sqrtm (CPU only due to .numpy())
135135
if device_str == 'cpu':
136136
A_sqrtm = make_psd(50, device)
137-
benchmarks['sqrtm'] = (linalg.sqrtm, (A_sqrtm,))
137+
benchmarks['sqrtm'] = (linalg.sqrtm, (A_sqrtm,), False)
138138

139139
# --- Run benchmarks ---
140140
print(f"\n{'Function':<30} {'Eager (ms)':>12} {'Compile (ms)':>14} {'Speedup':>10} {'Compile OK':>12}")
141141
print("-" * 80)
142142

143-
for name, (fn, args) in benchmarks.items():
143+
for name, (fn, args, needs_clone) in benchmarks.items():
144144
# Eager benchmark
145-
# For replace_nan_and_inf, need fresh clone each call
146-
if name == 'replace_nan_and_inf':
147-
def eager_fn(x_template=x_nan):
148-
return math_utils.replace_nan_and_inf(x_template.clone(), 0)
149-
eager_ms = bench(eager_fn, warmup=5, repeats=20, device=device_str)
145+
if needs_clone:
146+
# For in-place functions, clone first arg each call
147+
template = args[0]
148+
rest_args = args[1:]
149+
150+
def cloning_fn(*a, _fn=fn, _tpl=template, _rest=rest_args):
151+
return _fn(_tpl.clone(), *_rest)
152+
153+
eager_ms = bench(cloning_fn, warmup=5, repeats=20, device=device_str)
150154
else:
151155
try:
152156
eager_ms = bench(fn, *args, device=device_str)
@@ -157,10 +161,7 @@ def eager_fn(x_template=x_nan):
157161
continue
158162

159163
# Compile benchmark
160-
if name == 'replace_nan_and_inf':
161-
compile_result = try_compile_bench(eager_fn, device=device_str)
162-
else:
163-
compile_result = try_compile_bench(fn, *args, device=device_str)
164+
compile_result = try_compile_bench(fn, *args, device=device_str)
164165

165166
if len(compile_result) == 2:
166167
compile_ms, compile_ok = compile_result

src/arm_pytorch_utilities/linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class MatrixSquareRoot(Function):
115115

116116
@staticmethod
117117
def forward(ctx, input):
118-
m = input.detach().numpy().astype(np.float_)
118+
m = input.detach().numpy().astype(np.float64)
119119
sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).type_as(input)
120120
ctx.save_for_backward(sqrtm)
121121
return sqrtm
@@ -125,8 +125,8 @@ def backward(ctx, grad_output):
125125
grad_input = None
126126
if ctx.needs_input_grad[0]:
127127
sqrtm, = ctx.saved_tensors
128-
sqrtm = sqrtm.data.numpy().astype(np.float_)
129-
gm = grad_output.data.numpy().astype(np.float_)
128+
sqrtm = sqrtm.data.numpy().astype(np.float64)
129+
gm = grad_output.data.numpy().astype(np.float64)
130130

131131
# Given a positive semi-definite matrix X,
132132
# since X = X^{1/2}X^{1/2}, we can compute the gradient of the

src/arm_pytorch_utilities/math_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ def clip(a, min_val, max_val):
1111

1212
def replace_nan_and_inf(a, replacement=0):
1313
"""Replaces nan,inf,-inf values with replacement value in place"""
14-
a[torch.isnan(a)] = replacement
15-
a[a == float('inf')] = replacement
16-
a[a == -float('inf')] = replacement
14+
torch.nan_to_num(a, nan=replacement, posinf=replacement, neginf=replacement, out=a)
1715
return a
1816

1917

@@ -66,8 +64,8 @@ def angle_between_stable(u: torch.tensor, v: torch.tensor):
6664
dim = -1
6765
u_norm = u.norm(dim=dim, keepdim=True)
6866
v_norm = v.norm(dim=dim, keepdim=True)
69-
uv = u.unsqueeze(1).repeat(1, v.shape[0], 1) * v_norm
70-
vu = v.unsqueeze(0).repeat(u.shape[0], 1, 1) * u_norm.unsqueeze(1)
67+
uv = u.unsqueeze(1) * v_norm.transpose(-2, -1).unsqueeze(-1)
68+
vu = v.unsqueeze(0) * u_norm.unsqueeze(1)
7169
num = (uv - vu).norm(dim=dim)
7270
den = (uv + vu).norm(dim=dim)
7371
return 2 * torch.atan2(num, den)
@@ -104,9 +102,7 @@ def angular_diff(a, b):
104102
def angular_diff_batch(a, b):
105103
"""Angle difference from b to a (a - b)"""
106104
d = a - b
107-
d[d > math.pi] -= 2 * math.pi
108-
d[d < -math.pi] += 2 * math.pi
109-
return d
105+
return ((d + math.pi) % (2 * math.pi)) - math.pi
110106

111107

112108
def angle_normalize(a):

tests/test_math.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,23 @@ def test_angular_diff_batch():
130130
assert (result > -math.pi).all()
131131
assert (result <= math.pi + 1e-6).all()
132132

133-
# Compare against element-wise angular_diff
133+
# Compare against element-wise angular_diff with inputs where |a-b| < 2*pi
134+
# (scalar angular_diff only wraps once, so it's only correct in that range)
134135
N = 50
135-
a = (torch.rand(N) - 0.5) * 4 * math.pi
136-
b = (torch.rand(N) - 0.5) * 4 * math.pi
136+
a = (torch.rand(N) - 0.5) * 2 * math.pi
137+
b = (torch.rand(N) - 0.5) * 2 * math.pi
137138
batch_result = math_utils.angular_diff_batch(a, b)
138139
for i in range(N):
139140
scalar_result = math_utils.angular_diff(a[i].item(), b[i].item())
140141
assert abs(batch_result[i].item() - scalar_result) < 1e-5
141142

143+
# Verify batch version handles large differences correctly (beyond single-wrap range)
144+
a_large = torch.tensor([10.0, -10.0, 20.0])
145+
b_large = torch.tensor([0.0, 0.0, 0.0])
146+
result_large = math_utils.angular_diff_batch(a_large, b_large)
147+
assert (result_large > -math.pi).all()
148+
assert (result_large <= math.pi + 1e-6).all()
149+
142150

143151
def test_get_bounds():
144152
assert math_utils.get_bounds(None, 5) == (-5, 5)

0 commit comments

Comments
 (0)