Skip to content

Commit 68cf2fd

Browse files
Fix mx.prod vjp for complex types (#3433)
1 parent c594e6e commit 68cf2fd

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

mlx/primitives.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3951,7 +3951,7 @@ std::vector<array> Reduce::vjp(
39513951
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
39523952
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
39533953
auto exclusive_prod = multiply(p1, p2, s);
3954-
return multiply(exclusive_prod, cotan, s);
3954+
return multiply(conjugate(exclusive_prod, s), cotan, s);
39553955
};
39563956

39573957
// To compute a numerically stable gradient for prod we need an exclusive

python/tests/test_autograd.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,27 @@ def test_reduce_jvp(self):
966966
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
967967
self.assertEqual(jout[0].item(), 0)
968968

969+
def test_complex_prod_vjp(self):
970+
def prod(x):
971+
return x.prod(axis=0)
972+
973+
primal = mx.random.normal((2, 20), dtype=mx.complex64)
974+
cotangent = mx.random.normal((20,), dtype=mx.complex64)
975+
976+
_, vjps = mx.vjp(prod, [primal], [cotangent])
977+
978+
expected = mx.stack(
979+
[mx.conj(primal[1]) * cotangent, mx.conj(primal[0]) * cotangent]
980+
)
981+
982+
# Check against hand-computed vjps
983+
self.assertTrue(mx.array_equal(vjps[0], expected))
984+
985+
# Ensure that prod agrees with multiply for complex values
986+
_, vjps_multiply = mx.vjp(mx.multiply, [primal[0], primal[1]], [cotangent])
987+
988+
self.assertTrue(mx.array_equal(mx.stack(vjps_multiply), vjps[0]))
989+
969990

970991
if __name__ == "__main__":
971992
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)