@@ -55,54 +55,60 @@ void luf_impl(
5555 encoder.set_output_array (pivots);
5656 encoder.set_output_array (row_indices);
5757
58- encoder.dispatch (
59- [a_ptr, pivots_ptr, row_indices_ptr, num_matrices, M, N, K, allow_singular]() mutable {
60- int info;
61- for (size_t i = 0 ; i < num_matrices; ++i) {
62- // Compute LU factorization of A
63- getrf<T>(
64- /* m */ &M,
65- /* n */ &N,
66- /* a */ a_ptr,
67- /* lda */ &M,
68- /* ipiv */ reinterpret_cast <int *>(pivots_ptr),
69- /* info */ &info);
58+ encoder.dispatch ([a_ptr,
59+ pivots_ptr,
60+ row_indices_ptr,
61+ num_matrices,
62+ M,
63+ N,
64+ K,
65+ allow_singular]() mutable {
66+ int info;
67+ for (size_t i = 0 ; i < num_matrices; ++i) {
68+ // Compute LU factorization of A
69+ getrf<T>(
70+ /* m */ &M,
71+ /* n */ &N,
72+ /* a */ a_ptr,
73+ /* lda */ &M,
74+ /* ipiv */ reinterpret_cast <int *>(pivots_ptr),
75+ /* info */ &info);
7076
71- if (info < 0 ) {
72- std::stringstream ss;
73- ss << " [LUF::eval_cpu] sgetrf_ failed with code " << info
74- << " because argument had an illegal value" ;
75- throw std::runtime_error (ss.str ());
76- } else if (info > 0 && !allow_singular) {
77- std::stringstream ss;
78- ss << " [LUF::eval_cpu] sgetrf_ failed with code " << info
79- << " because matrix is singular" ;
80- throw std::runtime_error (ss.str ());
81- }
77+ if (info < 0 ) {
78+ std::stringstream ss;
79+ ss << " [LUF::eval_cpu] sgetrf_ failed with code " << info
80+ << " because argument had an illegal value" ;
81+ throw std::runtime_error (ss.str ());
82+ } else if (info > 0 && !allow_singular) {
83+ std::stringstream ss;
84+ ss << " [LUF::eval_cpu] sgetrf_ failed with code " << info
85+ << " because matrix is singular" ;
86+ throw std::runtime_error (ss.str ());
87+ }
8288
83- // Subtract 1 to get 0-based index
84- int j = 0 ;
85- for (; j < K; ++j) {
86- pivots_ptr[j]--;
87- row_indices_ptr[j] = j;
88- }
89- for (; j < M; ++j) {
90- row_indices_ptr[j] = j;
91- }
92- for (int j = K - 1 ; j >= 0 ; --j) {
93- auto piv = pivots_ptr[j];
94- auto t1 = row_indices_ptr[piv];
95- auto t2 = row_indices_ptr[j];
96- row_indices_ptr[j] = t1;
97- row_indices_ptr[piv] = t2;
98- }
89+ // Subtract 1 to get 0-based index
90+ int j = 0 ;
91+ for (; j < K; ++j) {
92+ pivots_ptr[j]--;
93+ row_indices_ptr[j] = j;
94+ }
95+ for (; j < M; ++j) {
96+ row_indices_ptr[j] = j;
97+ }
98+ for (int j = K - 1 ; j >= 0 ; --j) {
99+ auto piv = pivots_ptr[j];
100+ auto t1 = row_indices_ptr[piv];
101+ auto t2 = row_indices_ptr[j];
102+ row_indices_ptr[j] = t1;
103+ row_indices_ptr[piv] = t2;
104+ }
99105
100- // Advance pointers to the next matrix
101- a_ptr += M * N;
102- pivots_ptr += K;
103- row_indices_ptr += M;
104- }
105- });
106+ // Advance pointers to the next matrix
107+ a_ptr += M * N;
108+ pivots_ptr += K;
109+ row_indices_ptr += M;
110+ }
111+ });
106112}
107113
108114void LUF::eval_cpu (
@@ -111,10 +117,22 @@ void LUF::eval_cpu(
111117 assert (inputs.size () == 1 );
112118 switch (inputs[0 ].dtype ()) {
113119 case float32:
114- luf_impl<float >(inputs[0 ], outputs[0 ], outputs[1 ], outputs[2 ], stream (), allow_singular_);
120+ luf_impl<float >(
121+ inputs[0 ],
122+ outputs[0 ],
123+ outputs[1 ],
124+ outputs[2 ],
125+ stream (),
126+ allow_singular_);
115127 break ;
116128 case float64:
117- luf_impl<double >(inputs[0 ], outputs[0 ], outputs[1 ], outputs[2 ], stream (), allow_singular_);
129+ luf_impl<double >(
130+ inputs[0 ],
131+ outputs[0 ],
132+ outputs[1 ],
133+ outputs[2 ],
134+ stream (),
135+ allow_singular_);
118136 break ;
119137 default :
120138 throw std::runtime_error (
0 commit comments