Skip to content

Commit 1944cf6

Browse files
authored
Add vmap for BroadcastAxes (#3344)
1 parent 939e425 commit 1944cf6

3 files changed

Lines changed: 101 additions & 31 deletions

File tree

mlx/ops.cpp

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,23 +1676,19 @@ std::vector<array> broadcast_arrays(
16761676
for (int i = 0; i < inputs.size(); ++i) {
16771677
auto& in = inputs[i];
16781678
auto out_shape = check_and_get_shape(in);
1679-
if (in.shape() == out_shape) {
1680-
outputs.push_back(in);
1681-
} else {
1682-
// broadcasted array goes first followed by other stopgrad inputs
1683-
std::vector<array> p_inputs = {in};
1684-
for (int j = 0; j < inputs.size(); ++j) {
1685-
if (j == i) {
1686-
continue;
1687-
}
1688-
p_inputs.push_back(stop_grad_inputs[j]);
1679+
// broadcasted array goes first followed by other stopgrad inputs
1680+
std::vector<array> p_inputs = {in};
1681+
for (int j = 0; j < inputs.size(); ++j) {
1682+
if (j == i) {
1683+
continue;
16891684
}
1690-
outputs.push_back(array(
1691-
std::move(out_shape),
1692-
in.dtype(),
1693-
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
1694-
std::move(p_inputs)));
1685+
p_inputs.push_back(stop_grad_inputs[j]);
16951686
}
1687+
outputs.push_back(array(
1688+
out_shape,
1689+
in.dtype(),
1690+
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
1691+
std::move(p_inputs)));
16961692
}
16971693
return outputs;
16981694
}
@@ -1727,23 +1723,19 @@ std::vector<array> broadcast_arrays(
17271723
}
17281724
for (int i = 0; i < inputs.size(); ++i) {
17291725
auto& in = inputs[i];
1730-
if (in.shape() == shape) {
1731-
outputs.push_back(in);
1732-
} else {
1733-
// broadcasted array goes first followed by other stopgrad inputs
1734-
std::vector<array> p_inputs = {in};
1735-
for (int j = 0; j < inputs.size(); ++j) {
1736-
if (j == i) {
1737-
continue;
1738-
}
1739-
p_inputs.push_back(stop_grad_inputs[j]);
1726+
// broadcasted array goes first followed by other stopgrad inputs
1727+
std::vector<array> p_inputs = {in};
1728+
for (int j = 0; j < inputs.size(); ++j) {
1729+
if (j == i) {
1730+
continue;
17401731
}
1741-
outputs.push_back(array(
1742-
shape,
1743-
in.dtype(),
1744-
std::make_shared<Broadcast>(to_stream(s), shape),
1745-
std::move(p_inputs)));
1732+
p_inputs.push_back(stop_grad_inputs[j]);
17461733
}
1734+
outputs.push_back(array(
1735+
shape,
1736+
in.dtype(),
1737+
std::make_shared<Broadcast>(to_stream(s), shape),
1738+
std::move(p_inputs)));
17471739
}
17481740
return outputs;
17491741
}

mlx/primitives.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,40 @@ std::vector<array> BroadcastAxes::jvp(
904904
std::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(
905905
const std::vector<array>& inputs,
906906
const std::vector<int>& axes) {
907-
throw std::invalid_argument("[BroadcastAxes] VMAP NYI");
907+
std::vector<array> new_inputs = inputs;
908+
std::vector<int> new_axes = axes;
909+
size_t ndim = 0;
910+
bool have_batch = false;
911+
for (int i = 0; i < inputs.size(); i++) {
912+
have_batch |= axes[i] >= 0;
913+
ndim = std::max(inputs[i].ndim(), ndim);
914+
}
915+
916+
std::vector<int> expand;
917+
expand.reserve(ndim);
918+
for (int i = 0; i < inputs.size(); i++) {
919+
int extra = ndim - inputs[i].ndim();
920+
if (axes[i] >= 0 && extra > 0) {
921+
new_axes[i] += extra;
922+
expand.resize(extra);
923+
std::iota(expand.begin(), expand.end(), 0);
924+
new_inputs[i] = expand_dims(new_inputs[i], expand, stream());
925+
}
926+
927+
if (new_axes[i] > 0) {
928+
new_inputs[i] = moveaxis(new_inputs[i], new_axes[i], 0, stream());
929+
}
930+
}
931+
932+
auto shape = output_shape(new_inputs, ignore_axes_);
933+
auto dtype = new_inputs[0].dtype();
934+
return {
935+
{array(
936+
shape,
937+
dtype,
938+
std::make_shared<BroadcastAxes>(stream(), ignore_axes_),
939+
std::move(new_inputs))},
940+
{have_batch ? 0 : -1}};
908941
}
909942

910943
bool BroadcastAxes::is_equivalent(const Primitive& other) const {

python/tests/test_vmap.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,51 @@ def scatter_fn(x, m, src):
899899
out = double_scatter(a + 0, mask, src)
900900
self.assertTrue(mx.array_equal(expected, out))
901901

902+
def test_broadcast_axes_vmap(self):
903+
# Broadcast axes requires shapeless compile to properly test
904+
905+
counter = [0]
906+
907+
def fn(x, y):
908+
counter[0] += 1
909+
return mx.matmul(x, y)
910+
911+
x = mx.random.normal((2, 3, 1, 4, 5))
912+
y = mx.random.normal((1, 2, 5, 6))
913+
z = mx.random.normal((3, 2, 1, 4, 5))
914+
w = mx.random.normal((2, 3, 5, 6))
915+
916+
vmap_fn = mx.vmap(fn, in_axes=(0, 1))
917+
cvmap_fn = mx.compile(vmap_fn, shapeless=True)
918+
919+
expected = vmap_fn(x, y)
920+
out = cvmap_fn(x, y)
921+
self.assertTrue(mx.array_equal(expected, out))
922+
self.assertEqual(2, counter[0])
923+
924+
expected = vmap_fn(z, w)
925+
out = cvmap_fn(z, w)
926+
self.assertTrue(mx.array_equal(expected, out))
927+
self.assertEqual(3, counter[0])
928+
929+
x = mx.random.normal((2, 3, 1, 4, 5))
930+
y = mx.random.normal((1, 2, 5, 6))
931+
z = mx.random.normal((2, 3, 1, 7, 2))
932+
w = mx.random.normal((1, 2, 2, 3))
933+
934+
vmap_fn = mx.vmap(fn, in_axes=(0, None))
935+
cvmap_fn = mx.compile(vmap_fn, shapeless=True)
936+
937+
expected = vmap_fn(x, y)
938+
out = cvmap_fn(x, y)
939+
self.assertTrue(mx.array_equal(expected, out))
940+
self.assertEqual(5, counter[0])
941+
942+
expected = vmap_fn(z, w)
943+
out = cvmap_fn(z, w)
944+
self.assertTrue(mx.array_equal(expected, out))
945+
self.assertEqual(6, counter[0])
946+
902947

903948
if __name__ == "__main__":
904949
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)