Skip to content

Commit 3c673a3

Browse files
authored
float support for lapack functions. (#3)
* float support for lapack functions. * adjusted lint rules.
1 parent 4b14418 commit 3c673a3

6 files changed

Lines changed: 350 additions & 7 deletions

File tree

.swiftlint.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ disabled_rules: # rule identifiers to exclude from running
33
- identifier_name
44
- force_try
55
- line_length
6+
- file_length
67
function_body_length:
78
- 100
89
excluded:

AccelerateArray.podspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Pod::Spec.new do |s|
1616
#
1717

1818
s.name = "AccelerateArray"
19-
s.version = "0.2.0"
19+
s.version = "0.3.0"
2020
s.summary = "Swift Array Extensions for the Apple Accelerate Framework"
2121

2222
# This description is used to generate tags and improve search results.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ additional types, which can be easily built on top of this package.
2323

2424
### Swift Package Manager
2525
dependencies: [
26-
.package(url: "https://github.com/dastrobu/AccelerateArray.git", from: "0.2.0"),
26+
.package(url: "https://github.com/dastrobu/AccelerateArray.git", from: "0.3.0"),
2727
],
2828
2929
### Cocoa Pods

Sources/AccelerateArray/lapack.swift

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,207 @@ public enum LapackError: Error {
1515
///
1616
/// Float array extension
1717
public extension Array where Element == Float {
18+
/// SGETRF computes an LU factorization of a general M-by-N matrix A
19+
/// using partial pivoting with row interchanges.
20+
///
21+
/// The factorization has the form
22+
/// A = P * L * U
23+
/// where P is a permutation matrix, L is lower triangular with unit
24+
/// diagonal elements (lower trapezoidal if m > n), and U is upper
25+
/// triangular (upper trapezoidal if m < n).
26+
///
27+
/// This is the right-looking Level 3 BLAS version of the algorithm.
28+
///
29+
/// This array must be in column major storage.
30+
///
31+
/// http://www.netlib.org/lapack/explore-html/d8/ddc/group__real_g_ecomputational_ga8d99c11b94db3d5eac75cac46a0f2e17.html#ga8d99c11b94db3d5eac75cac46a0f2e17
32+
///
33+
/// - Parameters:
34+
/// - m: number of rows
35+
/// - n: number of columns
36+
///
37+
/// - Returns: The pivot indices; for 1 <= i <= min(M,N), row i of the matrix was interchanged with row IPIV(i).
38+
mutating func getrf(m: Int, n: Int) throws -> [Int32] {
39+
var ipiv = [Int32](repeating: 0, count: Swift.min(m, n))
40+
try getrf(m: m, n: n, ipiv: &ipiv)
41+
return ipiv
42+
}
43+
44+
/// SGETRF computes an LU factorization of a general M-by-N matrix A
45+
/// using partial pivoting with row interchanges.
46+
///
47+
/// The factorization has the form
48+
/// A = P * L * U
49+
/// where P is a permutation matrix, L is lower triangular with unit
50+
/// diagonal elements (lower trapezoidal if m > n), and U is upper
51+
/// triangular (upper trapezoidal if m < n).
52+
///
53+
/// This is the right-looking Level 3 BLAS version of the algorithm.
54+
///
55+
/// This array must be in column major storage.
56+
///
57+
/// http://www.netlib.org/lapack/explore-html/d8/ddc/group__real_g_ecomputational_ga8d99c11b94db3d5eac75cac46a0f2e17.html#ga8d99c11b94db3d5eac75cac46a0f2e17
58+
///
59+
/// - Parameters:
60+
/// - m: number of rows
61+
/// - n: number of columns
62+
/// - ipiv: the pivot indices; for 1 <= i <= min(m,n), row i of the matrix was interchanged with row ipiv[i].
63+
mutating func getrf(m m_: Int, n n_: Int, ipiv: inout [Int32]) throws {
64+
var m = Int32(m_)
65+
assert(m >= 0, "\(m) >= 0")
66+
var n = Int32(n_)
67+
assert(n >= 0, "\(n) >= 0")
68+
// leading dimension is the number of rows in column major order
69+
var lda = Int32(m)
70+
assert(lda >= Swift.max(1, m), "\(lda) >= max(1, \(m)")
71+
assert(count == lda * n, "\(count) == (\(lda),\(n))")
72+
73+
assert(ipiv.count >= Swift.min(m, n), "\(ipiv.count) > min(\(m), \(n)")
74+
75+
var info: Int32 = 0
76+
sgetrf_(&m, &n, &self, &lda, &ipiv, &info)
77+
if info != 0 {
78+
throw LapackError.getrf(info)
79+
}
80+
}
81+
82+
/// SGETRI computes the inverse of a matrix using the LU factorization
83+
/// computed by DGETRF.
84+
///
85+
/// This method inverts U and then computes inv(A) by solving the system
86+
/// inv(A)*L = inv(U) for inv(A).
87+
///
88+
///
89+
mutating func getri() throws {
90+
var n = Int32(Double(count).squareRoot())
91+
assert(count == n * n, "\(count) == \(n) * \(n)")
92+
var ipiv = try getrf(m: Int(n), n: Int(n))
93+
94+
var lda = Int32(count / Int(n))
95+
assert(lda >= Swift.max(1, n), "\(lda) >= max(1, \(n)")
96+
97+
var info: Int32 = 0
98+
99+
// do optimal workspace query
100+
var lwork: Int32 = -1
101+
var work = [__CLPK_real](repeating: 0.0, count: 1)
102+
sgetri_(&n, &self, &lda, &ipiv, &work, &lwork, &info)
103+
if info != 0 {
104+
throw LapackError.getri(info)
105+
}
106+
107+
// retrieve optimal workspace
108+
lwork = Int32(work[0])
109+
work = [__CLPK_real](repeating: 0.0, count: Int(lwork))
110+
111+
// do the inversion
112+
sgetri_(&n, &self, &lda, &ipiv, &work, &lwork, &info)
113+
if info != 0 {
114+
throw LapackError.getri(info)
115+
}
116+
}
117+
118+
/// DGESV computes the solution to a real system of linear equations
119+
/// A * X = B,
120+
/// where A is an N-by-N matrix and X and B are N-by-NRHS matrices.
121+
///
122+
/// The LU decomposition with partial pivoting and row interchanges is
123+
/// used to factor A as
124+
/// A = P * L * U,
125+
/// where P is a permutation matrix, L is unit lower triangular, and U is
126+
/// upper triangular. The factored form of A is then used to solve the
127+
/// system of equations A * X = B.
128+
///
129+
/// This array must be in column major storage.
130+
///
131+
/// http://www.netlib.org/lapack/explore-html/d8/ddc/group__real_g_ecomputational_ga1af62182327d0be67b1717db399d7d83.html#ga1af62182327d0be67b1717db399d7d83
132+
mutating func gesv(B: inout [Element]) throws {
133+
var ipiv: [Int32] = [Int32].init(repeating: 0, count: n)
134+
try gesv(ipiv: &ipiv, B: &B)
135+
}
136+
137+
/// SGESV computes the solution to a real system of linear equations
138+
/// A * X = B,
139+
/// where A is an N-by-N matrix and X and B are N-by-NRHS matrices.
140+
///
141+
/// The LU decomposition with partial pivoting and row interchanges is
142+
/// used to factor A as
143+
/// A = P * L * U,
144+
/// where P is a permutation matrix, L is unit lower triangular, and U is
145+
/// upper triangular. The factored form of A is then used to solve the
146+
/// system of equations A * X = B.
147+
///
148+
/// This array and B must be in column major storage.
149+
///
150+
/// http://www.netlib.org/lapack/explore-html/d8/ddc/group__real_g_ecomputational_ga461f4ac32685a5ca30e293ee73d32920.html#ga461f4ac32685a5ca30e293ee73d32920
151+
mutating func gesv(ipiv: inout [Int32], B: inout [Element]) throws {
152+
var n = Int32(self.n)
153+
assert(count == n * n, "\(count) == \(n) * \(n)")
154+
assert(ipiv.count == n, "\(ipiv.count) == \(n)")
18155

156+
var nrhs = Int32(B.count / Int(n))
157+
assert(nrhs >= 1, "\(nrhs) >= 1")
158+
assert(B.count == nrhs * n, "\(B.count) == \(nrhs) * \(n)")
159+
160+
var lda = n
161+
assert(lda >= Swift.max(1, n), "\(lda) >= max(1, \(n)")
162+
163+
var ldb = Int32(B.count / Int(nrhs))
164+
assert(ldb * nrhs == B.count, "\(ldb) * \(nrhs) == \(B.count)")
165+
assert(ldb >= Swift.max(1, n), "\(ldb) >= max(1, \(n))")
166+
167+
var info: Int32 = 0
168+
sgesv_(&n, &nrhs, &self, &lda, &ipiv, &B, &ldb, &info)
169+
if info != 0 {
170+
throw LapackError.dgesv(info)
171+
}
172+
}
173+
174+
/// SGTSV solves the equation
175+
///
176+
/// A*X = B,
177+
///
178+
/// where A is an n by n tridiagonal matrix, by Gaussian elimination with
179+
/// partial pivoting.
180+
///
181+
/// Note that the equation A**T*X = B may be solved by interchanging the
182+
/// order of the arguments DU and DL.
183+
///
184+
/// This array represents the diagonal of A.
185+
///
186+
/// http://www.netlib.org/lapack/explore-html/d1/d88/group__real_g_tsolve_gae1cbb7cd9c376c9cc72575d472eba346.html#gae1cbb7cd9c376c9cc72575d472eba346
187+
///
188+
/// - Parameters:
189+
/// - nrhs: The number of right hand sides, i.e., the number of columns of the matrix B. NRHS >= 0.
190+
/// - dl: On entry, DL must contain the (n-1) sub-diagonal elements of A.
191+
/// On exit, DL is overwritten by the (n-2) elements of the
192+
/// second super-diagonal of the upper triangular matrix U from
193+
/// the LU factorization of A, in DL(1), ..., DL(n-2).
194+
/// - du: On entry, DU must contain the (n-1) super-diagonal elements of A.
195+
/// On exit, DU is overwritten by the (n-1) elements of the first
196+
/// super-diagonal of U.
197+
/// - B: On entry, the N by NRHS matrix of right hand side matrix B.
198+
/// On exit, if no error was thrown, the N by NRHS solution matrix X.
199+
///
200+
mutating func gtsv(nrhs: Int, dl: inout [Element], du: inout [Element], B: inout [Element]) throws {
201+
assert(count - 1 == dl.count, "\(count) - 1 == \(dl.count)")
202+
assert(count - 1 == du.count, "\(count) - 1 == \(du.count)")
203+
var n = Int32(count)
204+
205+
var nrhs = Int32(B.count / Int(n))
206+
assert(nrhs >= 1, "\(nrhs) >= 1")
207+
assert(B.count == Int(nrhs) * count, "\(B.count) == \(nrhs) * \(n)")
208+
209+
var ldb = Int32(B.count / Int(nrhs))
210+
assert(ldb * nrhs == B.count, "\(ldb) * \(nrhs) == \(B.count)")
211+
assert(ldb >= Swift.max(1, n), "\(ldb) >= max(1, \(n))")
212+
213+
var info: Int32 = 0
214+
sgtsv_(&n, &nrhs, &dl, &self, &du, &B, &ldb, &info)
215+
if info != 0 {
216+
throw LapackError.dgesv(info)
217+
}
218+
}
19219
}
20220

21221
/// Array extension employing the LAPACK framework.
@@ -206,7 +406,7 @@ public extension Array where Element == Double {
206406
/// - B: On entry, the N by NRHS matrix of right hand side matrix B.
207407
/// On exit, if no error was thrown, the N by NRHS solution matrix X.
208408
///
209-
mutating func gtsv(nrhs: Int, dl: inout [Double], du: inout [Double], B: inout [Double]) throws {
409+
mutating func gtsv(nrhs: Int, dl: inout [Element], du: inout [Element], B: inout [Element]) throws {
210410
assert(count - 1 == dl.count, "\(count) - 1 == \(dl.count)")
211411
assert(count - 1 == du.count, "\(count) - 1 == \(du.count)")
212412
var n = Int32(count)

Tests/AccelerateArrayTests/lapack.swift

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,144 @@ import XCTest
22
@testable import AccelerateArray
33

44
class LapackTests: XCTestCase {
5+
func testGetrfFloat() throws {
6+
// A in row major
7+
let A: [Float] = [
8+
1.0, 2.0,
9+
3.0, 4.0,
10+
5.0, 6.0,
11+
7.0, 8.0
12+
]
13+
// convert A to col major
14+
var At = A.mtrans(m: 2, n: 4)
15+
let ipiv = try At.getrf(m: 4, n: 2)
16+
// convert solution to row major
17+
let X = At.mtrans(m: 4, n: 2)
18+
// L in row major
19+
let L: [Float] = [
20+
1.0, 0.0,
21+
X[2], 1.0,
22+
X[4], X[5],
23+
X[6], X[7],
24+
]
25+
// U in row major
26+
let U: [Float] = [
27+
X[0], X[1],
28+
0.0, X[3],
29+
]
30+
31+
// note, the indices in ipiv are one base (fortran)
32+
// construct the permutation vector
33+
// see: https://math.stackexchange.com/a/3112224/91477
34+
var p = [0, 1, 2, 3]
35+
for i in 0..<ipiv.count {
36+
p.swapAt(i, Int(ipiv[i] - 1))
37+
}
38+
39+
let n = 4
40+
var P: [Float] = Array(repeating: 0, count: n * n)
41+
for i in 0..<p.count {
42+
// i iterates columns of P (in row major)
43+
// p[i] indicates which element in the column must be set to one, to create the permutation matrix
44+
P[i + p[i] * n] = 1.0
45+
}
46+
47+
let PLU = P.mmul(B: L.mmul(B: U, m: 4, n: 2, p: 2), m: 4, n: 2, p: 4)
48+
XCTAssertEqual(A, PLU, accuracy: 1e-6)
49+
}
50+
51+
func testGetriFloat() throws {
52+
// inversion is independent of row/col major storage
53+
var A: [Float] = [
54+
1.0, 2.0,
55+
3.0, 4.0,
56+
]
57+
try A.getri()
58+
59+
let Ainv: [Float] = [
60+
-2.0, 1.0,
61+
1.5, -0.5,
62+
]
63+
64+
XCTAssertEqual(A, Ainv)
65+
}
66+
67+
func testGesvFloat() throws {
68+
// A in row major
69+
let A: [Float] = [
70+
1.0, 2.0,
71+
3.0, 4.0,
72+
]
73+
// convert A to col major
74+
var At = A.mtrans(m: 2, n: 2)
75+
let b: [Float] = [
76+
1.0, 1.0
77+
]
78+
// B in row major
79+
let B: [Float] = [
80+
1.0, 2.0, 3.0,
81+
1.0, 2.0, 3.0,
82+
]
83+
// B in col major
84+
var Bt = B.mtrans(m: 3, n: 2)
85+
// Ainv in row major
86+
let Ainv: [Float] = [
87+
-2.0, 1.0,
88+
1.5, -0.5,
89+
]
90+
let x1 = Ainv.mmul(B: b, m: 2, n: 1, p: 2)
91+
// X1 is in row major
92+
let X1 = Ainv.mmul(B: B, m: 2, n: 3, p: 2)
93+
94+
// solution is stored col major in Bt
95+
try At.gesv(B: &Bt)
96+
let X2 = Bt.mtrans(m: 2, n: 3)
97+
98+
XCTAssertEqual(x1[0], X1[0], accuracy: 1e-15)
99+
XCTAssertEqual(x1[1], X1[3], accuracy: 1e-15)
100+
XCTAssertEqual(X1, X2, accuracy: 1e-6)
101+
}
102+
103+
func testGtsvFloat() throws {
104+
// A in row major
105+
let A: [Float] = [
106+
1.0, 1.0, 0.0, 0.0,
107+
-1.0, 2.0, 2.0, 0.0,
108+
0.0, -2.0, 3.0, 3.0,
109+
0.0, 0.0, -3.0, 4.0,
110+
]
111+
// convert A to col major
112+
var At = A.mtrans(m: 4, n: 4)
113+
114+
// diagonals of A
115+
var d: [Float] = [1.0, 2.0, 3.0, 4.0, ]
116+
var du: [Float] = [1.0, 2.0, 3.0, ]
117+
var dl: [Float] = [-1.0, -2.0, -3.0, ]
118+
119+
// B in row major
120+
let B: [Float] = [
121+
1.0, 2.0,
122+
1.0, 2.0,
123+
1.0, 2.0,
124+
1.0, 2.0,
125+
]
126+
// B in col major
127+
var Bt = B.mtrans(m: 2, n: 4)
128+
// make a copy of Bt
129+
var Ct = Bt
130+
131+
// solve with general solver
132+
// solution is stored col major in Ct
133+
try At.gesv(B: &Ct)
134+
let X1 = Ct.mtrans(m: 4, n: 2)
135+
136+
// solution is stored col major in Bt
137+
try d.gtsv(nrhs: 2, dl: &dl, du: &du, B: &Bt)
138+
let X2 = Bt.mtrans(m: 4, n: 2)
139+
140+
XCTAssertEqual(X1, X2, accuracy: 1e-15)
141+
}
142+
5143
func testGetrfDouble() throws {
6144
// A in row major
7145
let A: [Double] = [
@@ -142,6 +280,10 @@ class LapackTests: XCTestCase {
142280

143281
static var allTests: [(String, (LapackTests) -> () throws -> Void)] {
144282
return [
283+
("testGetrfFloat", testGetrfFloat),
284+
("testGetriFloat", testGetriFloat),
285+
("testGesvFloat", testGesvFloat),
286+
("testGtsvFloat", testGtsvFloat),
145287
("testGetrfDouble", testGetrfDouble),
146288
("testGetriDouble", testGetriDouble),
147289
("testGesvDouble", testGesvDouble),

0 commit comments

Comments
 (0)