Skip to content

Commit d35810d

Browse files
committed
fix: harden hw pairing flow
1 parent f3e3e6a commit d35810d

7 files changed

Lines changed: 136 additions & 19 deletions

File tree

app/src/main/java/to/bitkit/repositories/HwWalletRepo.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ class HwWalletRepo @Inject constructor(
127127
suspend fun hasKnownDevice(deviceId: String): Boolean = trezorRepo.hasKnownDevice(deviceId)
128128

129129
/** Connects and pairs a discovered device, persisting it as a watch-only known device. */
130-
suspend fun connect(deviceId: String): Result<TrezorFeatures> = trezorRepo.connect(deviceId)
130+
suspend fun connect(deviceId: String): Result<TrezorFeatures> {
131+
trezorRepo.resetWalletSelection()
132+
return trezorRepo.connect(deviceId)
133+
}
131134

132135
/**
133136
* Persists the Bitkit-side funds label for a paired device. Applied to every entry sharing the

app/src/main/java/to/bitkit/repositories/TrezorRepo.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -828,9 +828,9 @@ class TrezorRepo @Inject constructor(
828828
}
829829

830830
private suspend fun addOrUpdateKnownDevice(deviceInfo: TrezorDeviceInfo, features: TrezorFeatures) {
831-
val existing = _state.value.knownDevices
832-
val existingIds = existing.map { it.id }.toSet()
833-
val knownDevices = existing + hwWalletStore.loadKnownDevices().filter { it.id !in existingIds }
831+
val stored = hwWalletStore.loadKnownDevices()
832+
val storedIds = stored.map { it.id }.toSet()
833+
val knownDevices = stored + _state.value.knownDevices.filter { it.id !in storedIds }
834834
val previous = knownDevices.find { it.id == deviceInfo.id }
835835
val known = KnownDevice(
836836
id = deviceInfo.id,

app/src/main/java/to/bitkit/ui/sheets/hardware/HardwareSheet.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ fun HardwareSheet(
160160
isConnecting = uiState.isConnecting,
161161
errorMessage = uiState.errorMessage,
162162
onConnect = { viewModel.onConnectClick(route.deviceId) },
163-
onCancel = appViewModel::hideSheet,
163+
onCancel = {
164+
viewModel.cancelConnect()
165+
appViewModel.hideSheet()
166+
},
164167
)
165168
}
166169
composableWithDefaultTransitions<HardwareRoute.Paired> {
@@ -180,6 +183,7 @@ fun HardwareSheet(
180183
}
181184

182185
BackHandler {
186+
viewModel.cancelConnect()
183187
appViewModel.hideSheet()
184188
}
185189

app/src/main/java/to/bitkit/ui/sheets/hardware/HwConnectViewModel.kt

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class HwConnectViewModel @Inject constructor(
4747
val effects = _effects.asSharedFlow()
4848

4949
private var searchJob: Job? = null
50+
private var connectJob: Job? = null
5051
private var labelInitialized = false
5152
private var includeBluetoothInScan = true
5253
private var scanUsbBeforeConnect = false
@@ -85,26 +86,64 @@ class HwConnectViewModel @Inject constructor(
8586
fun onConnectClick(deviceIdOverride: String? = null) {
8687
val state = _uiState.value
8788
val deviceId = deviceIdOverride ?: state.foundDeviceId ?: return
89+
if (connectJob?.isActive == true) return
8890
val shouldScanUsbBeforeConnect = scanUsbBeforeConnect
8991
searchJob?.cancel()
9092
_uiState.update { it.copy(isConnecting = true, errorMessage = null) }
91-
viewModelScope.launch {
93+
connectJob = viewModelScope.launch {
94+
var resolvedDeviceId = deviceId
95+
var resolvedDeviceModel = state.deviceModel
9296
if (shouldScanUsbBeforeConnect) {
9397
hwWalletRepo.scan(includeBluetooth = false)
94-
}
95-
hwWalletRepo.connect(deviceId)
96-
.onSuccess { onConnected(deviceId, it) }
97-
.onFailure {
98-
_uiState.update { state ->
99-
state.copy(
100-
isConnecting = false,
101-
errorMessage = context.getString(R.string.hardware__connect_error),
102-
)
98+
.onSuccess { devices ->
99+
devices.firstOrNull { it.id == deviceId || it.path == deviceId }?.let { device ->
100+
resolvedDeviceId = device.id
101+
resolvedDeviceModel = resolveHwWalletName(label = null, model = device.model)
102+
_uiState.update {
103+
it.copy(
104+
foundDeviceId = resolvedDeviceId,
105+
deviceModel = resolvedDeviceModel,
106+
)
107+
}
108+
}
103109
}
104-
}
110+
.onFailure {
111+
onConnectFailed(resolvedDeviceId, resolvedDeviceModel)
112+
return@launch
113+
}
114+
}
115+
hwWalletRepo.connect(resolvedDeviceId)
116+
.onSuccess { onConnected(resolvedDeviceId, it) }
117+
.onFailure { onConnectFailed(resolvedDeviceId, resolvedDeviceModel) }
118+
connectJob = null
105119
}
106120
}
107121

122+
private fun onConnectFailed(deviceId: String, deviceModel: String) {
123+
_uiState.update {
124+
it.copy(
125+
isConnecting = false,
126+
foundDeviceId = deviceId,
127+
deviceModel = deviceModel,
128+
errorMessage = context.getString(R.string.hardware__connect_error),
129+
)
130+
}
131+
setEffect(
132+
HwConnectEffect.NavigateToFound(
133+
deviceId = deviceId,
134+
deviceModel = deviceModel,
135+
)
136+
)
137+
connectJob = null
138+
}
139+
140+
fun cancelConnect() {
141+
connectJob?.cancel()
142+
connectJob = null
143+
hwWalletRepo.cancelPairingCode()
144+
_uiState.update { it.copy(isConnecting = false) }
145+
}
146+
108147
fun onLabelChange(value: String) = _uiState.update { it.copy(labelInput = value.take(DEVICE_LABEL_MAX_LENGTH)) }
109148

110149
fun onFinishClick() {
@@ -122,6 +161,9 @@ class HwConnectViewModel @Inject constructor(
122161
fun resetState() {
123162
searchJob?.cancel()
124163
searchJob = null
164+
connectJob?.cancel()
165+
connectJob = null
166+
hwWalletRepo.cancelPairingCode()
125167
labelInitialized = false
126168
includeBluetoothInScan = true
127169
scanUsbBeforeConnect = false

app/src/test/java/to/bitkit/repositories/HwWalletRepoTest.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,7 @@ class HwWalletRepoTest : BaseUnitTest() {
773773

774774
sut.connect("dev1")
775775

776+
verify(trezorRepo).resetWalletSelection()
776777
verify(trezorRepo).connect("dev1")
777778
}
778779

app/src/test/java/to/bitkit/repositories/TrezorRepoTest.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class TrezorRepoTest : BaseUnitTest() {
146146
model: String? = DEVICE_MODEL,
147147
transportType: TransportType = TransportType.USB,
148148
xpubs: Map<String, String> = emptyMap(),
149+
customLabel: String? = null,
149150
) = KnownDevice(
150151
id = id,
151152
name = name,
@@ -155,6 +156,7 @@ class TrezorRepoTest : BaseUnitTest() {
155156
model = model,
156157
lastConnectedAt = 123L,
157158
xpubs = xpubs,
159+
customLabel = customLabel,
158160
)
159161

160162
// region initialize
@@ -584,6 +586,26 @@ class TrezorRepoTest : BaseUnitTest() {
584586
)
585587
}
586588

589+
@Test
590+
fun `connect preserves stored custom label over stale state label`() = test {
591+
val features = mockFeatures()
592+
val device = mockDeviceInfo()
593+
whenever(hwWalletStore.loadKnownDevices())
594+
.thenReturn(listOf(mockKnownDevice()))
595+
.thenReturn(listOf(mockKnownDevice(customLabel = "Cold Storage")))
596+
whenever(trezorService.connect(eq(DEVICE_ID), any())).thenReturn(features)
597+
whenever(trezorService.scan()).thenReturn(listOf(device))
598+
sut = createSut()
599+
600+
sut.scan()
601+
val result = sut.connect(DEVICE_ID)
602+
603+
assertTrue(result.isSuccess)
604+
val captor = argumentCaptor<List<KnownDevice>>()
605+
verify(hwWalletStore).saveKnownDevices(captor.capture())
606+
assertEquals("Cold Storage", captor.lastValue.single().customLabel)
607+
}
608+
587609
@Test
588610
fun `connect should retry once for retryable THP errors`() = test {
589611
val features = mockFeatures()

app/src/test/java/to/bitkit/ui/sheets/hardware/HwConnectViewModelTest.kt

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,46 @@ class HwConnectViewModelTest : BaseUnitTest() {
118118
verify(hwWalletRepo).connect("usb1")
119119
}
120120

121+
@Test
122+
fun `onConnectClick uses scanned device id for usb route path`() = test {
123+
val path = "/dev/bus/usb/001/002"
124+
val connectedFeatures = features(model = "Safe 5")
125+
val usbDevice = deviceInfo(
126+
id = "core-usb-id",
127+
model = "Safe 5",
128+
transportType = TrezorTransportType.USB,
129+
path = path,
130+
)
131+
whenever(hwWalletRepo.scan(includeBluetooth = false)).thenReturn(Result.success(listOf(usbDevice)))
132+
whenever(hwWalletRepo.connect("core-usb-id")).thenReturn(Result.success(connectedFeatures))
133+
sut.onFoundRoute(deviceId = path, deviceModel = "Trezor")
134+
135+
sut.effects.test {
136+
sut.onConnectClick()
137+
assertEquals(HwConnectEffect.NavigateToPaired, awaitItem())
138+
cancelAndIgnoreRemainingEvents()
139+
}
140+
141+
verify(hwWalletRepo).connect("core-usb-id")
142+
assertEquals("core-usb-id", sut.uiState.value.foundDeviceId)
143+
}
144+
145+
@Test
146+
fun `onConnectClick returns to found when connect fails from pair code`() = test {
147+
whenever(hwWalletRepo.scan(includeBluetooth = false)).thenReturn(Result.success(emptyList()))
148+
whenever(hwWalletRepo.connect("usb1")).thenReturn(Result.failure(AppError("connect failed")))
149+
sut.onFoundRoute(deviceId = "usb1", deviceModel = "Trezor Safe 5")
150+
151+
sut.effects.test {
152+
sut.onConnectClick()
153+
assertEquals(HwConnectEffect.NavigateToFound("usb1", "Trezor Safe 5"), awaitItem())
154+
cancelAndIgnoreRemainingEvents()
155+
}
156+
157+
assertFalse(sut.uiState.value.isConnecting)
158+
assertEquals(CONNECT_ERROR, sut.uiState.value.errorMessage)
159+
}
160+
121161
@Test
122162
fun `onConnectClick connects the found device and advances to paired`() = test {
123163
givenDeviceFound()
@@ -203,11 +243,16 @@ class HwConnectViewModelTest : BaseUnitTest() {
203243
sut.onIntroContinue()
204244
}
205245

206-
private fun deviceInfo(id: String, model: String?) = TrezorDeviceInfo(
246+
private fun deviceInfo(
247+
id: String,
248+
model: String?,
249+
transportType: TrezorTransportType = TrezorTransportType.BLUETOOTH,
250+
path: String = "ble:$id",
251+
) = TrezorDeviceInfo(
207252
id = id,
208-
transportType = TrezorTransportType.BLUETOOTH,
253+
transportType = transportType,
209254
name = null,
210-
path = "ble:$id",
255+
path = path,
211256
label = null,
212257
model = model,
213258
isBootloader = false,

0 commit comments

Comments
 (0)