Skip to content

Commit 776177b

Browse files
author
Lucas Fernandes Martins
committed
Adding linalg.det
1 parent 6cef1e9 commit 776177b

File tree

6 files changed

+219
-49
lines changed

6 files changed

+219
-49
lines changed

mlx/backend/cpu/luf.cpp

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ void luf_impl(
1616
array& lu,
1717
array& pivots,
1818
array& row_indices,
19-
Stream stream) {
19+
Stream stream,
20+
bool allow_singular /* = false */) {
2021
int M = a.shape(-2);
2122
int N = a.shape(-1);
2223
int K = std::min(M, N);
@@ -54,50 +55,60 @@ void luf_impl(
5455
encoder.set_output_array(pivots);
5556
encoder.set_output_array(row_indices);
5657

57-
encoder.dispatch(
58-
[a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K]() mutable {
59-
int info;
60-
for (size_t i = 0; i < num_matrices; ++i) {
61-
// Compute LU factorization of A
62-
getrf<T>(
63-
/* m */ &M,
64-
/* n */ &N,
65-
/* a */ a_ptr,
66-
/* lda */ &M,
67-
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
68-
/* info */ &info);
58+
encoder.dispatch([a_ptr,
59+
pivots_ptr,
60+
row_indices_ptr,
61+
num_matrices,
62+
M,
63+
N,
64+
K,
65+
allow_singular]() mutable {
66+
int info;
67+
for (size_t i = 0; i < num_matrices; ++i) {
68+
// Compute LU factorization of A
69+
getrf<T>(
70+
/* m */ &M,
71+
/* n */ &N,
72+
/* a */ a_ptr,
73+
/* lda */ &M,
74+
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
75+
/* info */ &info);
6976

70-
if (info != 0) {
71-
std::stringstream ss;
72-
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
73-
<< ((info > 0) ? " because matrix is singular"
74-
: " because argument had an illegal value");
75-
throw std::runtime_error(ss.str());
76-
}
77+
if (info < 0) {
78+
std::stringstream ss;
79+
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
80+
<< " because argument had an illegal value";
81+
throw std::runtime_error(ss.str());
82+
} else if (info > 0 && !allow_singular) {
83+
std::stringstream ss;
84+
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
85+
<< " because matrix is singular";
86+
throw std::runtime_error(ss.str());
87+
}
7788

78-
// Subtract 1 to get 0-based index
79-
int j = 0;
80-
for (; j < K; ++j) {
81-
pivots_ptr[j]--;
82-
row_indices_ptr[j] = j;
83-
}
84-
for (; j < M; ++j) {
85-
row_indices_ptr[j] = j;
86-
}
87-
for (int j = K - 1; j >= 0; --j) {
88-
auto piv = pivots_ptr[j];
89-
auto t1 = row_indices_ptr[piv];
90-
auto t2 = row_indices_ptr[j];
91-
row_indices_ptr[j] = t1;
92-
row_indices_ptr[piv] = t2;
93-
}
89+
// Subtract 1 to get 0-based index
90+
int j = 0;
91+
for (; j < K; ++j) {
92+
pivots_ptr[j]--;
93+
row_indices_ptr[j] = j;
94+
}
95+
for (; j < M; ++j) {
96+
row_indices_ptr[j] = j;
97+
}
98+
for (int j = K - 1; j >= 0; --j) {
99+
auto piv = pivots_ptr[j];
100+
auto t1 = row_indices_ptr[piv];
101+
auto t2 = row_indices_ptr[j];
102+
row_indices_ptr[j] = t1;
103+
row_indices_ptr[piv] = t2;
104+
}
94105

95-
// Advance pointers to the next matrix
96-
a_ptr += M * N;
97-
pivots_ptr += K;
98-
row_indices_ptr += M;
99-
}
100-
});
106+
// Advance pointers to the next matrix
107+
a_ptr += M * N;
108+
pivots_ptr += K;
109+
row_indices_ptr += M;
110+
}
111+
});
101112
}
102113

103114
void LUF::eval_cpu(
@@ -106,10 +117,22 @@ void LUF::eval_cpu(
106117
assert(inputs.size() == 1);
107118
switch (inputs[0].dtype()) {
108119
case float32:
109-
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
120+
luf_impl<float>(
121+
inputs[0],
122+
outputs[0],
123+
outputs[1],
124+
outputs[2],
125+
stream(),
126+
allow_singular_);
110127
break;
111128
case float64:
112-
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2], stream());
129+
luf_impl<double>(
130+
inputs[0],
131+
outputs[0],
132+
outputs[1],
133+
outputs[2],
134+
stream(),
135+
allow_singular_);
113136
break;
114137
default:
115138
throw std::runtime_error(

mlx/linalg.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,10 @@ void validate_lu(
583583
}
584584
}
585585

586-
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
586+
std::vector<array> lu_helper(
587+
const array& a,
588+
StreamOrDevice s /* = {} */,
589+
bool allow_singular = false) {
587590
int m = a.shape()[a.shape().size() - 2];
588591
int n = a.shape()[a.shape().size() - 1];
589592

@@ -595,14 +598,14 @@ std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
595598
return array::make_arrays(
596599
{a.shape(), pivots_shape, row_idx_shape},
597600
{a.dtype(), uint32, uint32},
598-
std::make_shared<LUF>(to_stream(s)),
601+
std::make_shared<LUF>(to_stream(s), allow_singular),
599602
{astype(a, a.dtype(), s)});
600603
}
601604

602605
std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
603606
validate_lu(a, s, "[linalg::lu]");
604607

605-
auto out = lu_helper(a, s);
608+
auto out = lu_helper(a, s, /*allow_singular=*/false);
606609
auto& LU = out[0];
607610
auto& row_pivots = out[2];
608611
auto L = tril(LU, /* k = */ -1, s);
@@ -705,4 +708,36 @@ array solve_triangular(
705708
return matmul(a_inv, b, s);
706709
}
707710

711+
void validate_det(
712+
const array& a,
713+
const StreamOrDevice& stream,
714+
const std::string& fname) {
715+
check_cpu_stream(stream, fname);
716+
check_float(a.dtype(), fname);
717+
if (a.ndim() != 2 || a.shape(-2) != a.shape(-1)) {
718+
std::ostringstream msg;
719+
msg << fname
720+
<< " For determinant to be calculated, array must have 2 dimensions. Received array "
721+
"with "
722+
<< a.ndim() << " dimensions.";
723+
throw std::invalid_argument(msg.str());
724+
}
725+
}
726+
727+
array det(const array& a, StreamOrDevice s /* = {} */) {
728+
validate_det(a, s, "[linalg::det]");
729+
730+
auto out = lu_helper(a, s, /*allow_singular=*/true);
731+
auto& LU = out[0];
732+
auto& pivots = out[1];
733+
734+
auto det_val = prod(diag(LU, 0, s), s);
735+
736+
auto indices = arange(pivots.shape(-1), pivots.dtype(), s);
737+
auto num_swaps = sum(not_equal(pivots, indices, s), -1, false, s);
738+
auto sign = power(array(-1.0f), astype(num_swaps, float32, s), s);
739+
740+
return multiply(sign, det_val, s);
741+
}
742+
708743
} // namespace mlx::core::linalg

mlx/linalg.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,6 @@ eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
112112
MLX_API std::pair<array, array>
113113
eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
114114

115-
} // namespace mlx::core::linalg
115+
MLX_API array det(const array& a, StreamOrDevice s = {});
116+
117+
} // namespace mlx::core::linalg

