Skip to content

Commit 370e651

Browse files
authored
added support for DGTSV. (#2)
1 parent d0af31d commit 370e651

3 files changed

Lines changed: 100 additions & 8 deletions

File tree

AccelerateArray.podspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Pod::Spec.new do |s|
1616
#
1717

1818
s.name = "AccelerateArray"
19-
s.version = "0.0.2"
19+
s.version = "0.1.0"
2020
s.summary = "Swift Array Extensions for the Apple Accelerate Framework"
2121

2222
# This description is used to generate tags and improve search results.

Sources/AccelerateArray/lapack.swift

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,55 @@ public extension Array where Element == Double {
174174
assert(ldb >= Swift.max(1, n), "\(ldb) >= max(1, \(n))")
175175

176176
var info: Int32 = 0
177-
178177
dgesv_(&n, &nrhs, &self, &lda, &ipiv, &B, &ldb, &info)
179178
if info != 0 {
180179
throw LapackError.dgesv(info)
181180
}
182181
}
182+
183+
/// DGTSV solves the equation
184+
///
185+
/// A*X = B,
186+
///
187+
/// where A is an n by n tridiagonal matrix, by Gaussian elimination with
188+
/// partial pivoting.
189+
///
190+
/// Note that the equation A**T*X = B may be solved by interchanging the
191+
/// order of the arguments DU and DL.
192+
///
193+
/// This array represents the diagonal of A.
194+
///
195+
/// http://www.netlib.org/lapack/explore-html/d4/d62/group__double_g_tsolve_ga2bf93f2ddefa5e671866eb2191dc19d4.html#ga2bf93f2ddefa5e671866eb2191dc19d4
196+
///
197+
/// - Parameters:
198+
/// - nrhs: The number of right hand sides, i.e., the number of columns of the matrix B. NRHS >= 0.
199+
/// - dl: On entry, DL must contain the (n-1) sub-diagonal elements of A.
200+
/// On exit, DL is overwritten by the (n-2) elements of the
201+
/// second super-diagonal of the upper triangular matrix U from
202+
/// the LU factorization of A, in DL(1), ..., DL(n-2).
203+
/// - du: On entry, DU must contain the (n-1) super-diagonal elements of A.
204+
/// On exit, DU is overwritten by the (n-1) elements of the first
205+
/// super-diagonal of U.
206+
/// - B: On entry, the N by NRHS matrix of right hand side matrix B.
207+
/// On exit, if no error was thrown, the N by NRHS solution matrix X.
208+
///
209+
mutating func gtsv(nrhs: Int, dl: inout [Double], du: inout [Double], B: inout [Double]) throws {
210+
assert(count - 1 == dl.count, "\(count) - 1 == \(dl.count)")
211+
assert(count - 1 == du.count, "\(count) - 1 == \(du.count)")
212+
var n = Int32(count)
213+
214+
var nrhs = Int32(B.count / Int(n))
215+
assert(nrhs >= 1, "\(nrhs) >= 1")
216+
assert(B.count == Int(nrhs) * count, "\(B.count) == \(nrhs) * \(n)")
217+
218+
var ldb = Int32(B.count / Int(nrhs))
219+
assert(ldb * nrhs == B.count, "\(ldb) * \(nrhs) == \(B.count)")
220+
assert(ldb >= Swift.max(1, n), "\(ldb) >= max(1, \(n))")
221+
222+
var info: Int32 = 0
223+
dgtsv_(&n, &nrhs, &dl, &self, &du, &B, &ldb, &info)
224+
if info != 0 {
225+
throw LapackError.dgesv(info)
226+
}
227+
}
183228
}

Tests/AccelerateArrayTests/lapack.swift

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class LapackTests: XCTestCase {
3838

3939
let n = 4
4040
var P: [Double] = Array(repeating: 0, count: n * n)
41-
for i in 0 ..< p.count {
41+
for i in 0..<p.count {
4242
// i iterates columns of P (in row major)
4343
// p[i] indicates which element in the column must be set to one, to create the permutation matrix
4444
P[i + p[i] * n] = 1.0
@@ -72,6 +72,9 @@ class LapackTests: XCTestCase {
7272
]
7373
// convert A to col major
7474
var At = A.mtrans(m: 2, n: 2)
75+
let b: [Double] = [
76+
1.0, 1.0
77+
]
7578
// B in row major
7679
let B: [Double] = [
7780
1.0, 2.0, 3.0,
@@ -84,21 +87,65 @@ class LapackTests: XCTestCase {
8487
-2.0, 1.0,
8588
1.5, -0.5,
8689
]
87-
// x1 is in row major
88-
let x1 = Ainv.mmul(B: B, m: 2, n: 3, p: 2)
90+
let x1 = Ainv.mmul(B: b, m: 2, n: 1, p: 2)
91+
// X1 is in row major
92+
let X1 = Ainv.mmul(B: B, m: 2, n: 3, p: 2)
8993

90-
// solution is stored row major in Bt
94+
// solution is stored col major in Bt
9195
try At.gesv(B: &Bt)
92-
let x2 = Bt.mtrans(m: 2, n: 3)
96+
let X2 = Bt.mtrans(m: 2, n: 3)
97+
98+
XCTAssertEqual(x1[0], X1[0], accuracy: 1e-15)
99+
XCTAssertEqual(x1[1], X1[3], accuracy: 1e-15)
100+
XCTAssertEqual(X1, X2, accuracy: 1e-15)
101+
}
102+
103+
func testGtsvDouble() throws {
104+
// A in row major
105+
let A: [Double] = [
106+
1.0, 1.0, 0.0, 0.0,
107+
-1.0, 2.0, 2.0, 0.0,
108+
0.0, -2.0, 3.0, 3.0,
109+
0.0, 0.0, -3.0, 4.0,
110+
]
111+
// convert A to col major
112+
var At = A.mtrans(m: 4, n: 4)
113+
114+
// diagonals of A
115+
var d: [Double] = [1.0, 2.0, 3.0, 4.0, ]
116+
var du: [Double] = [1.0, 2.0, 3.0, ]
117+
var dl: [Double] = [-1.0, -2.0, -3.0, ]
118+
119+
// B in row major
120+
let B: [Double] = [
121+
1.0, 2.0,
122+
1.0, 2.0,
123+
1.0, 2.0,
124+
1.0, 2.0,
125+
]
126+
// B in col major
127+
var Bt = B.mtrans(m: 2, n: 4)
128+
// make a copy of Bt
129+
var Ct = Bt
130+
131+
// solve with general solver
132+
// solution is stored col major in Ct
133+
try At.gesv(B: &Ct)
134+
let X1 = Ct.mtrans(m: 4, n: 2)
135+
136+
// solution is stored col major in Bt
137+
try d.gtsv(nrhs: 2, dl: &dl, du: &du, B: &Bt)
138+
let X2 = Bt.mtrans(m: 4, n: 2)
93139

94-
XCTAssertEqual(x1, x2, accuracy: 1e-15)
140+
XCTAssertEqual(X1, X2, accuracy: 1e-15)
95141
}
96142

97143
static var allTests: [(String, (LapackTests) -> () throws -> Void)] {
98144
return [
99145
("testGetrfDouble", testGetrfDouble),
100146
("testGetriDouble", testGetriDouble),
101147
("testGesvDouble", testGesvDouble),
148+
("testGtsvDouble", testGtsvDouble),
102149
]
103150
}
104151
}

0 commit comments

Comments
 (0)