Skip to content

Commit 1d8d693

Browse files
belkakariGleb Sterkinangeloskath
authored
[Metal] Add implicit matmul pathway for mx.conv3d (#3147)
Co-authored-by: Gleb Sterkin <g_sterkin@apple.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
1 parent d4c8106 commit 1d8d693

13 files changed

Lines changed: 1238 additions & 65 deletions

File tree

benchmarks/python/conv3d_bench.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import math
2+
import time
3+
4+
import mlx.core as mx
5+
import numpy as np
6+
import torch
7+
8+
N_warmup = 2
9+
N_iter_bench = 10
10+
N_iter_func = 10
11+
12+
13+
def bench(f, a, b, b_prime):
14+
for i in range(N_warmup):
15+
f(a, b, b_prime)
16+
torch.mps.synchronize()
17+
18+
s = time.perf_counter_ns()
19+
for i in range(N_iter_bench):
20+
f(a, b, b_prime)
21+
e = time.perf_counter_ns()
22+
return (e - s) * 1e-9
23+
24+
25+
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
26+
def mx_conv_3D(a, b, b_prime):
27+
y = a
28+
for i in range(N_iter_func):
29+
y = mx.conv3d(y, b, stride=strides, padding=padding, groups=groups)
30+
y = mx.conv3d(y, b_prime, stride=strides, padding=padding, groups=groups)
31+
mx.eval(y)
32+
return y
33+
34+
return mx_conv_3D
35+
36+
37+
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
38+
@torch.no_grad()
39+
def pt_conv_3D(a, b, b_prime):
40+
y = a
41+
for i in range(N_iter_func):
42+
y = torch.conv3d(y, b, stride=strides, padding=padding, groups=groups)
43+
y = torch.conv3d(y, b_prime, stride=strides, padding=padding, groups=groups)
44+
torch.mps.synchronize()
45+
return y
46+
47+
return pt_conv_3D
48+
49+
50+
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
51+
scale = 1.0 / math.sqrt(kD * kH * kW * C)
52+
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C))
53+
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups)))
54+
b_prime_np = np.random.uniform(-scale, scale, (C, kD, kH, kW, int(O / groups)))
55+
56+
a_np, b_np, b_prime_np = map(lambda x: x.astype(np_dtype), (a_np, b_np, b_prime_np))
57+
a_mx, b_mx, b_prime_mx = map(lambda x: mx.array(x), (a_np, b_np, b_prime_np))
58+
a_pt, b_pt, b_prime_pt = map(
59+
lambda x: torch.from_numpy(x.transpose(0, 4, 1, 2, 3)).to("mps"),
60+
(a_np, b_np, b_prime_np),
61+
)
62+
63+
torch.mps.synchronize()
64+
65+
f_mx = make_mx_conv_3D(strides, padding, groups)
66+
f_pt = make_pt_conv_3D(strides, padding, groups)
67+
68+
time_torch = bench(f_pt, a_pt, b_pt, b_prime_pt)
69+
time_mlx = bench(f_mx, a_mx, b_mx, b_prime_mx)
70+
71+
# Measure MLX memory
72+
mx.clear_cache()
73+
mx.reset_peak_memory()
74+
y = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
75+
mx.eval(y)
76+
mlx_peak_mb = mx.get_peak_memory() / 1024**2
77+
mlx_active_mb = mx.get_active_memory() / 1024**2
78+
del y
79+
80+
# Measure PyTorch MPS memory
81+
torch.mps.synchronize()
82+
torch.mps.empty_cache()
83+
y = torch.conv3d(a_pt, b_pt, stride=strides, padding=padding, groups=groups)
84+
torch.mps.synchronize()
85+
pt_current_mb = torch.mps.current_allocated_memory() / 1024**2
86+
pt_driver_mb = torch.mps.driver_allocated_memory() / 1024**2
87+
del y
88+
89+
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
90+
out_pt = torch.conv3d(
91+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
92+
)
93+
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
94+
out_pt = out_pt.numpy(force=True)
95+
96+
atol = 2e-5 if np_dtype == np.float32 else 5e-4
97+
98+
if not np.allclose(out_pt, out_mx, atol=atol):
99+
print(
100+
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} "
101+
f"[strides = {strides}, padding = {padding}, groups = {groups}] "
102+
f"with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
103+
)
104+
105+
return time_mlx, time_torch, mlx_peak_mb, mlx_active_mb, pt_current_mb, pt_driver_mb
106+
107+
108+
if __name__ == "__main__":
109+
dtypes = ("float16", "float32")
110+
shapes = (
111+
# (C % 16 == 0)
112+
(4, 16, 16, 16, 32, 3, 3, 3, 32, (1, 1, 1), (1, 1, 1), 1),
113+
(4, 16, 16, 16, 64, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),
114+
(4, 16, 16, 16, 128, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),
115+
(4, 32, 32, 32, 64, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),
116+
(4, 32, 32, 32, 128, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),
117+
# Larger spatial dims
118+
(2, 64, 64, 64, 32, 3, 3, 3, 64, (1, 1, 1), (1, 1, 1), 1),
119+
(1, 64, 64, 64, 64, 3, 3, 3, 128, (1, 1, 1), (1, 1, 1), 1),
120+
# Strided
121+
(4, 32, 32, 32, 64, 3, 3, 3, 128, (2, 2, 2), (1, 1, 1), 1),
122+
# Asymmetric kernels
123+
(4, 32, 32, 32, 64, 3, 1, 1, 128, (1, 1, 1), (1, 0, 0), 1),
124+
(4, 32, 32, 32, 64, 1, 3, 3, 128, (1, 1, 1), (0, 1, 1), 1),
125+
# (C % 16 != 0)
126+
(4, 16, 16, 16, 21, 3, 3, 3, 21, (1, 1, 1), (1, 1, 1), 1),
127+
(4, 16, 16, 16, 55, 3, 3, 3, 55, (1, 1, 1), (1, 1, 1), 1),
128+
(4, 32, 32, 32, 55, 3, 3, 3, 55, (1, 1, 1), (1, 1, 1), 1),
129+
(4, 16, 16, 16, 3, 3, 3, 3, 32, (1, 1, 1), (1, 1, 1), 1),
130+
)
131+
132+
for dtype in dtypes:
133+
print(f"\n{'=' * 120}" f"\n dtype: {dtype}" f"\n{'=' * 120}")
134+
print(
135+
f"{'(N, D, H, W, C)':<26s} {'( O, kD, kH, kW, C)':<24s} "
136+
f"{'stride':<12s} {'pads':<12s} {'groups':>6s} "
137+
f"{'diff%':>7s} "
138+
f"{'MLX peak':>9s} {'MLX act':>8s} {'PT cur':>8s} {'PT drv':>8s}"
139+
)
140+
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
141+
np_dtype = getattr(np, dtype)
142+
time_mlx, time_torch, mlx_peak, mlx_act, pt_cur, pt_drv = bench_shape(
143+
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
144+
)
145+
diff = time_torch / time_mlx - 1.0
146+
147+
print(
148+
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), "
149+
f"{strides}, {padding}, {groups:6d}, "
150+
f"{100. * diff:+6.1f}% "
151+
f"{mlx_peak:8.1f} {mlx_act:7.1f} {pt_cur:7.1f} {pt_drv:7.1f}"
152+
)

mlx/backend/metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ if(MLX_METAL_JIT)
7171
kernels/steel/conv/loaders/loader_channel_l.h
7272
kernels/steel/conv/loaders/loader_channel_n.h)
7373
make_jit_source(steel/conv/kernels/steel_conv)
74+
make_jit_source(steel/conv/kernels/steel_conv_3d)
7475
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
7576
kernels/steel/conv/loaders/loader_general.h)
7677

0 commit comments

Comments
 (0)