Skip to content

Commit c8536f5

Browse files
Fix compile_fuse broadcast split aliasing bug (ml-explore#3166)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
1 parent 0c8107c commit c8536f5

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

mlx/compile.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,8 @@ void compile_fuse(
856856
// are not fusable except for broadcast which we can split to avoid
857857
// stopping fusion
858858
if (!all_parents_in) {
859-
if (a.has_primitive() && is_broadcast(a.primitive())) {
859+
if (a.has_primitive() && is_broadcast(a.primitive()) &&
860+
input_set.size() < max_compile_arrays) {
860861
array b = split_one(a, parents_map, cache);
861862
recurse(b, depth, s, shape);
862863
} else {

python/tests/test_compile.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,28 @@ def fun(x, y, z):
10491049
self.assertTrue(mx.allclose(d[0], d_hat[0]))
10501050
self.assertTrue(mx.allclose(d[1], d_hat[1]))
10511051

1052+
def test_compile_large_graph_with_broadcasts(self):
1053+
N = 20
1054+
_as = [mx.array(2 * i, dtype=mx.float32) for i in range(N)]
1055+
_bs = [mx.array(i, dtype=mx.float32) for i in range(N)]
1056+
_c = mx.array(0.0)
1057+
x = mx.random.normal((2, 2))
1058+
1059+
def f(x):
1060+
y = 0
1061+
for i in range(N):
1062+
y = y + _as[i] * x * _bs[i] * _c
1063+
return y
1064+
1065+
ref = f(x)
1066+
mx.eval(ref)
1067+
f = mx.compile(f)
1068+
for i in range(2):
1069+
y = f(x)
1070+
mx.eval(y)
1071+
1072+
self.assertTrue(mx.allclose(y, ref))
1073+
10521074
def test_wrap_compiled(self):
10531075
@mx.compile
10541076
def inner():

0 commit comments

Comments
 (0)