Skip to content

Commit dac6254

Browse files
committed
Removed fun rust guard style transactions
1 parent 9ba10a5 commit dac6254

11 files changed

Lines changed: 102 additions & 124 deletions

File tree

Sources/Feather/Connection.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ public protocol Connection: Actor {
1616
/// Cancels the observation for the given subscriber
1717
nonisolated func cancel(subscriber: DatabaseSubscriber)
1818

19-
func begin(
20-
_ transaction: TransactionKind
21-
) async throws(FeatherError) -> sending Transaction
22-
23-
func didCommit(transaction: borrowing Transaction)
19+
func begin<Output>(
20+
_ kind: Transaction.Kind,
21+
execute: (borrowing Transaction) throws -> Output
22+
) async throws -> Output
2423
}

Sources/Feather/ConnectionPool.swift

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,9 @@ public actor ConnectionPool: Sendable {
5151
// Turn on WAL mode
5252
try connection.execute(sql: "PRAGMA journal_mode=WAL;")
5353

54-
let tx = try Transaction(
55-
connection: connection,
56-
kind: .write,
57-
pool: nil
58-
)
59-
54+
let tx = try Transaction(connection: connection, kind: .write)
6055
try MigrationRunner.execute(migrations: migrations, tx: tx)
61-
62-
// We don't want an async inti so we can skip the reclaim to remove the await
63-
// and manually add it to the availableConnections manually.
64-
try tx.commitWithoutReclaim()
56+
try tx.commit()
6557

6658
self.availableConnections = [connection]
6759
}
@@ -71,47 +63,26 @@ public actor ConnectionPool: Sendable {
7163
return count >= limit
7264
}
7365

