Skip to content

Commit 8a434b5

Browse files
author
Lucas Fernandes Martins
committed
fixing formatting issues
1 parent 6ab983a commit 8a434b5

5 files changed

Lines changed: 90 additions & 78 deletions

File tree

mlx/backend/cpu/luf.cpp

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -55,54 +55,60 @@ void luf_impl(
5555
encoder.set_output_array(pivots);
5656
encoder.set_output_array(row_indices);
5757

58-
encoder.dispatch(
59-
[a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K, allow_singular]() mutable {
60-
int info;
61-
for (size_t i = 0; i < num_matrices; ++i) {
62-
// Compute LU factorization of A
63-
getrf<T>(
64-
/* m */ &M,
65-
/* n */ &N,
66-
/* a */ a_ptr,
67-
/* lda */ &M,
68-
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
69-
/* info */ &info);
58+
encoder.dispatch([a_ptr,
59+
pivots_ptr,
60+
row_indices_ptr,
61+
num_matrices,
62+
M,
63+
N,
64+
K,
65+
allow_singular]() mutable {
66+
int info;
67+
for (size_t i = 0; i < num_matrices; ++i) {
68+
// Compute LU factorization of A
69+
getrf<T>(
70+
/* m */ &M,
71+
/* n */ &N,
72+
/* a */ a_ptr,
73+
/* lda */ &M,
74+
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
75+
/* info */ &info);
7076

71-
if (info < 0) {
72-
std::stringstream ss;
73-
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
74-
<< " because argument had an illegal value";
75-
throw std::runtime_error(ss.str());
76-
} else if (info > 0 && !allow_singular) {
77-
std::stringstream ss;
78-
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
79-
<< " because matrix is singular";
80-
throw std::runtime_error(ss.str());
81-
}
77+
if (info < 0) {
78+
std::stringstream ss;
79+
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
80+
<< " because argument had an illegal value";
81+
throw std::runtime_error(ss.str());
82+
} else if (info > 0 && !allow_singular) {
83+
std::stringstream ss;
84+
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
85+
<< " because matrix is singular";
86+
throw std::runtime_error(ss.str());
87+
}
8288

83-
// Subtract 1 to get 0-based index
84-
int j = 0;
85-
for (; j < K; ++j) {
86-
pivots_ptr[j]--;
87-
row_indices_ptr[j] = j;
88-
}
89-
for (; j < M; ++j) {
90-
row_indices_ptr[j] = j;
91-
}
92-
for (int j = K - 1; j >= 0; --j) {
93-
auto piv = pivots_ptr[j];
94-
auto t1 = row_indices_ptr[piv];
95-
auto t2 = row_indices_ptr[j];
96-
row_indices_ptr[j] = t1;
97-
row_indices_ptr[piv] = t2;
98-
}
89+
// Subtract 1 to get 0-based index
90+
int j = 0;
91+
for (; j < K; ++j) {
92+
pivots_ptr[j]--;
93+
row_indices_ptr[j] = j;
94+
}
95+
for (; j < M; ++j) {
96+
row_indices_ptr[j] = j;
97+
}
98+
for (int j = K - 1; j >= 0; --j) {
99+
auto piv = pivots_ptr[j];
100+
auto t1 = row_indices_ptr[piv];
101+
auto t2 = row_indices_ptr[j];
102+
row_indices_ptr[j] = t1;
103+
row_indices_ptr[piv] = t2;
104+
}
99105

100-
// Advance pointers to the next matrix
101-
a_ptr += M * N;
102-
pivots_ptr += K;
103-
row_indices_ptr += M;
104-
}
105-
});
106+
// Advance pointers to the next matrix
107+
a_ptr += M * N;
108+
pivots_ptr += K;
109+
row_indices_ptr += M;
110+
}
111+
});
106112
}
107113

