Skip to content

Commit 2545bb1

Browse files
committed
fix: introduce SQLResultSet to support column metadata for empty result sets across all providers
1 parent bd945fd commit 2545bb1

9 files changed

Lines changed: 163 additions & 43 deletions

File tree

Sources/CosmoMSSQL/MSSQLConnection.swift

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,18 @@ public final class MSSQLConnection: SQLDatabase, AdvancedSQLDatabase, @unchecked
584584

585585
/// Execute a query and return **all** result sets (e.g. from a stored procedure
586586
/// that contains multiple SELECT statements).
587-
public func queryMulti(_ sql: String, _ binds: [SQLValue] = []) async throws -> [[SQLRow]] {
587+
/// Execute a query and return the first result set as a ``SQLDataTable``.
588+
/// This is the sql-nio equivalent of .NET ``DataTable.Load(reader)``.
589+
public func queryTable(_ sql: String, _ binds: [SQLValue] = []) async throws -> SQLDataTable {
590+
let sets = try await queryMulti(sql, binds)
591+
if let first = sets.first {
592+
return SQLDataTable(name: sql, resultSet: first)
593+
}
594+
return SQLDataTable(name: sql, resultSet: SQLResultSet(columns: [], rows: []))
595+
}
596+
597+
598+
public func queryMulti(_ sql: String, _ binds: [SQLValue] = []) async throws -> [SQLResultSet] {
588599
guard !isClosed else { throw SQLError.connectionClosed }
589600
logger.debug("MSSQL queryMulti: \(sql.prefix(120))")
590601
return try await withTimeout(config.queryTimeout) {

Sources/CosmoMSSQL/MSSQLProcResult.swift

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@ import Foundation
55
public struct MSSQLProcResult: Sendable {
66

77
/// All result sets returned by the procedure (one per SELECT statement).
8-
public let resultSets: [[SQLRow]]
8+
public let resultSets: [SQLResultSet]
99

1010
/// First result set — convenience shorthand for `resultSets.first ?? []`.
11-
public var rows: [SQLRow] { resultSets.first ?? [] }
11+
public var rows: [SQLRow] { resultSets.first?.rows ?? [] }
12+
13+
/// First result set columns.
14+
public var columns: [SQLColumn] { resultSets.first?.columns ?? [] }
15+
1216

1317
/// Output parameter values keyed by name **including** the leading `@`
1418
/// (e.g. `outputParameters["@NewId"]`).
@@ -33,6 +37,6 @@ public struct MSSQLProcResult: Sendable {
3337
/// Decode result set at `index` into an array of `T`.
3438
public func decode<T: Decodable>(_ index: Int, as type: T.Type = T.self) throws -> [T] {
3539
guard index < resultSets.count else { return [] }
36-
return try resultSets[index].map { try SQLRowDecoder().decode(T.self, from: $0) }
40+
return try resultSets[index].rows.map { try SQLRowDecoder().decode(T.self, from: $0) }
3741
}
3842
}

Sources/CosmoMSSQL/TDS/TDSDecoder.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ struct TDSTokenDecoder {
1515
private var currentRows: [SQLRow] = []
1616

1717
// All completed result sets (flushed on COLMETADATA/DONE)
18-
private(set) var resultSets: [[SQLRow]] = []
18+
private(set) var resultSets: [SQLResultSet] = []
1919

2020
// First result set — convenience alias used by simple query callers.
2121
// Includes unflushed rows if no result sets were formally closed yet.
2222
var rows: [SQLRow] {
23-
if let first = resultSets.first { return first }
23+
if let first = resultSets.first { return first.rows }
2424
return currentRows
2525
}
2626

@@ -129,8 +129,8 @@ struct TDSTokenDecoder {
129129
return
130130
}
131131
// Flush any rows accumulated from a prior result set before starting a new one
132-
if !currentRows.isEmpty {
133-
resultSets.append(currentRows)
132+
if !columns.isEmpty {
133+
resultSets.append(SQLResultSet(columns: columns, rows: currentRows))
134134
currentRows = []
135135
}
136136
columns = []
@@ -191,8 +191,8 @@ struct TDSTokenDecoder {
191191
let count: UInt64 = buf.readInteger(endianness: .little) // rowCount (8 bytes in TDS 7.2+)
192192
else { throw TDSError.incomplete }
193193
// Flush current rows into resultSets on any DONE token
194-
if !currentRows.isEmpty {
195-
resultSets.append(currentRows)
194+
if !columns.isEmpty {
195+
resultSets.append(SQLResultSet(columns: columns, rows: currentRows))
196196
currentRows = []
197197
}
198198
// Only trust the rowcount when the DONE_COUNT bit (0x10) is set

Sources/CosmoMySQL/MySQLConnection.swift

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -362,17 +362,17 @@ public final class MySQLConnection: SQLDatabase, AdvancedSQLDatabase, @unchecked
362362
///
363363
/// Note: MySQL requires the `CLIENT_MULTI_STATEMENTS` capability for this to work.
364364
/// The driver negotiates this automatically during handshake.
365-
public func queryMulti(_ sql: String, _ binds: [SQLValue] = []) async throws -> [[SQLRow]] {
365+
public func queryMulti(_ sql: String, _ binds: [SQLValue] = []) async throws -> [SQLResultSet] {
366366
guard !isClosed else { throw SQLError.connectionClosed }
367367
let rendered = renderQuery(sql, binds: binds)
368368
logger.debug("MySQL queryMulti: \(rendered)")
369369
try await sendQuery(rendered)
370370

371-
var allSets: [[SQLRow]] = []
371+
var allSets: [SQLResultSet] = []
372372
// MySQL returns one result set at a time; the OK/EOF has a "more results" flag
373373
while true {
374374
let resultSet = try await readResultSetMulti()
375-
allSets.append(resultSet.rows)
375+
allSets.append(resultSet.resultSet)
376376
if !resultSet.hasMore { break }
377377
}
378378
return allSets
@@ -599,7 +599,7 @@ public final class MySQLConnection: SQLDatabase, AdvancedSQLDatabase, @unchecked
599599
// MARK: - Wire helpers
600600

601601
private struct ResultSetChunk {
602-
let rows: [SQLRow]
602+
let resultSet: SQLResultSet
603603
let hasMore: Bool
604604
}
605605

@@ -609,15 +609,15 @@ public final class MySQLConnection: SQLDatabase, AdvancedSQLDatabase, @unchecked
609609
let firstResponse = try MySQLResponse.decode(packet: &firstPacket, capabilities: capabilities)
610610
switch firstResponse {
611611
case .ok(_, _, let status, _):
612-
return ResultSetChunk(rows: [], hasMore: status.contains(.moreResultsExist))
612+
return ResultSetChunk(resultSet: SQLResultSet(columns: [], rows: []), hasMore: status.contains(.moreResultsExist))
613613
case .err(let code, _, let message):
614614
throw SQLError.serverError(code: Int(code), message: message)
615615
case .data(var countPacket):
616616
countPacket.moveReaderIndex(forwardBy: 4)
617617
let columnCount = countPacket.readLengthEncodedInt() ?? 0
618618
return try await readColumnsMulti(count: Int(columnCount))
619619
default:
620-
return ResultSetChunk(rows: [], hasMore: false)
620+
return ResultSetChunk(resultSet: SQLResultSet(columns: [], rows: []), hasMore: false)
621621
}
622622
}
623623

@@ -655,7 +655,7 @@ public final class MySQLConnection: SQLDatabase, AdvancedSQLDatabase, @unchecked
655655
_ = pkt.readInteger(endianness: .little) as UInt16? // warnings
656656
let statusRaw = pkt.readInteger(endianness: .little) as UInt16? ?? 0
657657
let status = MySQLServerStatus(rawValue: statusRaw)
658-
return ResultSetChunk(rows: rows, hasMore: status.contains(.moreResultsExist))
658+
return ResultSetChunk(resultSet: SQLResultSet(columns: sqlCols, rows: rows), hasMore: status.contains(.moreResultsExist))
659659
}
660660
if firstByte == 0x00 && capabilities.contains(.deprecateEOF) {
661661
// OK (deprecateEOF style)
@@ -664,7 +664,7 @@ public final class MySQLConnection: SQLDatabase, AdvancedSQLDatabase, @unchecked
664664
_ = pkt.readLengthEncodedInt() // last insert id
665665
let statusRaw = pkt.readInteger(endianness: .little) as UInt16? ?? 0
666666
let status = MySQLServerStatus(rawValue: statusRaw)
667-
return ResultSetChunk(rows: rows, hasMore: status.contains(.moreResultsExist))
667+
return ResultSetChunk(resultSet: SQLResultSet(columns: sqlCols, rows: rows), hasMore: status.contains(.moreResultsExist))
668668
}
669669

670670
// Data row
@@ -683,7 +683,7 @@ public final class MySQLConnection: SQLDatabase, AdvancedSQLDatabase, @unchecked
683683
}
684684
rows.append(SQLRow(columns: sqlCols, values: values))
685685
}
686-
return ResultSetChunk(rows: rows, hasMore: false)
686+
return ResultSetChunk(resultSet: SQLResultSet(columns: sqlCols, rows: rows), hasMore: false)
687687
}
688688

689689
private func send(_ buffer: ByteBuffer) {

Sources/CosmoPostgres/PostgresConnection.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,15 @@ public final class PostgresConnection: SQLDatabase, AdvancedSQLDatabase, @unchec
418418
///
419419
/// PostgreSQL allows multiple statements separated by `;` in a single query string.
420420
/// Each statement that returns rows produces one element in the returned array.
421-
public func queryMulti(_ sql: String, _ binds: [SQLValue] = []) async throws -> [[SQLRow]] {
421+
public func queryMulti(_ sql: String, _ binds: [SQLValue] = []) async throws -> [SQLResultSet] {
422422
guard !isClosed else { throw SQLError.connectionClosed }
423423
let rendered = renderQuery(sql, binds: binds)
424424
logger.debug("PostgreSQL queryMulti: \(rendered)")
425425

426426
let msg = PGFrontend.query(rendered, allocator: channel.allocator)
427427
send(msg)
428428

429-
var allSets: [[SQLRow]] = []
429+
var allSets: [SQLResultSet] = []
430430
var current: [SQLRow] = []
431431
var columns: [PGColumnDesc] = []
432432
var sqlCols: [SQLColumn] = [] // computed once per RowDescription
@@ -449,7 +449,7 @@ public final class PostgresConnection: SQLDatabase, AdvancedSQLDatabase, @unchec
449449
}
450450
case .commandComplete:
451451
if !current.isEmpty || !columns.isEmpty {
452-
allSets.append(current)
452+
allSets.append(SQLResultSet(columns: sqlCols, rows: current))
453453
current = []
454454
columns = []
455455
sqlCols = []

Sources/CosmoSQLCore/SQLDataTable.swift

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ public struct SQLDataTable: Sendable {
175175

176176
// MARK: Init from SQLRows
177177

178+
public init(name: String? = nil, resultSet: SQLResultSet) {
179+
self.name = name
180+
self.columns = resultSet.columns.map { SQLDataColumn(name: $0.name, table: $0.table) }
181+
self.rows = resultSet.rows.map { row in row.values.map { SQLCellValue($0) } }
182+
self._colIndex = Dictionary(uniqueKeysWithValues: self.columns.enumerated().map { ($1.name.lowercased(), $0) })
183+
}
184+
178185
public init(name: String? = nil, rows sqlRows: [SQLRow]) {
179186
self.name = name
180187
self.columns = (sqlRows.first?.columns ?? []).map {
@@ -351,11 +358,11 @@ extension Array where Element == SQLRow {
351358
}
352359
}
353360

354-
extension Array where Element == [SQLRow] {
361+
extension Array where Element == SQLResultSet {
355362
/// Convert multi-result-set rows to a `SQLDataSet`.
356363
public func asDataSet(names: [String?]? = nil) -> SQLDataSet {
357364
let tables = enumerated().map { (i, rows) in
358-
SQLDataTable(name: names?[safe: i] ?? nil, rows: rows)
365+
SQLDataTable(name: names?[safe: i] ?? nil, resultSet: rows)
359366
}
360367
return SQLDataSet(tables: tables)
361368
}

Sources/CosmoSQLCore/SQLRow.swift

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
/// A result set returned by a SQL query, containing column metadata and zero or more rows.
2+
public struct SQLResultSet: Sendable {
3+
public let columns: [SQLColumn]
4+
public let rows: [SQLRow]
5+
6+
public init(columns: [SQLColumn], rows: [SQLRow]) {
7+
self.columns = columns
8+
self.rows = rows
9+
}
10+
11+
public var isEmpty: Bool { rows.isEmpty }
12+
public var count: Int { rows.count }
13+
public subscript(index: Int) -> SQLRow { rows[index] }
14+
}
15+
116
/// A single row returned by a SQL query.
2-
///
3-
/// Access values by column name or zero-based index:
4-
/// ```swift
5-
/// let name = try row["name"].require().asString()
6-
/// let id = row[0].asInt64()
7-
/// ```
817
public struct SQLRow: Sendable {
918
public let columns: [SQLColumn]
1019
public let values: [SQLValue]
@@ -15,16 +24,10 @@ public struct SQLRow: Sendable {
1524
self.values = values
1625
}
1726

18-
// MARK: - Subscript by index
19-
2027
public subscript(index: Int) -> SQLValue {
2128
values[index]
2229
}
2330

24-
// MARK: - Subscript by column name (case-insensitive)
25-
26-
/// Returns the value for the first column whose name matches (case-insensitively).
27-
/// Returns `.null` if no such column exists.
2831
public subscript(column: String) -> SQLValue {
2932
let lower = column.lowercased()
3033
guard let idx = columns.firstIndex(where: { $0.name.lowercased() == lower }) else {
@@ -34,11 +37,7 @@ public struct SQLRow: Sendable {
3437
}
3538
}
3639

37-
// MARK: - Helpers
38-
3940
public extension SQLValue {
40-
/// Throws ``SQLError/columnNotFound(_:)`` when the value is `.null` and was
41-
/// produced by a missing column lookup.
4241
func require(column: String = "<unknown>") throws -> SQLValue {
4342
if case .null = self {
4443
throw SQLError.columnNotFound(column)

Sources/CosmoSQLite/SQLiteConnection.swift

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,18 +164,53 @@ public final class SQLiteConnection: SQLDatabase, AdvancedSQLDatabase, @unchecke
164164

165165
// MARK: - Multi-statement (split on ";")
166166

167-
public func queryMulti(_ sql: String, _ binds: [SQLValue] = []) async throws -> [[SQLRow]] {
167+
public func queryTable(_ sql: String, _ binds: [SQLValue] = []) async throws -> SQLDataTable {
168+
let stmts = sql.split(separator: ";", omittingEmptySubsequences: true).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }.filter { !$0.isEmpty }
169+
if let first = stmts.first {
170+
let rs = try await _runQuery(String(first), binds: binds)
171+
return SQLDataTable(name: String(first), resultSet: rs)
172+
}
173+
return SQLDataTable(name: sql, resultSet: SQLResultSet(columns: [], rows: []))
174+
}
175+
176+
177+
public func queryMulti(_ sql: String, _ binds: [SQLValue] = []) async throws -> [SQLResultSet] {
168178
let stmts = sql
169179
.split(separator: ";", omittingEmptySubsequences: true)
170180
.map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }
171181
.filter { !$0.isEmpty }
172-
var results: [[SQLRow]] = []
182+
var results: [SQLResultSet] = []
173183
for stmt in stmts {
174-
results.append(try await query(stmt, binds))
184+
results.append(try await _runQuery(String(stmt), binds: binds))
175185
}
176186
return results
177187
}
178188

189+
private func _runQuery(_ sql: String, binds: [SQLValue]) async throws -> SQLResultSet {
190+
let prepared = renderQuery(sql, binds: binds)
191+
return try await pool.runIfActive(eventLoop: group.next()) {
192+
guard let db = self.db else { throw SQLError.connectionClosed }
193+
var stmt: OpaquePointer?
194+
let rc = sqlite3_prepare_v2(db, prepared, -1, &stmt, nil)
195+
guard rc == SQLITE_OK, let stmt else { throw self.sqliteError(db: db, code: rc, context: "prepare") }
196+
defer { sqlite3_finalize(stmt) }
197+
try self.bindParams(stmt: stmt, binds: binds, db: db)
198+
let colCount = Int(sqlite3_column_count(stmt))
199+
let columns = self.makeColumns(stmt: stmt, count: colCount)
200+
var rows: [SQLRow] = []
201+
while true {
202+
let stepRc = sqlite3_step(stmt)
203+
if stepRc == SQLITE_ROW {
204+
let values = (0..<colCount).map { self.readColumn(stmt: stmt, index: Int32($0)) }
205+
rows.append(SQLRow(columns: columns, values: values))
206+
} else if stepRc == SQLITE_DONE { break } else {
207+
throw self.sqliteError(db: db, code: stepRc, context: "step")
208+
}
209+
}
210+
return SQLResultSet(columns: columns, rows: rows)
211+
}.get()
212+
}
213+
179214
// MARK: - Blocking internals (run on thread pool)
180215

181216
private func execQuery(_ sql: String, binds: [SQLValue]) throws -> [SQLRow] {
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import XCTest
2+
@testable import CosmoMSSQL
3+
import CosmoSQLCore
4+
5+
final class InvoiceHeaderTests: XCTestCase, @unchecked Sendable {
6+
7+
override func setUp() async throws {
8+
try skipUnlessIntegration()
9+
}
10+
11+
private func ensureProcedureExists(conn: MSSQLConnection) async throws {
12+
_ = try await conn.execute("IF OBJECT_ID('InvoiceHeader', 'P') IS NOT NULL DROP PROCEDURE InvoiceHeader;")
13+
14+
_ = try await conn.execute("CREATE PROCEDURE InvoiceHeader @TransactionID NVARCHAR(50), @FinancialYear INT, @AdminUser INT, @Language NVARCHAR(50) AS BEGIN SELECT @TransactionID as TransactionID, @FinancialYear as FinancialYear, @AdminUser as AdminUser, @Language as Language, 'Test' as DummyData; END")
15+
}
16+
17+
func testInvoiceHeaderProcedure_ShouldSucceed() {
18+
runAsync {
19+
try await TestDatabase.withConnection { conn in
20+
try await self.ensureProcedureExists(conn: conn)
21+
22+
let parameters: [SQLParameter] = [
23+
.init("1-C-96/25", name: "TransactionID"),
24+
.init(2025, name: "FinancialYear"),
25+
.init(1, name: "AdminUser"),
26+
.init("English", name: "Language")
27+
]
28+
29+
let result = try await conn.callProcedure("InvoiceHeader", parameters: parameters)
30+
31+
XCTAssertEqual(result.rows.count, 1)
32+
XCTAssertEqual(result.rows[0]["TransactionID"].asString(), "1-C-96/25")
33+
XCTAssertEqual(result.resultSets.first?.columns.count, 5)
34+
}
35+
}
36+
}
37+
38+
func testQueryTable_WithEmptyResult_ShouldHaveSchema() {
39+
runAsync {
40+
try await TestDatabase.withConnection { conn in
41+
let table = try await conn.queryTable("SELECT 'A' as Col1, 1 as Col2 WHERE 1=0")
42+
43+
XCTAssertEqual(table.rowCount, 0)
44+
XCTAssertEqual(table.columnCount, 2)
45+
XCTAssertEqual(table.columns[0].name, "Col1")
46+
XCTAssertEqual(table.columns[1].name, "Col2")
47+
}
48+
}
49+
}
50+
51+
func testQueryTable_WithInvoiceHeader_ShouldHaveData() {
52+
runAsync {
53+
try await TestDatabase.withConnection { conn in
54+
try await self.ensureProcedureExists(conn: conn)
55+
56+
let table = try await conn.queryTable("exec InvoiceHeader @TransactionID='1-C-96/25', @FinancialYear=2025, @AdminUser=1, @Language='English'")
57+
58+
XCTAssertTrue(table.columnCount > 0)
59+
XCTAssertTrue(table.rowCount > 0)
60+
XCTAssertEqual(table[0, "TransactionID"].displayString, "1-C-96/25")
61+
}
62+
}
63+
}
64+
}

0 commit comments

Comments
 (0)