Skip to content

Commit f8e39e6

Browse files
committed
fix: watch only monitored hw address types
1 parent 477575a commit f8e39e6

2 files changed

Lines changed: 56 additions & 7 deletions

File tree

app/src/main/java/to/bitkit/repositories/HwWalletRepo.kt

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ import kotlinx.coroutines.flow.MutableStateFlow
1717
import kotlinx.coroutines.flow.SharingStarted
1818
import kotlinx.coroutines.flow.StateFlow
1919
import kotlinx.coroutines.flow.combine
20+
import kotlinx.coroutines.flow.distinctUntilChanged
2021
import kotlinx.coroutines.flow.map
2122
import kotlinx.coroutines.flow.stateIn
2223
import kotlinx.coroutines.flow.update
2324
import kotlinx.coroutines.launch
25+
import to.bitkit.data.SettingsStore
2426
import to.bitkit.data.TrezorStore
2527
import to.bitkit.di.IoDispatcher
2628
import to.bitkit.env.Env
@@ -43,6 +45,7 @@ import javax.inject.Singleton
4345
class HwWalletRepo @Inject constructor(
4446
private val trezorRepo: TrezorRepo,
4547
private val trezorStore: TrezorStore,
48+
private val settingsStore: SettingsStore,
4649
@IoDispatcher private val ioDispatcher: CoroutineDispatcher,
4750
) {
4851
companion object {
@@ -100,9 +103,19 @@ class HwWalletRepo @Inject constructor(
100103

101104
private fun syncWatchers() {
102105
scope.launch {
103-
trezorStore.data.collect { data ->
104-
val wanted = data.knownDevices.flatMap { device ->
105-
device.xpubs.map { (addressType, xpub) -> WatcherSpec(device.id, addressType, xpub) }
106+
combine(
107+
trezorStore.data,
108+
settingsStore.data.map { it.addressTypesToMonitor.toSet() }.distinctUntilChanged(),
109+
) { data, monitoredTypes ->
110+
data.knownDevices to monitoredTypes
111+
}.collect { (knownDevices, monitoredTypes) ->
112+
// Only watch the address types the user monitors (Settings > Advanced > Address Type),
113+
// mirroring the on-chain wallet. Xpubs for all types are still captured on connect, so
114+
// toggling a type on later starts its watcher without reconnecting the device.
115+
val wanted = knownDevices.flatMap { device ->
116+
device.xpubs
117+
.filterKeys { it in monitoredTypes }
118+
.map { (addressType, xpub) -> WatcherSpec(device.id, addressType, xpub) }
106119
}
107120
val wantedIds = wanted.map { it.watcherId }.toSet()
108121

app/src/test/java/to/bitkit/repositories/HwWalletRepoTest.kt

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,15 @@ import kotlinx.coroutines.flow.MutableSharedFlow
1010
import kotlinx.coroutines.flow.MutableStateFlow
1111
import org.junit.Before
1212
import org.junit.Test
13+
import org.mockito.kotlin.any
14+
import org.mockito.kotlin.anyOrNull
15+
import org.mockito.kotlin.eq
1316
import org.mockito.kotlin.mock
17+
import org.mockito.kotlin.never
18+
import org.mockito.kotlin.verify
1419
import org.mockito.kotlin.whenever
20+
import to.bitkit.data.SettingsData
21+
import to.bitkit.data.SettingsStore
1522
import to.bitkit.data.TrezorData
1623
import to.bitkit.data.TrezorStore
1724
import to.bitkit.test.BaseUnitTest
@@ -21,8 +28,10 @@ class HwWalletRepoTest : BaseUnitTest() {
2128

2229
private val trezorRepo = mock<TrezorRepo>()
2330
private val trezorStore = mock<TrezorStore>()
31+
private val settingsStore = mock<SettingsStore>()
2432

2533
private lateinit var storeData: MutableStateFlow<TrezorData>
34+
private lateinit var settingsData: MutableStateFlow<SettingsData>
2635
private lateinit var trezorState: MutableStateFlow<TrezorState>
2736
private lateinit var watcherEvents: MutableSharedFlow<Pair<String, WatcherEvent>>
2837

@@ -39,16 +48,20 @@ class HwWalletRepoTest : BaseUnitTest() {
3948
@Before
4049
fun setUp() {
4150
storeData = MutableStateFlow(TrezorData(knownDevices = listOf(device)))
51+
settingsData = MutableStateFlow(SettingsData())
4252
trezorState = MutableStateFlow(TrezorState())
4353
watcherEvents = MutableSharedFlow(extraBufferCapacity = 8)
4454
whenever(trezorStore.data).thenReturn(storeData)
55+
whenever(settingsStore.data).thenReturn(settingsData)
4556
whenever(trezorRepo.state).thenReturn(trezorState)
4657
whenever(trezorRepo.watcherEvents).thenReturn(watcherEvents)
4758
}
4859

60+
private fun createRepo() = HwWalletRepo(trezorRepo, trezorStore, settingsStore, testDispatcher)
61+
4962
@Test
5063
fun `lists a known device with zero balance before any watcher event`() = test {
51-
val sut = HwWalletRepo(trezorRepo, trezorStore, testDispatcher)
64+
val sut = createRepo()
5265

5366
val wallet = sut.hardwareWallets.value.single()
5467
assertEquals("dev1", wallet.id)
@@ -59,7 +72,7 @@ class HwWalletRepoTest : BaseUnitTest() {
5972

6073
@Test
6174
fun `transactions changed event sets device balance and maps activity`() = test {
62-
val sut = HwWalletRepo(trezorRepo, trezorStore, testDispatcher)
75+
val sut = createRepo()
6376

6477
watcherEvents.emit(
6578
"dev1|nativeSegwit" to WatcherEvent.TransactionsChanged(
@@ -81,7 +94,7 @@ class HwWalletRepoTest : BaseUnitTest() {
8194

8295
@Test
8396
fun `balances from multiple address-type watchers are summed per device`() = test {
84-
val sut = HwWalletRepo(trezorRepo, trezorStore, testDispatcher)
97+
val sut = createRepo()
8598

8699
watcherEvents.emit(
87100
"dev1|nativeSegwit" to WatcherEvent.TransactionsChanged(
@@ -98,14 +111,37 @@ class HwWalletRepoTest : BaseUnitTest() {
98111
transactions = emptyList(),
99112
txCount = 0u,
100113
blockHeight = 1u,
101-
accountType = AccountType.NATIVE_SEGWIT,
114+
accountType = AccountType.TAPROOT,
102115
)
103116
)
104117

105118
assertEquals(150uL, sut.hardwareWallets.value.single().balanceSats)
106119
assertEquals(150uL, sut.totalHardwareSats.value)
107120
}
108121

122+
@Test
123+
fun `starts watchers only for the address types the user monitors`() = test {
124+
storeData.value = TrezorData(
125+
knownDevices = listOf(
126+
device.copy(
127+
xpubs = mapOf(
128+
"nativeSegwit" to "zpubNS",
129+
"taproot" to "zpubTR",
130+
"legacy" to "xpubLG",
131+
)
132+
)
133+
)
134+
)
135+
settingsData.value = SettingsData(addressTypesToMonitor = listOf("nativeSegwit", "taproot"))
136+
whenever(trezorRepo.startWatcher(any(), any(), any(), any(), anyOrNull())).thenReturn(Result.success(Unit))
137+
138+
createRepo()
139+
140+
verify(trezorRepo).startWatcher(eq("dev1|nativeSegwit"), any(), any(), any(), anyOrNull())
141+
verify(trezorRepo).startWatcher(eq("dev1|taproot"), any(), any(), any(), anyOrNull())
142+
verify(trezorRepo, never()).startWatcher(eq("dev1|legacy"), any(), any(), any(), anyOrNull())
143+
}
144+
109145
private fun walletBalance(total: ULong) = WalletBalance(
110146
confirmed = total,
111147
immature = 0uL,

0 commit comments

Comments
 (0)