mlx/primitives.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2537,13 +2537,17 @@ class Eigh : public Primitive {
25372537
/* LU Factorization primitive. */
25382538
class LUF : public Primitive {
25392539
public:
2540-
explicit LUF(Stream stream) : Primitive(stream) {}
2540+
explicit LUF(Stream stream, bool allow_singular = false)
2541+
: Primitive(stream), allow_singular_(allow_singular) {}
25412542
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
25422543
override;
25432544
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
25442545
override;
25452546

25462547
DEFINE_NAME(LUF)
2548+
2549+
private:
2550+
bool allow_singular_{false};
25472551
};
25482552

25492553
} // namespace mlx::core

python/src/linalg.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,4 +660,31 @@ void init_linalg(nb::module_& parent_module) {
660660
Returns:
661661
array: The unique solution to the system ``AX = B``.
662662
)pbdoc");
663+
m.def(
664+
"det",
665+
&mx::linalg::det,
666+
"a"_a,
667+
nb::kw_only(),
668+
"stream"_a = nb::none(),
669+
nb::sig(
670+
"def det(a: array, *, stream: Union[None, Stream, Device] = None) -> array"),
671+
R"pbdoc(
672+
Compute the determinant of a square matrix.
673+
674+
This function computes the determinant using LU factorization.
675+
Singular matrices return a determinant of ``0``.
676+
677+
Args:
678+
a (array): Input square matrix. Must be 2-D.
679+
stream (Stream, optional): Stream or device. Defaults to ``None``
680+
in which case the default stream of the default device is used.
681+
682+
Returns:
683+
array: The determinant of the input matrix.
684+
685+
Example:
686+
>>> a = mx.array([[1., 2.], [3., 4.]])
687+
>>> mx.linalg.det(a)
688+
array(-2, dtype=float32)
689+
)pbdoc");
663690
}

