diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 495380c46f..1bee822bcd 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -14,6 +14,7 @@ Linear Algebra cholesky cholesky_inv cross + det qr svd eigvals @@ -23,5 +24,6 @@ Linear Algebra lu lu_factor pinv + slogdet solve solve_triangular diff --git a/mlx/backend/cpu/luf.cpp b/mlx/backend/cpu/luf.cpp index 5f1507e183..e7e34c4d57 100644 --- a/mlx/backend/cpu/luf.cpp +++ b/mlx/backend/cpu/luf.cpp @@ -67,11 +67,10 @@ void luf_impl( /* ipiv */ reinterpret_cast(pivots_ptr), /* info */ &info); - if (info != 0) { + if (info < 0) { std::stringstream ss; ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info - << ((info > 0) ? " because matrix is singular" - : " because argument had an illegal value"); + << " because argument had an illegal value"; throw std::runtime_error(ss.str()); } diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 7ac080dab8..db5dbd9771 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -705,4 +705,162 @@ array solve_triangular( return matmul(a_inv, b, s); } -} // namespace mlx::core::linalg \ No newline at end of file +void validate_det( + const array& a, + const StreamOrDevice& stream, + const std::string& fname) { + check_cpu_stream(stream, fname); + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << fname + << " Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument(fname + " Only defined for square matrices."); + } +} + +array det_raw_small(const array& a, StreamOrDevice s) { + int n = a.shape(-1); + + // Helper to extract a[..., i, j] from the last two dims + auto elem = [&](int i, int j) { + auto starts = Shape(a.ndim(), 0); + auto stops = a.shape(); + starts[a.ndim() - 2] = i; + stops[a.ndim() - 2] = i + 1; + starts[a.ndim() - 1] = j; + stops[a.ndim() - 1] = j + 1; + return squeeze(squeeze(slice(a, starts, stops, s), -1, s), -1, s); + }; + + if (n == 1) { + return elem(0, 0); + } else if (n == 2) { + return subtract( + multiply(elem(0, 0), elem(1, 1), s), + multiply(elem(0, 1), elem(1, 0), s), + s); + } else { + // 3x3: a00*(a11*a22 - a12*a21) - a01*(a10*a22 - a12*a20) + a02*(a10*a21 - + // a11*a20) + auto a00 = elem(0, 0), a01 = elem(0, 1), a02 = elem(0, 2); + auto a10 = elem(1, 0), a11 = elem(1, 1), a12 = elem(1, 2); + auto a20 = elem(2, 0), a21 = elem(2, 1), a22 = elem(2, 2); + return add( + subtract( + multiply( + a00, + subtract(multiply(a11, a22, s), multiply(a12, a21, s), s), + s), + multiply( + a01, + subtract(multiply(a10, a22, s), multiply(a12, a20, s), s), + s), + s), + multiply( + a02, subtract(multiply(a10, a21, s), multiply(a11, a20, s), s), s), + s); + } +} + +std::pair slogdet_impl(const array& input, StreamOrDevice s) { + int n = input.shape(-1); + auto dtype = input.dtype(); + + // Small-matrix fast path + if (n <= 3) { + auto raw = det_raw_small(input, s); + auto abs_raw = abs(raw, s); + auto sgn = sign(raw, s); + auto logabs = log(abs_raw, s); + return std::make_pair(sgn, logabs); + } + + // General LU-based path + auto [LU, pivots] = lu_factor(input, s); + + // Extract diagonal of U + auto diag = diagonal(LU, 0, -2, -1, s); + + // Permutation parity: count positions where pivot[i] != i + int k = std::min(input.shape(-2), input.shape(-1)); + auto iota = arange(0, k, uint32, s); + auto parity = astype( + sum(not_equal(pivots, iota, s), + /* axis = */ -1, + /* keepdims = */ false, + s), + int32, + s); + + // Count negative diagonal elements + auto num_neg = astype( + sum(less(diag, array(0.0f, dtype), s), + /* axis = */ -1, + /* keepdims = */ false, + s), + int32, + s); + + // sign = (-1)^(parity + num_neg) + auto total = add(parity, num_neg, s); + auto sign_val = astype( + subtract( + array(1, int32), + multiply(array(2, int32), remainder(total, array(2, int32), s), s), + s), + dtype, + s); + + // logabsdet = sum(log(abs(diag))) + auto logabsdet = + sum(log(abs(diag, s), s), /* axis = */ -1, /* keepdims = */ false, s); + + // Handle singular matrices: any zero on diagonal + auto is_zero = + any(equal(diag, array(0.0f, dtype), s), + /* axis = */ -1, + /* keepdims = */ false, + s); + sign_val = where(is_zero, array(0.0f, dtype), sign_val, s); + logabsdet = where( + is_zero, + array(-std::numeric_limits::infinity(), dtype), + logabsdet, + s); + + return std::make_pair(sign_val, logabsdet); +} + +std::pair slogdet(const array& a, StreamOrDevice s /* = {} */) { + validate_det(a, s, "[linalg::slogdet]"); + + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + return slogdet_impl(input, s); +} + +array det(const array& a, StreamOrDevice s /* = {} */) { + validate_det(a, s, "[linalg::det]"); + + auto dtype = at_least_float(a.dtype()); + auto input = astype(a, dtype, s); + int n = input.shape(-1); + + // Small-matrix fast path: compute directly, skip log/exp round-trip + if (n <= 3) { + return det_raw_small(input, s); + } + + // General case: det = sign * exp(logabsdet) + auto [sign_val, logabsdet] = slogdet_impl(input, s); + return multiply(sign_val, exp(logabsdet, s), s); +} + +} // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index fe3f83c223..08e0bdd7ad 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -112,4 +112,8 @@ eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); MLX_API std::pair eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); +MLX_API array det(const array& a, StreamOrDevice s = {}); + +MLX_API std::pair slogdet(const array& a, StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 58f9be7766..36b9040395 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -660,4 +660,77 @@ void init_linalg(nb::module_& parent_module) { Returns: array: The unique solution to the system ``AX = B``. )pbdoc"); + + m.def( + "det", + &mx::linalg::det, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def det(a: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the determinant of a square matrix. + + This function supports arrays with at least 2 dimensions. When the + input has more than two dimensions, the determinant is computed for + each matrix in the last two dimensions. + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The determinant(s) of the input matrix (matrices). + + Example: + >>> A = mx.array([[1., 2.], [3., 4.]]) + >>> mx.linalg.det(A, stream=mx.cpu) + array(-2, dtype=float32) + )pbdoc"); + + m.def( + "slogdet", + [](const mx::array& a, mx::StreamOrDevice s) { + auto result = mx::linalg::slogdet(a, s); + return nb::make_tuple(result.first, result.second); + }, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def slogdet(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), + R"pbdoc( + Compute the sign and natural log of the absolute value of the + determinant of a square matrix. + + This function supports arrays with at least 2 dimensions. When the + input has more than two dimensions, the sign and log-absolute-determinant + are computed for each matrix in the last two dimensions. + + For a singular matrix, ``sign`` is 0 and ``logabsdet`` is ``-inf``. + + The determinant can be reconstructed as ``det = sign * exp(logabsdet)``. + This is more numerically stable than computing the determinant directly + for matrices with large or small determinants. + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + tuple(array, array): The ``sign`` and ``logabsdet`` of the + determinant. ``sign`` is -1, 0, or +1. ``logabsdet`` is the + natural log of the absolute value of the determinant. + + Example: + >>> A = mx.array([[1., 2.], [3., 4.]]) + >>> sign, logabsdet = mx.linalg.slogdet(A, stream=mx.cpu) + >>> sign + array(-1, dtype=float32) + >>> logabsdet + array(0.693147, dtype=float32) + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 8e2444f206..3efb10e23f 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -520,6 +520,19 @@ def test_lu(self): P, L, U = mx.linalg.lu(a, stream=mx.cpu) self.assertTrue(mx.allclose(L[P, :] @ U, a)) + # Test singular matrix (should not throw) + a = mx.array( + [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 4.0, 6.0, 8.0], + [0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 1.0], + ] + ) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + L_permuted = mx.take_along_axis(L, P[..., None], axis=-2) + self.assertTrue(mx.allclose(L_permuted @ U, a)) + def test_lu_factor(self): mx.random.seed(7) @@ -616,6 +629,210 @@ def test_solve_triangular(self): expected = np.linalg.solve(a, b) self.assertTrue(np.allclose(result, expected)) + def test_det(self): + # 1x1 fast path + A = mx.array([[5.0]]) + self.assertTrue(np.allclose(mx.linalg.det(A, stream=mx.cpu), 5.0)) + + # 2x2 fast path + A = mx.array([[1.0, 2.0], [3.0, 4.0]]) + d = mx.linalg.det(A, stream=mx.cpu) + self.assertTrue(np.allclose(d, -2.0)) + + # 3x3 fast path + A = mx.array([[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]]) + d = mx.linalg.det(A, stream=mx.cpu) + expected = np.linalg.det(np.array(A)) + self.assertTrue(np.allclose(d, expected, atol=1e-5)) + + # 4x4 LU path: compare with numpy + np.random.seed(42) + A_np = np.random.randn(4, 4).astype(np.float32) + A_mx = mx.array(A_np) + d_mx = mx.linalg.det(A_mx, stream=mx.cpu) + d_np = np.linalg.det(A_np) + self.assertTrue(np.allclose(d_mx, d_np, atol=1e-4)) + + # 5x5 LU path + A_np = np.random.randn(5, 5).astype(np.float32) + A_mx = mx.array(A_np) + d_mx = mx.linalg.det(A_mx, stream=mx.cpu) + d_np = np.linalg.det(A_np) + self.assertTrue(np.allclose(d_mx, d_np, atol=1e-4)) + + # Identity matrix + A = mx.eye(5) + self.assertTrue(np.allclose(mx.linalg.det(A, stream=mx.cpu), 1.0)) + + # Batched: (3, 4, 4) + A_np = np.random.randn(3, 4, 4).astype(np.float32) + A_mx = mx.array(A_np) + d_mx = mx.linalg.det(A_mx, stream=mx.cpu) + d_np = np.linalg.det(A_np) + self.assertTrue(np.allclose(d_mx, d_np, atol=1e-4)) + + # Multi-batch: (2, 3, 3, 3) + A_np = np.random.randn(2, 3, 3, 3).astype(np.float32) + A_mx = mx.array(A_np) + d_mx = mx.linalg.det(A_mx, stream=mx.cpu) + d_np = np.linalg.det(A_np) + self.assertTrue(np.allclose(d_mx, d_np, atol=1e-4)) + + # Integer input auto-promotes to float + A = mx.array([[1, 2], [3, 4]]) + d = mx.linalg.det(A, stream=mx.cpu) + self.assertTrue(np.allclose(d, -2.0)) + + # float64 + A_np = np.random.randn(4, 4).astype(np.float64) + A_mx = mx.array(A_np) + d_mx = mx.linalg.det(A_mx, stream=mx.cpu) + d_np = np.linalg.det(A_np) + self.assertTrue(np.allclose(d_mx, d_np, atol=1e-10)) + + # Singular 4x4 matrix (LU path): det should be 0 + A = mx.array( + [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 4.0, 6.0, 8.0], + [0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 1.0], + ] + ) + d = mx.linalg.det(A, stream=mx.cpu) + self.assertTrue(np.allclose(d, 0.0, atol=1e-5)) + + # Singular 5x5 matrix (LU path) + A_np = np.ones((5, 5), dtype=np.float32) + A_mx = mx.array(A_np) + d = mx.linalg.det(A_mx, stream=mx.cpu) + self.assertTrue(np.allclose(d, 0.0, atol=1e-5)) + + # Batched singular matrices (LU path) + A_np = np.array([np.diag([1.0, 2.0, 0.0, 3.0]), np.eye(4, dtype=np.float32)]) + A_mx = mx.array(A_np) + d_mx = mx.linalg.det(A_mx, stream=mx.cpu) + d_np = np.linalg.det(A_np) + self.assertTrue(np.allclose(d_mx, d_np, atol=1e-5)) + + # Error: non-square + with self.assertRaises(ValueError): + mx.linalg.det(mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), stream=mx.cpu) + + # Error: 1D + with self.assertRaises(ValueError): + mx.linalg.det(mx.array([1.0, 2.0]), stream=mx.cpu) + + def test_slogdet(self): + # 2x2: det = -2 => sign = -1, logabsdet = log(2) + A = mx.array([[1.0, 2.0], [3.0, 4.0]]) + sign, logabsdet = mx.linalg.slogdet(A, stream=mx.cpu) + self.assertTrue(np.allclose(sign, -1.0)) + self.assertTrue(np.allclose(logabsdet, np.log(2.0), atol=1e-5)) + + # Identity: sign = 1, logabsdet = 0 + A = mx.eye(4) + sign, logabsdet = mx.linalg.slogdet(A, stream=mx.cpu) + self.assertTrue(np.allclose(sign, 1.0)) + self.assertTrue(np.allclose(logabsdet, 0.0, atol=1e-6)) + + # Compare with numpy for random matrices + np.random.seed(42) + for n in [1, 2, 3, 4, 5]: + A_np = np.random.randn(n, n).astype(np.float32) + A_mx = mx.array(A_np) + sign_mx, logabs_mx = mx.linalg.slogdet(A_mx, stream=mx.cpu) + sign_np, logabs_np = np.linalg.slogdet(A_np) + with self.subTest(n=n): + self.assertTrue(np.allclose(sign_mx, sign_np, atol=1e-5)) + self.assertTrue(np.allclose(logabs_mx, logabs_np, atol=1e-4)) + + # Singular matrix 2x2 (fast path): sign = 0, logabsdet = -inf + A = mx.array([[1.0, 2.0], [2.0, 4.0]]) + sign, logabsdet = mx.linalg.slogdet(A, stream=mx.cpu) + self.assertEqual(float(sign), 0.0) + self.assertEqual(float(logabsdet), float("-inf")) + + # Singular 4x4 matrix (LU path): sign = 0, logabsdet = -inf + A = mx.array( + [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 4.0, 6.0, 8.0], + [0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 1.0], + ] + ) + sign, logabsdet = mx.linalg.slogdet(A, stream=mx.cpu) + self.assertEqual(float(sign), 0.0) + self.assertEqual(float(logabsdet), float("-inf")) + + # Singular 5x5 matrix (LU path): all-ones matrix + A = mx.array(np.ones((5, 5), dtype=np.float32)) + sign, logabsdet = mx.linalg.slogdet(A, stream=mx.cpu) + self.assertEqual(float(sign), 0.0) + self.assertEqual(float(logabsdet), float("-inf")) + + # Batched with mix of singular and non-singular (LU path) + A_np = np.array([np.diag([1.0, 2.0, 0.0, 3.0]), np.eye(4, dtype=np.float32)]) + A_mx = mx.array(A_np) + sign_mx, logabs_mx = mx.linalg.slogdet(A_mx, stream=mx.cpu) + sign_np, logabs_np = np.linalg.slogdet(A_np) + self.assertTrue(np.allclose(sign_mx, sign_np, atol=1e-5)) + # Check -inf for singular, 0.0 for identity + self.assertEqual(float(logabs_mx[0]), float("-inf")) + self.assertTrue(np.allclose(logabs_mx[1], 0.0, atol=1e-6)) + + # Batched + A_np = np.random.randn(3, 4, 4).astype(np.float32) + A_mx = mx.array(A_np) + sign_mx, logabs_mx = mx.linalg.slogdet(A_mx, stream=mx.cpu) + sign_np, logabs_np = np.linalg.slogdet(A_np) + self.assertTrue(np.allclose(sign_mx, sign_np, atol=1e-5)) + self.assertTrue(np.allclose(logabs_mx, logabs_np, atol=1e-4)) + + # Multi-batch + A_np = np.random.randn(2, 3, 3, 3).astype(np.float32) + A_mx = mx.array(A_np) + sign_mx, logabs_mx = mx.linalg.slogdet(A_mx, stream=mx.cpu) + sign_np, logabs_np = np.linalg.slogdet(A_np) + self.assertTrue(np.allclose(sign_mx, sign_np, atol=1e-5)) + self.assertTrue(np.allclose(logabs_mx, logabs_np, atol=1e-4)) + + # Numerical stability: large matrix where det overflows + # 0.1 * I_100 has det = 0.1^100 which underflows in float32 + # but slogdet should give sign=1, logabsdet = 100*log(0.1) + n = 100 + A = mx.array(0.1) * mx.eye(n) + sign, logabsdet = mx.linalg.slogdet(A, stream=mx.cpu) + self.assertTrue(np.allclose(sign, 1.0)) + self.assertTrue(np.allclose(logabsdet, n * np.log(0.1), atol=1e-3)) + + # Verify det = sign * exp(logabsdet) for non-singular cases + A_np = np.random.randn(5, 5).astype(np.float32) + A_mx = mx.array(A_np) + sign_mx, logabs_mx = mx.linalg.slogdet(A_mx, stream=mx.cpu) + det_mx = mx.linalg.det(A_mx, stream=mx.cpu) + reconstructed = float(sign_mx) * np.exp(float(logabs_mx)) + self.assertTrue(np.allclose(float(det_mx), reconstructed, rtol=1e-4)) + + # float64 + A_np = np.random.randn(4, 4).astype(np.float64) + A_mx = mx.array(A_np) + sign_mx, logabs_mx = mx.linalg.slogdet(A_mx, stream=mx.cpu) + sign_np, logabs_np = np.linalg.slogdet(A_np) + self.assertTrue(np.allclose(sign_mx, sign_np)) + self.assertTrue(np.allclose(logabs_mx, logabs_np, atol=1e-10)) + + # Error: non-square + with self.assertRaises(ValueError): + mx.linalg.slogdet( + mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), stream=mx.cpu + ) + + # Error: 1D + with self.assertRaises(ValueError): + mx.linalg.slogdet(mx.array([1.0, 2.0]), stream=mx.cpu) + if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 591f6910e7..4b229edfa5 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -637,3 +637,68 @@ TEST_CASE("test solve_triangluar") { expected = array({-3., 2., 3.}); CHECK(allclose(expected, result).item()); } + +TEST_CASE("test det") { + // 1x1 fast path + { + array a = array({5.0f}, {1, 1}); + auto d = det(a, Device::cpu); + CHECK_EQ(d.item(), doctest::Approx(5.0f)); + } + + // 2x2 fast path: det([[1,2],[3,4]]) = -2 + { + array a = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + auto d = det(a, Device::cpu); + CHECK_EQ(d.item(), doctest::Approx(-2.0f)); + } + + // 3x3 fast path: det([[1,2,3],[0,1,4],[5,6,0]]) = 1 + { + array a = + array({1.0f, 2.0f, 3.0f, 0.0f, 1.0f, 4.0f, 5.0f, 6.0f, 0.0f}, {3, 3}); + auto d = det(a, Device::cpu); + CHECK_EQ(d.item(), doctest::Approx(1.0f)); + } + + // 4x4 LU path: identity matrix det = 1 + { + array a = eye(4); + auto d = det(a, Device::cpu); + CHECK_EQ(d.item(), doctest::Approx(1.0f)); + } + + // Non-square should throw + CHECK_THROWS( + det(array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}), Device::cpu)); + + // 1D should throw + CHECK_THROWS(det(array({1.0f, 2.0f}), Device::cpu)); +} + +TEST_CASE("test slogdet") { + // 2x2: det = -2, so sign = -1, logabsdet = log(2) + { + array a = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + auto [s, logabs] = slogdet(a, Device::cpu); + CHECK_EQ(s.item(), doctest::Approx(-1.0f)); + CHECK_EQ(logabs.item(), doctest::Approx(std::log(2.0f))); + } + + // Identity: sign = 1, logabsdet = 0 + { + array a = eye(4); + auto [s, logabs] = slogdet(a, Device::cpu); + CHECK_EQ(s.item(), doctest::Approx(1.0f)); + CHECK_EQ(logabs.item(), doctest::Approx(0.0f)); + } + + // Singular: sign = 0, logabsdet = -inf + { + array a = array({1.0f, 2.0f, 2.0f, 4.0f}, {2, 2}); + auto [s, logabs] = slogdet(a, Device::cpu); + CHECK_EQ(s.item(), 0.0f); + CHECK(std::isinf(logabs.item())); + CHECK(logabs.item() < 0); + } +}