@@ -10,8 +10,15 @@ import kotlinx.coroutines.flow.MutableSharedFlow
1010import kotlinx.coroutines.flow.MutableStateFlow
1111import org.junit.Before
1212import org.junit.Test
13+ import org.mockito.kotlin.any
14+ import org.mockito.kotlin.anyOrNull
15+ import org.mockito.kotlin.eq
1316import org.mockito.kotlin.mock
17+ import org.mockito.kotlin.never
18+ import org.mockito.kotlin.verify
1419import org.mockito.kotlin.whenever
20+ import to.bitkit.data.SettingsData
21+ import to.bitkit.data.SettingsStore
1522import to.bitkit.data.TrezorData
1623import to.bitkit.data.TrezorStore
1724import to.bitkit.test.BaseUnitTest
@@ -21,8 +28,10 @@ class HwWalletRepoTest : BaseUnitTest() {
2128
2229 private val trezorRepo = mock<TrezorRepo >()
2330 private val trezorStore = mock<TrezorStore >()
31+ private val settingsStore = mock<SettingsStore >()
2432
2533 private lateinit var storeData: MutableStateFlow <TrezorData >
34+ private lateinit var settingsData: MutableStateFlow <SettingsData >
2635 private lateinit var trezorState: MutableStateFlow <TrezorState >
2736 private lateinit var watcherEvents: MutableSharedFlow <Pair <String , WatcherEvent >>
2837
@@ -39,16 +48,20 @@ class HwWalletRepoTest : BaseUnitTest() {
3948 @Before
4049 fun setUp () {
4150 storeData = MutableStateFlow (TrezorData (knownDevices = listOf (device)))
51+ settingsData = MutableStateFlow (SettingsData ())
4252 trezorState = MutableStateFlow (TrezorState ())
4353 watcherEvents = MutableSharedFlow (extraBufferCapacity = 8 )
4454 whenever(trezorStore.data).thenReturn(storeData)
55+ whenever(settingsStore.data).thenReturn(settingsData)
4556 whenever(trezorRepo.state).thenReturn(trezorState)
4657 whenever(trezorRepo.watcherEvents).thenReturn(watcherEvents)
4758 }
4859
60+ private fun createRepo () = HwWalletRepo (trezorRepo, trezorStore, settingsStore, testDispatcher)
61+
4962 @Test
5063 fun `lists a known device with zero balance before any watcher event` () = test {
51- val sut = HwWalletRepo (trezorRepo, trezorStore, testDispatcher )
64+ val sut = createRepo( )
5265
5366 val wallet = sut.hardwareWallets.value.single()
5467 assertEquals(" dev1" , wallet.id)
@@ -59,7 +72,7 @@ class HwWalletRepoTest : BaseUnitTest() {
5972
6073 @Test
6174 fun `transactions changed event sets device balance and maps activity` () = test {
62- val sut = HwWalletRepo (trezorRepo, trezorStore, testDispatcher )
75+ val sut = createRepo( )
6376
6477 watcherEvents.emit(
6578 " dev1|nativeSegwit" to WatcherEvent .TransactionsChanged (
@@ -81,7 +94,7 @@ class HwWalletRepoTest : BaseUnitTest() {
8194
8295 @Test
8396 fun `balances from multiple address-type watchers are summed per device` () = test {
84- val sut = HwWalletRepo (trezorRepo, trezorStore, testDispatcher )
97+ val sut = createRepo( )
8598
8699 watcherEvents.emit(
87100 " dev1|nativeSegwit" to WatcherEvent .TransactionsChanged (
@@ -98,14 +111,37 @@ class HwWalletRepoTest : BaseUnitTest() {
98111 transactions = emptyList(),
99112 txCount = 0u ,
100113 blockHeight = 1u ,
101- accountType = AccountType .NATIVE_SEGWIT ,
114+ accountType = AccountType .TAPROOT ,
102115 )
103116 )
104117
105118 assertEquals(150uL , sut.hardwareWallets.value.single().balanceSats)
106119 assertEquals(150uL , sut.totalHardwareSats.value)
107120 }
108121
122+ @Test
123+ fun `starts watchers only for the address types the user monitors` () = test {
124+ storeData.value = TrezorData (
125+ knownDevices = listOf (
126+ device.copy(
127+ xpubs = mapOf (
128+ " nativeSegwit" to " zpubNS" ,
129+ " taproot" to " zpubTR" ,
130+ " legacy" to " xpubLG" ,
131+ )
132+ )
133+ )
134+ )
135+ settingsData.value = SettingsData (addressTypesToMonitor = listOf (" nativeSegwit" , " taproot" ))
136+ whenever(trezorRepo.startWatcher(any(), any(), any(), any(), anyOrNull())).thenReturn(Result .success(Unit ))
137+
138+ createRepo()
139+
140+ verify(trezorRepo).startWatcher(eq(" dev1|nativeSegwit" ), any(), any(), any(), anyOrNull())
141+ verify(trezorRepo).startWatcher(eq(" dev1|taproot" ), any(), any(), any(), anyOrNull())
142+ verify(trezorRepo, never()).startWatcher(eq(" dev1|legacy" ), any(), any(), any(), anyOrNull())
143+ }
144+
109145 private fun walletBalance (total : ULong ) = WalletBalance (
110146 confirmed = total,
111147 immature = 0uL ,
0 commit comments