74-
/// Gives the connection back to the pool.
75-
func reclaim(
76-
connection: SQLiteConnection,
77-
txKind: TransactionKind
78-
) async {
79-
availableConnections.append(connection)
80-
alertAnyWaitersOfAvailableConnection()
81-
82-
if txKind == .write {
83-
await writeLock.unlock()
84-
}
85-
}
86-
}
87-
88-
extension ConnectionPool: Connection {
89-
public nonisolated func observe(subscriber: any DatabaseSubscriber) {
90-
observer.subscribe(subscriber: subscriber)
91-
}
92-
93-
public nonisolated func cancel(subscriber: any DatabaseSubscriber) {
94-
observer.cancel(subscriber: subscriber)
95-
}
96-
97-
public nonisolated func didCommit(transaction: borrowing Transaction) {
98-
observer.didCommit()
99-
}
100-
10166
/// Starts a transaction.
102-
public func begin(
103-
_ kind: TransactionKind
67+
private func begin(
68+
_ kind: Transaction.Kind
10469
) async throws(FeatherError) -> sending Transaction {
10570
// Writes must be exclusive, make sure to wait on any pending writes.
10671
if kind == .write {
10772
await writeLock.lock()
10873
}
10974

110-
return try await Transaction(
111-
connection: getConnection(),
112-
kind: kind,
113-
pool: self
114-
)
75+
return try await Transaction(connection: getConnection(), kind: kind)
76+
}
77+
78+
/// Gives the connection back to the pool.
79+
private func reclaim(tx: borrowing Transaction) async {
80+
availableConnections.append(tx.connection)
81+
alertAnyWaitersOfAvailableConnection()
82+
83+
if tx.kind == .write {
84+
await writeLock.unlock()
85+
}
11586
}
11687

11788
/// Will get, wait or create a connection to the database
@@ -147,3 +118,41 @@ extension ConnectionPool: Connection {
147118
waiter.resume(with: .success(connection))
148119
}
149120
}
121+
122+
extension ConnectionPool: Connection {
123+
public nonisolated func observe(subscriber: any DatabaseSubscriber) {
124+
observer.subscribe(subscriber: subscriber)
125+
}
126+
127+
public nonisolated func cancel(subscriber: any DatabaseSubscriber) {
128+
observer.cancel(subscriber: subscriber)
129+
}
130+
131+
/// Starts a transaction.
132+
public func begin<Output>(
133+
_ kind: Transaction.Kind,
134+
execute: (borrowing Transaction) throws -> Output
135+
) async throws -> Output {
136+
let tx = try await begin(kind)
137+
138+
// The `Result` wrapper seems weird, but allows us to keep
139+
// tx functions consuming. Cause we cannot call `commit` in
140+
// the `do` and on failure call `rollback` since it would
141+
// have been consumed in the `commit`.
142+
let result = Result {
143+
try execute(tx)
144+
}
145+
146+
await reclaim(tx: tx)
147+
148+
switch result {
149+
case .success(let output):
150+
try tx.commit()
151+
observer.didCommit()
152+
return output
153+
case .failure(let error):
154+
try tx.commitOrRollback()
155+
throw error
156+
}
157+
}
158+
}

Sources/Feather/DatabaseQuery.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
public protocol DatabaseQuery<Input, Output>: Query {
99
/// Whether the query requires a read or write transaction.
10-
var transactionKind: TransactionKind { get }
10+
var transactionKind: Transaction.Kind { get }
1111

1212
var connection: any Connection { get }
1313

@@ -19,10 +19,9 @@ public protocol DatabaseQuery<Input, Output>: Query {
1919

2020
public extension DatabaseQuery {
2121
func execute(with input: Input) async throws -> Output {
22-
let tx = try await connection.begin(transactionKind)
23-
let output = try execute(with: input, tx: tx)
24-
try await tx.commit()
25-
return output
22+
try await connection.begin(transactionKind) { tx in
23+
try execute(with: input, tx: tx)
24+
}
2625
}
2726

2827
func observe(with input: Input) -> any QueryObservation<Output> {

Sources/Feather/Migration.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ public struct MigrationRunner {
99
static let migrationTableName = "__featherMigrations"
1010

1111
public static func execute(migrations: [String], pool: ConnectionPool) async throws {
12-
let tx = try await pool.begin(.write)
13-
try execute(migrations: migrations, tx: tx)
14-
try tx.commit()
12+
try await pool.begin(.write) { tx in
13+
try execute(migrations: migrations, tx: tx)
14+
}
1515
}
1616

1717
public static func execute(migrations: [String], tx: borrowing Transaction) throws {

Sources/Feather/Queries/AnyDatabaseQuery.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public struct AnyDatabaseQuery<Input, Output>: DatabaseQuery
1010
where Input: Sendable, Output: Sendable
1111
{
1212
public let connection: any Connection
13-
public let transactionKind: TransactionKind
13+
public let transactionKind: Transaction.Kind
1414
public let execute: @Sendable (Input, borrowing Transaction) throws -> Output
1515

1616

@@ -28,7 +28,7 @@ public struct AnyDatabaseQuery<Input, Output>: DatabaseQuery
2828
/// - connection: The connection to execute the query with
2929
/// - execute: A closure to run on `execute`.
3030
public init(
31-
_ transactionKind: TransactionKind,
31+
_ transactionKind: Transaction.Kind,
3232
in connection: any Connection,
3333
execute: @escaping @Sendable (Input, borrowing Transaction) throws -> Output
3434
) {

Sources/Feather/Queries/Map.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ extension Queries.Map: DatabaseQuery where Base: DatabaseQuery {
5555
return base.connection
5656
}
5757

58-
public var transactionKind: TransactionKind {
58+
public var transactionKind: Transaction.Kind {
5959
return base.transactionKind
6060
}
6161

Sources/Feather/Queries/MapInput.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ extension Queries.MapInput: DatabaseQuery where Base: DatabaseQuery {
2929
return base.connection
3030
}
3131

32-
public var transactionKind: TransactionKind {
32+
public var transactionKind: Transaction.Kind {
3333
return base.transactionKind
3434
}
3535

Sources/Feather/Queries/Then.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ extension Queries {
2020
return first.connection
2121
}
2222

23-
public var transactionKind: TransactionKind {
23+
public var transactionKind: Transaction.Kind {
2424
return max(first.transactionKind, second.transactionKind)
2525
}
2626

Sources/Feather/SQL.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ public struct SQL: ExpressibleByStringLiteral, ExpressibleByStringInterpolation,
3030
output.append("(\(primitives.map { _ in "?" }.joined(separator: ",")))")
3131
parameters.append(contentsOf: primitives)
3232
}
33+
34+
public mutating func appendInterpolation<T>(raw: T) {
35+
output.append("\(raw)")
36+
}
3337
}
3438

3539
public init(stringLiteral value: String) {

Sources/Feather/Transaction.swift

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,88 +8,53 @@
88
/// A SQLite transaction.
99
public struct Transaction: ~Copyable {
1010
let connection: SQLiteConnection
11-
let kind: TransactionKind
11+
let kind: Kind
1212
let behavior: Behavior
13-
private var didCommit = false
14-
private let pool: ConnectionPool?
1513

1614
public enum Behavior: String, Sendable {
1715
case deferred = "DEFERRED"
1816
case immediate = "IMMEDIATE"
1917
case exclusive = "EXCLUSIVE"
2018
}
2119

20+
public enum Kind: Int, Sendable, Comparable {
21+
case read
22+
case write
23+
24+
public static func < (lhs: Kind, rhs: Kind) -> Bool {
25+
return lhs.rawValue < rhs.rawValue
26+
}
27+
}
28+
2229
init(
2330
connection: SQLiteConnection,
24-
kind: TransactionKind,
25-
behavior: Behavior = .deferred,
26-
pool: ConnectionPool?
31+
kind: Kind,
32+
behavior: Behavior = .deferred
2733
) throws(FeatherError) {
2834
self.connection = connection
2935
self.kind = kind
3036
self.behavior = behavior
31-
self.pool = pool
3237
try connection.execute(sql: "BEGIN \(behavior.rawValue) TRANSACTION;")
3338
}
3439

40+
/// Executes the raw SQL
3541
public func execute(sql: String) throws(FeatherError) {
3642
try connection.execute(sql: sql)
3743
}
3844

39-
public consuming func commit() async throws(FeatherError) {
40-
guard !didCommit else {
41-
// This should never happen since its ~Copyable in a consuming
42-
// function but cant hurt to double check
43-
throw .alreadyCommited
44-
}
45-
46-
didCommit = true
45+
/// Commits any changes to the db
46+
public consuming func commit() throws(FeatherError) {
4747
try connection.execute(sql: "COMMIT")
48-
49-
pool?.didCommit(transaction: self)
50-
51-
await pool?.reclaim(connection: connection, txKind: kind)
5248
}
5349

54-
consuming func commitWithoutReclaim() throws(FeatherError) {
55-
guard !didCommit else {
56-
throw .alreadyCommited
50+
/// Should be called on error. If it is a read then it will just commit
51+
/// but writes will be rolled back.
52+
public consuming func commitOrRollback() throws(FeatherError) {
53+
switch kind {
54+
case .read:
55+
try connection.execute(sql: "COMMIT")
56+
case .write:
57+
try connection.execute(sql: "ROLLBACK")
5758
}
58-
59-
didCommit = true
60-
try connection.execute(sql: "COMMIT")
61-
}
62-
63-
deinit {
64-
guard didCommit else { return }
65-
66-
do {
67-
// Did not commit, need to either auto commit or rollback the changes.
68-
switch kind {
69-
case .read:
70-
try connection.execute(sql: "COMMIT")
71-
case .write:
72-
try connection.execute(sql: "ROLLBACK")
73-
}
74-
75-
// Feels dirty having this task here but it cannot be done
76-
// in a synchronous way...
77-
Task { [pool, connection, kind] in
78-
await pool?.reclaim(connection: connection, txKind: kind)
79-
}
80-
} catch {
81-
assertionFailure("Failed to commit or rollback")
82-
}
83-
}
84-
}
85-
86-
public enum TransactionKind: Int, Sendable {
87-
case read
88-
case write
89-
}
90-
91-
extension TransactionKind: Comparable {
92-
public static func < (lhs: TransactionKind, rhs: TransactionKind) -> Bool {
93-
return lhs.rawValue < rhs.rawValue
9459
}
9560
}

0 commit comments

Comments
 (0)