python/tests/test_linalg.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,85 @@ def test_solve_triangular(self):
616616
expected = np.linalg.solve(a, b)
617617
self.assertTrue(np.allclose(result, expected))
618618

619+
def test_det(self):
620+
# 2x2 basic
621+
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
622+
self.assertTrue(abs(mx.linalg.det(a).item() - (-2.0)) < 1e-4)
623+
624+
# 3x3 identity
625+
a = mx.eye(3)
626+
self.assertTrue(abs(mx.linalg.det(a).item() - 1.0) < 1e-4)
627+
628+
# 1x1
629+
a = mx.array([[5.0]])
630+
self.assertTrue(abs(mx.linalg.det(a).item() - 5.0) < 1e-4)
631+
632+
# 2x2 singular
633+
a = mx.array([[1.0, 2.0], [2.0, 4.0]])
634+
self.assertTrue(abs(mx.linalg.det(a).item()) < 1e-3)
635+
636+
# 3x3 general
637+
a = mx.array([[1.0, 2.0, 3.0],
638+
[4.0, 5.0, 6.0],
639+
[7.0, 8.0, 0.0]])
640+
self.assertTrue(abs(mx.linalg.det(a).item() - 27.0) < 1e-4)
641+
642+
# 3x3 diagonal
643+
a = mx.array([[2.0, 0.0, 0.0],
644+
[0.0, 3.0, 0.0],
645+
[0.0, 0.0, 4.0]])
646+
self.assertTrue(abs(mx.linalg.det(a).item() - 24.0) < 1e-4)
647+
648+
# 3x3 upper triangular
649+
a = mx.array([[2.0, 1.0, 3.0],
650+
[0.0, 5.0, 7.0],
651+
[0.0, 0.0, 4.0]])
652+
self.assertTrue(abs(mx.linalg.det(a).item() - 40.0) < 1e-4)
653+
654+
# 2x2 permutation (negative det)
655+
a = mx.array([[0.0, 1.0], [1.0, 0.0]])
656+
self.assertTrue(abs(mx.linalg.det(a).item() - (-1.0)) < 1e-4)
657+
658+
# 4x4 general
659+
a = mx.array([[1.0, 0.0, 2.0, -1.0],
660+
[3.0, 0.0, 0.0, 5.0],
661+
[2.0, 1.0, 4.0, -3.0],
662+
[1.0, 0.0, 5.0, 0.0]])
663+
self.assertTrue(abs(mx.linalg.det(a).item() - 30.0) < 1e-4)
664+
665+
# 4x4 scaled identity
666+
a = 3.0 * mx.eye(4)
667+
self.assertTrue(abs(mx.linalg.det(a).item() - 81.0) < 1e-4)
668+
669+
# 3x3 zeros (singular)
670+
a = mx.zeros((3, 3))
671+
self.assertTrue(abs(mx.linalg.det(a).item()) < 1e-3)
672+
673+
# 2x2 all negative
674+
a = mx.array([[-1.0, -2.0], [-3.0, -4.0]])
675+
self.assertTrue(abs(mx.linalg.det(a).item() - (-2.0)) < 1e-4)
676+
677+
# float64
678+
a = mx.array([[1.0, 2.0], [3.0, 4.0]], dtype=mx.float64)
679+
self.assertTrue(abs(mx.linalg.det(a).item() - (-2.0)) < 1e-10)
680+
681+
def test_det_throws(self):
682+
# 1D array
683+
with self.assertRaises(Exception):
684+
mx.eval(mx.linalg.det(mx.array([1.0, 2.0, 3.0])))
685+
686+
# Non-square matrix
687+
with self.assertRaises(Exception):
688+
mx.eval(mx.linalg.det(mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])))
689+
690+
# 0D scalar
691+
with self.assertRaises(Exception):
692+
mx.eval(mx.linalg.det(mx.array(5.0)))
693+
694+
# Integer input
695+
with self.assertRaises(Exception):
696+
mx.eval(mx.linalg.det(mx.array([[1, 2], [3, 4]])))
697+
619698

620699
if __name__ == "__main__":
621700
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)