@@ -16,7 +16,9 @@ import org.junit.Test
1616import org.junit.rules.TemporaryFolder
1717import org.mockito.kotlin.any
1818import org.mockito.kotlin.anyOrNull
19+ import org.mockito.kotlin.argumentCaptor
1920import org.mockito.kotlin.mock
21+ import org.mockito.kotlin.times
2022import org.mockito.kotlin.verify
2123import org.mockito.kotlin.whenever
2224import to.bitkit.data.TrezorStore
@@ -162,6 +164,23 @@ class TrezorRepoTest : BaseUnitTest() {
162164 assertFalse(sut.state.value.isScanning)
163165 }
164166
167+ @Test
168+ fun `scan should exclude known devices from nearbyDevices state` () = test {
169+ val knownDevice = mockKnownDevice()
170+ val known = mockDeviceInfo()
171+ val nearby = mockDeviceInfo(id = " device-456" , path = " /dev/trezor1" )
172+ whenever(trezorStore.loadKnownDevices()).thenReturn(listOf (knownDevice))
173+ whenever(trezorService.scan()).thenReturn(listOf (known, nearby))
174+ sut = createSut()
175+
176+ sut.initialize()
177+ val result = sut.scan()
178+
179+ assertTrue(result.isSuccess)
180+ assertEquals(listOf (known, nearby), result.getOrNull())
181+ assertEquals(listOf (nearby), sut.state.value.nearbyDevices)
182+ }
183+
165184 @Test
166185 fun `scan should set error on failure` () = test {
167186 whenever(trezorService.scan()).thenThrow(RuntimeException (" scan failed" ))
@@ -198,6 +217,56 @@ class TrezorRepoTest : BaseUnitTest() {
198217 assertFalse(sut.state.value.isConnecting)
199218 }
200219
220+ @Test
221+ fun `connect should persist connected device as known device` () = test {
222+ val features = mockFeatures(label = " Savings" , model = " Safe 5" )
223+ val device = mockDeviceInfo()
224+ whenever(trezorService.connect(DEVICE_ID )).thenReturn(features)
225+ whenever(trezorService.scan()).thenReturn(listOf (device))
226+ sut = createSut()
227+
228+ sut.scan()
229+ val result = sut.connect(DEVICE_ID )
230+
231+ assertTrue(result.isSuccess)
232+ val captor = argumentCaptor<List <KnownDevice >>()
233+ verify(trezorStore).saveKnownDevices(captor.capture())
234+ val saved = captor.firstValue.single()
235+ assertEquals(DEVICE_ID , saved.id)
236+ assertEquals(KnownDeviceTransportType .USB , saved.transportType)
237+ assertEquals(" Savings" , saved.label)
238+ assertEquals(" Safe 5" , saved.model)
239+ }
240+
241+ @Test
242+ fun `connect should retry once for retryable THP errors` () = test {
243+ val features = mockFeatures()
244+ val device = mockDeviceInfo()
245+ whenever(trezorService.connect(DEVICE_ID ))
246+ .thenThrow(RuntimeException (" thp timeout" ))
247+ .thenReturn(features)
248+ whenever(trezorService.scan()).thenReturn(listOf (device))
249+ sut = createSut()
250+
251+ sut.scan()
252+ val result = sut.connect(DEVICE_ID )
253+
254+ assertTrue(result.isSuccess)
255+ assertEquals(features, result.getOrNull())
256+ verify(trezorService, times(2 )).connect(DEVICE_ID )
257+ }
258+
259+ @Test
260+ fun `connect should not retry non-retryable errors` () = test {
261+ whenever(trezorService.connect(DEVICE_ID )).thenThrow(RuntimeException (" bad pin" ))
262+ sut = createSut()
263+
264+ val result = sut.connect(DEVICE_ID )
265+
266+ assertTrue(result.isFailure)
267+ verify(trezorService, times(1 )).connect(DEVICE_ID )
268+ }
269+
201270 @Test
202271 fun `connect should set error on failure` () = test {
203272 whenever(trezorService.connect(DEVICE_ID )).thenThrow(RuntimeException (" connect failed" ))
@@ -408,6 +477,62 @@ class TrezorRepoTest : BaseUnitTest() {
408477
409478 // endregion
410479
480+ // region autoReconnect
481+
482+ @Test
483+ fun `autoReconnect should fail when no known devices exist` () = test {
484+ sut = createSut()
485+
486+ val result = sut.autoReconnect()
487+
488+ assertTrue(result.isFailure)
489+ assertEquals(" No known devices" , result.exceptionOrNull()?.message)
490+ }
491+
492+ @Test
493+ fun `autoReconnect should scan and connect known nearby device` () = test {
494+ val knownDevice = mockKnownDevice()
495+ val device = mockDeviceInfo()
496+ val features = mockFeatures()
497+ whenever(trezorStore.loadKnownDevices()).thenReturn(listOf (knownDevice))
498+ whenever(trezorService.scan()).thenReturn(listOf (device))
499+ whenever(trezorService.connect(DEVICE_ID )).thenReturn(features)
500+ whenever(trezorService.isConnected()).thenReturn(false )
501+ sut = createSut()
502+
503+ sut.initialize()
504+ val result = sut.autoReconnect()
505+
506+ assertTrue(result.isSuccess)
507+ assertEquals(features, result.getOrNull())
508+ assertEquals(DEVICE_ID , sut.state.value.connectedDeviceId)
509+ assertFalse(sut.state.value.isAutoReconnecting)
510+ }
511+
512+ // endregion
513+
514+ // region connectKnownDevice
515+
516+ @Test
517+ fun `connectKnownDevice should connect exact known device match` () = test {
518+ val knownDevice = mockKnownDevice()
519+ val device = mockDeviceInfo()
520+ val features = mockFeatures()
521+ whenever(trezorStore.loadKnownDevices()).thenReturn(listOf (knownDevice))
522+ whenever(trezorService.scan()).thenReturn(listOf (device))
523+ whenever(trezorService.connect(DEVICE_ID )).thenReturn(features)
524+ sut = createSut()
525+
526+ sut.initialize()
527+ val result = sut.connectKnownDevice(DEVICE_ID )
528+
529+ assertTrue(result.isSuccess)
530+ assertEquals(features, result.getOrNull())
531+ assertEquals(DEVICE_ID , sut.state.value.connectedDeviceId)
532+ }
533+
534+ // endregion
535+
411536 // region clearError
412537
413538 @Test
@@ -442,6 +567,40 @@ class TrezorRepoTest : BaseUnitTest() {
442567
443568 // endregion
444569
570+ // region ensureConnected
571+
572+ @Test
573+ fun `getAddress should reconnect known device before reading address` () = test {
574+ val knownDevice = mockKnownDevice()
575+ val device = mockDeviceInfo()
576+ val features = mockFeatures()
577+ val addressResponse = mock<TrezorAddressResponse >()
578+ whenever(trezorStore.loadKnownDevices()).thenReturn(listOf (knownDevice))
579+ whenever(trezorService.isConnected()).thenReturn(false )
580+ whenever(trezorService.scan()).thenReturn(listOf (device))
581+ whenever(trezorService.connect(DEVICE_ID )).thenReturn(features)
582+ whenever(
583+ trezorService.getAddress(
584+ path = any(),
585+ coin = any(),
586+ showOnTrezor = any(),
587+ scriptType = anyOrNull(),
588+ )
589+ ).thenReturn(addressResponse)
590+ sut = createSut()
591+
592+ sut.initialize()
593+ val result = sut.getAddress()
594+
595+ assertTrue(result.isSuccess)
596+ assertEquals(addressResponse, result.getOrNull())
597+ assertEquals(DEVICE_ID , sut.state.value.connectedDeviceId)
598+ verify(trezorService).scan()
599+ verify(trezorService).connect(DEVICE_ID )
600+ }
601+
602+ // endregion
603+
445604 // region forgetDevice
446605
447606 @Test
0 commit comments