108114
void LUF::eval_cpu(
@@ -111,10 +117,22 @@ void LUF::eval_cpu(
111117
assert(inputs.size() == 1);
112118
switch (inputs[0].dtype()) {
113119
case float32:
114-
luf_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2], stream(), allow_singular_);
120+
luf_impl<float>(
121+
inputs[0],
122+
outputs[0],
123+
outputs[1],
124+
outputs[2],
125+
stream(),
126+
allow_singular_);
115127
break;
116128
case float64:
117-
luf_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2], stream(), allow_singular_);
129+
luf_impl<double>(
130+
inputs[0],
131+
outputs[0],
132+
outputs[1],
133+
outputs[2],
134+
stream(),
135+
allow_singular_);
118136
break;
119137
default:
120138
throw std::runtime_error(

mlx/linalg.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -584,9 +584,9 @@ void validate_lu(
584584
}
585585

586586
std::vector<array> lu_helper(
587-
const array& a,
588-
StreamOrDevice s /* = {} */,
589-
bool allow_singular = false) {
587+
const array& a,
588+
StreamOrDevice s /* = {} */,
589+
bool allow_singular = false) {
590590
int m = a.shape()[a.shape().size() - 2];
591591
int n = a.shape()[a.shape().size() - 1];
592592

@@ -605,7 +605,7 @@ std::vector<array> lu_helper(
605605
std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
606606
validate_lu(a, s, "[linalg::lu]");
607607

608-
auto out = lu_helper(a, s, /*allow_singular=*/ false);
608+
auto out = lu_helper(a, s, /*allow_singular=*/false);
609609
auto& LU = out[0];
610610
auto& row_pivots = out[2];
611611
auto L = tril(LU, /* k = */ -1, s);
@@ -629,9 +629,7 @@ std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
629629
return {row_pivots, L, U};
630630
}
631631

632-
std::pair<array, array> lu_factor(
633-
const array& a,
634-
StreamOrDevice s /* = {} */) {
632+
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
635633
validate_lu(a, s, "[linalg::lu_factor]");
636634
auto out = lu_helper(a, s);
637635
return std::make_pair(out[0], out[1]);
@@ -711,8 +709,8 @@ array solve_triangular(
711709
}
712710

713711
void validate_det(
714-
const array& a,
715-
const StreamOrDevice& stream,
712+
const array& a,
713+
const StreamOrDevice& stream,
716714
const std::string& fname) {
717715
check_cpu_stream(stream, fname);
718716
check_float(a.dtype(), fname);
@@ -723,10 +721,10 @@ void validate_det(
723721
"with "
724722
<< a.ndim() << " dimensions.";
725723
throw std::invalid_argument(msg.str());
726-
}
724+
}
727725
}
728726

729-
array det(const array& a, StreamOrDevice s /* = {} */){
727+
array det(const array& a, StreamOrDevice s /* = {} */) {
730728
validate_det(a, s, "[linalg::det]");
731729

732730
auto out = lu_helper(a, s, /*allow_singular=*/true);

mlx/linalg.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
112112
MLX_API std::pair<array, array>
113113
eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
114114

115-
MLX_API array
116-
det(const array& a, StreamOrDevice s = {});
115+
MLX_API array det(const array& a, StreamOrDevice s = {});
117116

118117
} // namespace mlx::core::linalg

mlx/primitives.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2537,8 +2537,8 @@ class Eigh : public Primitive {
25372537
/* LU Factorization primitive. */
25382538
class LUF : public Primitive {
25392539
public:
2540-
explicit LUF(Stream stream, bool allow_singular = false) : Primitive(stream),
2541-
allow_singular_(allow_singular) {}
2540+
explicit LUF(Stream stream, bool allow_singular = false)
2541+
: Primitive(stream), allow_singular_(allow_singular) {}
25422542
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
25432543
override;
25442544
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)

python/tests/test_det.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import pytest
21
import mlx.core as mx
2+
import pytest
33

44

55
class TestDet:
@@ -20,32 +20,30 @@ def test_2x2_singular(self):
2020
assert mx.linalg.det(a).item() == pytest.approx(0.0, abs=1e-3)
2121

2222
def test_3x3_general(self):
23-
a = mx.array([[1.0, 2.0, 3.0],
24-
[4.0, 5.0, 6.0],
25-
[7.0, 8.0, 0.0]])
23+
a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 0.0]])
2624
assert mx.linalg.det(a).item() == pytest.approx(27.0, abs=1e-4)
2725

2826
def test_3x3_diagonal(self):
29-
a = mx.array([[2.0, 0.0, 0.0],
30-
[0.0, 3.0, 0.0],
31-
[0.0, 0.0, 4.0]])
27+
a = mx.array([[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 4.0]])
3228
assert mx.linalg.det(a).item() == pytest.approx(24.0, abs=1e-4)
3329

3430
def test_3x3_upper_triangular(self):
35-
a = mx.array([[2.0, 1.0, 3.0],
36-
[0.0, 5.0, 7.0],
37-
[0.0, 0.0, 4.0]])
31+
a = mx.array([[2.0, 1.0, 3.0], [0.0, 5.0, 7.0], [0.0, 0.0, 4.0]])
3832
assert mx.linalg.det(a).item() == pytest.approx(40.0, abs=1e-4)
3933

4034
def test_2x2_permutation_negative_det(self):
4135
a = mx.array([[0.0, 1.0], [1.0, 0.0]])
4236
assert mx.linalg.det(a).item() == pytest.approx(-1.0, abs=1e-4)
4337

4438
def test_4x4_general(self):
45-
a = mx.array([[1.0, 0.0, 2.0, -1.0],
46-
[3.0, 0.0, 0.0, 5.0],
47-
[2.0, 1.0, 4.0, -3.0],
48-
[1.0, 0.0, 5.0, 0.0]])
39+
a = mx.array(
40+
[
41+
[1.0, 0.0, 2.0, -1.0],
42+
[3.0, 0.0, 0.0, 5.0],
43+
[2.0, 1.0, 4.0, -3.0],
44+
[1.0, 0.0, 5.0, 0.0],
45+
]
46+
)
4947
assert mx.linalg.det(a).item() == pytest.approx(30.0, abs=1e-4)
5048

5149
def test_4x4_scaled_identity(self):
@@ -66,8 +64,7 @@ def test_1d_throws(self):
6664

6765
def test_non_square_throws(self):
6866
with pytest.raises(Exception):
69-
mx.eval(mx.linalg.det(mx.array([[1.0, 2.0, 3.0],
70-
[4.0, 5.0, 6.0]])))
67+
mx.eval(mx.linalg.det(mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])))
7168

7269
def test_scalar_throws(self):
7370
with pytest.raises(Exception):

0 commit comments

Comments
 (0)