@@ -899,6 +899,51 @@ def scatter_fn(x, m, src):
899899 out = double_scatter (a + 0 , mask , src )
900900 self .assertTrue (mx .array_equal (expected , out ))
901901
902+ def test_broadcast_axes_vmap (self ):
903+ # Broadcast axes requires shapeless compile to properly test
904+
905+ counter = [0 ]
906+
907+ def fn (x , y ):
908+ counter [0 ] += 1
909+ return mx .matmul (x , y )
910+
911+ x = mx .random .normal ((2 , 3 , 1 , 4 , 5 ))
912+ y = mx .random .normal ((1 , 2 , 5 , 6 ))
913+ z = mx .random .normal ((3 , 2 , 1 , 4 , 5 ))
914+ w = mx .random .normal ((2 , 3 , 5 , 6 ))
915+
916+ vmap_fn = mx .vmap (fn , in_axes = (0 , 1 ))
917+ cvmap_fn = mx .compile (vmap_fn , shapeless = True )
918+
919+ expected = vmap_fn (x , y )
920+ out = cvmap_fn (x , y )
921+ self .assertTrue (mx .array_equal (expected , out ))
922+ self .assertEqual (2 , counter [0 ])
923+
924+ expected = vmap_fn (z , w )
925+ out = cvmap_fn (z , w )
926+ self .assertTrue (mx .array_equal (expected , out ))
927+ self .assertEqual (3 , counter [0 ])
928+
929+ x = mx .random .normal ((2 , 3 , 1 , 4 , 5 ))
930+ y = mx .random .normal ((1 , 2 , 5 , 6 ))
931+ z = mx .random .normal ((2 , 3 , 1 , 7 , 2 ))
932+ w = mx .random .normal ((1 , 2 , 2 , 3 ))
933+
934+ vmap_fn = mx .vmap (fn , in_axes = (0 , None ))
935+ cvmap_fn = mx .compile (vmap_fn , shapeless = True )
936+
937+ expected = vmap_fn (x , y )
938+ out = cvmap_fn (x , y )
939+ self .assertTrue (mx .array_equal (expected , out ))
940+ self .assertEqual (5 , counter [0 ])
941+
942+ expected = vmap_fn (z , w )
943+ out = cvmap_fn (z , w )
944+ self .assertTrue (mx .array_equal (expected , out ))
945+ self .assertEqual (6 , counter [0 ])
946+
902947
903948if __name__ == "__main__" :
904949 mlx_tests .MLXTestRunner ()
0 commit comments