Skip to content

Commit 103cc25

Browse files
authored
feat: implement thread-safe access to sessions (#4)
1 parent ecf7cef commit 103cc25

4 files changed

Lines changed: 246 additions & 0 deletions

File tree

Sources/DXProtocol/Session/Session.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ public struct Session: Codable {
109109

110110
try identityStore.saveIdentity(theirIdentityKey, for: address)
111111

112+
let lock = SessionLock(address: address)
113+
lock.lock()
114+
defer { lock.unlock() }
115+
112116
var session = try sessionStore.loadSession(for: address)
113117
if nil == session {
114118
session = Self(state: state)
@@ -154,6 +158,10 @@ public struct Session: Codable {
154158
direction: .receiving) else {
155159
throw DXError.untrustedIdentity("Abort processing PreKey Message for untrusted identity")
156160
}
161+
162+
let lock = SessionLock(address: address)
163+
lock.lock()
164+
defer { lock.unlock() }
157165

158166
let theirBaseKey = message.senderBaseKey
159167
let messageVersion = Int(message.messageVersion)
@@ -227,6 +235,10 @@ public struct Session: Codable {
227235
for address: ProtocolAddress,
228236
sessionStore: SessionStorable,
229237
identityStore: IdentityKeyStorable) throws -> MessageContainer {
238+
let lock = SessionLock(address: address)
239+
lock.lock()
240+
defer { lock.unlock() }
241+
230242
let result = try self.state.encrypt(
231243
data: data,
232244
sessionStore: sessionStore,
@@ -311,6 +323,10 @@ extension Session {
311323
identityStore: IdentityKeyStorable,
312324
preKeyStore: PreKeyStorable,
313325
signedPreKeyStore: SignedPreKeyStorable) throws -> Data {
326+
let lock = SessionLock(address: address)
327+
lock.lock()
328+
defer { lock.unlock() }
329+
314330
var session = try self.processPreKeyMessage(
315331
preKeyMessage,
316332
from: address,
@@ -347,6 +363,10 @@ extension Session {
347363
from address: ProtocolAddress,
348364
sessionStore: SessionStorable,
349365
identityStore: IdentityKeyStorable) throws -> Data {
366+
let lock = SessionLock(address: address)
367+
lock.lock()
368+
defer { lock.unlock() }
369+
350370
// This code is not covered by tests
351371
guard var session = try sessionStore.loadSession(for: address) else {
352372
throw DXError.sessionNotFound("Failed to find session while decrypting message")
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//
2+
// SessionLock.swift
3+
//
4+
//
5+
// Created by Andriy Vasyk on 19.02.2024.
6+
//
7+
8+
import Foundation
9+
10+
struct SessionLock {
11+
/// The underlying implementation of this session lock
12+
private var impl: NSRecursiveLock
13+
14+
/// Initialises a new lock for session that corresponds to specified protocol address
15+
/// - Parameter address: The protocol address that uniquely identifies session
16+
init(address: ProtocolAddress) {
17+
self.impl = SessionLockStorage.shared.lock(for: address)
18+
}
19+
20+
// MARK: - Interface
21+
22+
/// Blocks a thread’s execution until the lock can be acquired.
23+
func lock() {
24+
self.impl.lock()
25+
}
26+
27+
/// Relinquishes a previously acquired lock.
28+
func unlock() {
29+
self.impl.unlock()
30+
}
31+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//
2+
// SessionLockStorage.swift
3+
//
4+
//
5+
// Created by Andriy Vasyk on 19.02.2024.
6+
//
7+
8+
import Foundation
9+
10+
final class SessionLockStorage {
11+
static let shared = SessionLockStorage()
12+
13+
/// Internal queue allowing to sync access to storage with locks
14+
private let syncAccessQueue = DispatchQueue(label: "StorageAccessQueue")
15+
16+
/// Container to store locks and associate them with corresponding protocol addresses
17+
private var storage: [ProtocolAddress: NSRecursiveLock] = [:]
18+
19+
/// Returns a lock for corresponding protocol addresses
20+
/// - Parameter address: The protocol address that uniquely identifies session and lock for it
21+
/// - Returns: A lock allowing to sync access to the session
22+
func lock(for address: ProtocolAddress) -> NSRecursiveLock {
23+
var result: NSRecursiveLock
24+
if let lock = self.storage[address] {
25+
result = lock
26+
} else {
27+
result = NSRecursiveLock()
28+
self.storage[address] = result
29+
}
30+
31+
return result
32+
}
33+
}

Tests/DXProtocolTests/SessionTests/SessionTests.swift

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,168 @@ final class SessionTests: XCTestCase {
265265
let decryptedText2 = try XCTUnwrap(String(data: decrypted2, encoding: .utf8))
266266
XCTAssertEqual(plaintext2, decryptedText2)
267267
}
268+
269+
func testThreadSafeSimultaneousDecrypt() async throws {
270+
let senderClient = try TestClient(userId: UUID()) // Alice
271+
let recipientClient = try TestClient(userId: UUID()) // Bob
272+
let recipientAddress = recipientClient.protocolAddress
273+
274+
try initializeSession(senderClient: senderClient, recipientClient: recipientClient)
275+
276+
// Alice creates a set of messages
277+
let aliceMessageCount = 50
278+
var aliceMessages: [(Data, MessageContainer)] = []
279+
for index in 0..<aliceMessageCount {
280+
var aliceSession = try senderClient.sessionStore.loadSession(for: recipientAddress)
281+
let data = try XCTUnwrap("From Alice \(index)".data(using: .utf8))
282+
let message = try XCTUnwrap(
283+
try aliceSession?.encrypt(
284+
data: data,
285+
for: recipientAddress,
286+
sessionStore: senderClient.sessionStore,
287+
identityStore: senderClient.identityKeyStore)
288+
)
289+
let item = (data, message)
290+
aliceMessages.append(item)
291+
}
292+
aliceMessages.shuffle()
293+
294+
var tasks = [Task<(Data, Data), Error>]()
295+
for aliceMessage in aliceMessages {
296+
let task = Task {
297+
let decryptedMessage = try Session.decrypt(
298+
message: aliceMessage.1,
299+
from: senderClient.protocolAddress,
300+
sessionStore: recipientClient.sessionStore,
301+
identityStore: recipientClient.identityKeyStore,
302+
preKeyStore: recipientClient.preKeyStore,
303+
signedPreKeyStore: recipientClient.signedPreKeyStore)
304+
let expectedMessage = aliceMessage.0
305+
return (decryptedMessage, expectedMessage)
306+
}
307+
tasks.append(task)
308+
}
309+
310+
var results = [(decrypted: Data, expected: Data)]()
311+
for task in tasks {
312+
let result = try await task.value
313+
results.append(result)
314+
}
315+
316+
for result in results {
317+
XCTAssertEqual(result.decrypted, result.expected)
318+
}
319+
}
320+
321+
func testThreadSafeSimultaneousEncrypt() async throws {
322+
let senderClient = try TestClient(userId: UUID()) // Alice
323+
let recipientClient = try TestClient(userId: UUID()) // Bob
324+
let recipientAddress = recipientClient.protocolAddress
325+
326+
try initializeSession(senderClient: senderClient, recipientClient: recipientClient)
327+
328+
// Actually test for Thread Safe begins here
329+
330+
let plaintext = "Do not despair when your enemy attacks you."
331+
let data = try XCTUnwrap(plaintext.data(using: .utf8))
332+
let task1 = Task {
333+
var session = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: recipientAddress))
334+
return try session.encrypt(
335+
data: data,
336+
for: recipientClient.protocolAddress,
337+
sessionStore: senderClient.sessionStore,
338+
identityStore: senderClient.identityKeyStore)
339+
}
340+
341+
let task2 = Task {
342+
var session = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: recipientAddress))
343+
return try session.encrypt(
344+
data: data,
345+
for: recipientClient.protocolAddress,
346+
sessionStore: senderClient.sessionStore,
347+
identityStore: senderClient.identityKeyStore)
348+
}
349+
350+
let task3 = Task {
351+
var session = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: recipientAddress))
352+
return try session.encrypt(
353+
data: data,
354+
for: recipientClient.protocolAddress,
355+
sessionStore: senderClient.sessionStore,
356+
identityStore: senderClient.identityKeyStore)
357+
}
358+
359+
let task4 = Task {
360+
var session = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: recipientAddress))
361+
return try session.encrypt(
362+
data: data,
363+
for: recipientClient.protocolAddress,
364+
sessionStore: senderClient.sessionStore,
365+
identityStore: senderClient.identityKeyStore)
366+
}
367+
368+
let task5 = Task {
369+
var session = try XCTUnwrap(try senderClient.sessionStore.loadSession(for: recipientAddress))
370+
return try session.encrypt(
371+
data: data,
372+
for: recipientClient.protocolAddress,
373+
sessionStore: senderClient.sessionStore,
374+
identityStore: senderClient.identityKeyStore)
375+
}
376+
377+
let message1 = try await task1.value
378+
let message2 = try await task2.value
379+
let message3 = try await task3.value
380+
let message4 = try await task4.value
381+
let message5 = try await task5.value
382+
383+
let decryptedMessage1 = try Session.decrypt(
384+
message: message1,
385+
from: senderClient.protocolAddress,
386+
sessionStore: recipientClient.sessionStore,
387+
identityStore: recipientClient.identityKeyStore,
388+
preKeyStore: recipientClient.preKeyStore,
389+
signedPreKeyStore: recipientClient.signedPreKeyStore)
390+
391+
let decryptedMessage2 = try Session.decrypt(
392+
message: message2,
393+
from: senderClient.protocolAddress,
394+
sessionStore: recipientClient.sessionStore,
395+
identityStore: recipientClient.identityKeyStore,
396+
preKeyStore: recipientClient.preKeyStore,
397+
signedPreKeyStore: recipientClient.signedPreKeyStore)
398+
399+
let decryptedMessage3 = try Session.decrypt(
400+
message: message3,
401+
from: senderClient.protocolAddress,
402+
sessionStore: recipientClient.sessionStore,
403+
identityStore: recipientClient.identityKeyStore,
404+
preKeyStore: recipientClient.preKeyStore,
405+
signedPreKeyStore: recipientClient.signedPreKeyStore)
406+
407+
let decryptedMessage4 = try Session.decrypt(
408+
message: message4,
409+
from: senderClient.protocolAddress,
410+
sessionStore: recipientClient.sessionStore,
411+
identityStore: recipientClient.identityKeyStore,
412+
preKeyStore: recipientClient.preKeyStore,
413+
signedPreKeyStore: recipientClient.signedPreKeyStore)
414+
415+
let decryptedMessage5 = try Session.decrypt(
416+
message: message5,
417+
from: senderClient.protocolAddress,
418+
sessionStore: recipientClient.sessionStore,
419+
identityStore: recipientClient.identityKeyStore,
420+
preKeyStore: recipientClient.preKeyStore,
421+
signedPreKeyStore: recipientClient.signedPreKeyStore)
422+
423+
XCTAssertEqual(decryptedMessage1, data)
424+
XCTAssertEqual(decryptedMessage2, data)
425+
XCTAssertEqual(decryptedMessage3, data)
426+
XCTAssertEqual(decryptedMessage4, data)
427+
XCTAssertEqual(decryptedMessage5, data)
428+
}
429+
268430
// FIXME: - Need fix
269431
func testOutOfOrder() throws {
270432
let senderClient = try TestClient(userId: UUID()) // Alice

0 commit comments

Comments
 (0)