Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,21 @@ internal class QueryChannelsLogic(

internal suspend fun parseChatEventResults(chatEvents: List<ChatEvent>): List<EventHandlingResult> {
val cids = chatEvents.filterIsInstance<CidEvent>().map { it.cid }.distinct()
val cachedChannels = queryChannelsDatabaseLogic
.selectChannels(cids).associateBy { it.cid }
// Prefer in-memory per-channel state which has already been updated by the channel
// event handlers. Fall back to DB for channels that are not currently active in memory.
val inMemoryChannels = cids.mapNotNull { cid ->
queryChannelsStateLogic.getActiveChannelState(cid)?.let { cid to it }
}.toMap()
val remainingCids = cids - inMemoryChannels.keys
val dbChannels = if (remainingCids.isEmpty()) {
emptyMap()
} else {
queryChannelsDatabaseLogic.selectChannels(remainingCids).associateBy { it.cid }
}
val resolvedChannels = inMemoryChannels + dbChannels

return chatEvents.map { event ->
val channel = (event as? CidEvent)?.let { cachedChannels[it.cid] }
val channel = (event as? CidEvent)?.let { resolvedChannels[it.cid] }
queryChannelsStateLogic.handleChatEvent(event, channel)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ internal class QueryChannelsStateLogic(
mutableState.setChannels(newChannels)
}

/**
* Returns the current [Channel] snapshot from the in-memory per-channel state if the
* channel is active, or `null` otherwise.
*/
internal fun getActiveChannelState(cid: String): Channel? {
val (type, id) = cid.cidToTypeAndId()
if (!stateRegistry.isActiveChannel(type, id)) return null
return stateRegistry.channel(type, id).toChannel()
}

/**
* Refreshes member state in all channels from this query.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
package io.getstream.chat.android.client.internal.state.plugin.logic.querychannels.internal

import io.getstream.chat.android.client.ChatClient
import io.getstream.chat.android.client.api.event.EventHandlingResult
import io.getstream.chat.android.client.api.models.QueryChannelsRequest
import io.getstream.chat.android.client.api.state.QueryChannelsState
import io.getstream.chat.android.client.query.QueryChannelsSpec
import io.getstream.chat.android.client.query.pagination.AnyChannelPaginationRequest
import io.getstream.chat.android.client.test.randomNewMessageEvent
import io.getstream.chat.android.models.Channel
import io.getstream.chat.android.models.FilterObject
import io.getstream.chat.android.models.Filters
Expand All @@ -31,6 +33,7 @@ import io.getstream.chat.android.test.asCall
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.test.runTest
import org.junit.Rule
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.mockito.kotlin.any
Expand Down Expand Up @@ -283,4 +286,75 @@ internal class QueryChannelsLogicTest {
}

// endregion

// region parseChatEventResults

@Test
fun `parseChatEventResults should resolve channels from in-memory state and skip DB`() = runTest {
// Given
val channel = randomChannel(type = "messaging", id = "ch1")
val event = randomNewMessageEvent(cid = channel.cid, channelType = "messaging", channelId = "ch1")
val expectedResult = EventHandlingResult.Skip

whenever(queryChannelsStateLogic.getActiveChannelState(channel.cid)) doReturn channel
whenever(queryChannelsStateLogic.handleChatEvent(eq(event), eq(channel))) doReturn expectedResult

// When
val results = logic.parseChatEventResults(listOf(event))

// Then
verify(queryChannelsDatabaseLogic, never()).selectChannels(any())
assertEquals(listOf(expectedResult), results)
}

@Test
fun `parseChatEventResults should fall back to DB when channel is not active in memory`() = runTest {
// Given
val channel = randomChannel(type = "messaging", id = "ch1")
val event = randomNewMessageEvent(cid = channel.cid, channelType = "messaging", channelId = "ch1")
val expectedResult = EventHandlingResult.Skip

whenever(queryChannelsStateLogic.getActiveChannelState(channel.cid)) doReturn null
whenever(queryChannelsDatabaseLogic.selectChannels(listOf(channel.cid))) doReturn listOf(channel)
whenever(queryChannelsStateLogic.handleChatEvent(eq(event), eq(channel))) doReturn expectedResult

// When
val results = logic.parseChatEventResults(listOf(event))

// Then
verify(queryChannelsDatabaseLogic).selectChannels(listOf(channel.cid))
assertEquals(listOf(expectedResult), results)
}

@Test
fun `parseChatEventResults should use mixed resolution - memory for active, DB for inactive`() = runTest {
// Given
val inMemoryChannel = randomChannel(type = "messaging", id = "active")
val dbChannel = randomChannel(type = "messaging", id = "inactive")
val event1 = randomNewMessageEvent(
cid = inMemoryChannel.cid,
channelType = "messaging",
channelId = "active",
)
val event2 = randomNewMessageEvent(
cid = dbChannel.cid,
channelType = "messaging",
channelId = "inactive",
)

whenever(queryChannelsStateLogic.getActiveChannelState(inMemoryChannel.cid)) doReturn inMemoryChannel
whenever(queryChannelsStateLogic.getActiveChannelState(dbChannel.cid)) doReturn null
whenever(queryChannelsDatabaseLogic.selectChannels(listOf(dbChannel.cid))) doReturn listOf(dbChannel)
whenever(queryChannelsStateLogic.handleChatEvent(any(), any())) doReturn EventHandlingResult.Skip

// When
logic.parseChatEventResults(listOf(event1, event2))

// Then – only the inactive channel should be fetched from DB
verify(queryChannelsDatabaseLogic).selectChannels(listOf(dbChannel.cid))
verify(queryChannelsStateLogic).handleChatEvent(event1, inMemoryChannel)
verify(queryChannelsStateLogic).handleChatEvent(event2, dbChannel)
}

// endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import io.getstream.chat.android.test.TestCoroutineRule
import kotlinx.coroutines.test.runTest
import org.amshove.kluent.`should contain same`
import org.junit.Rule
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.doReturn
Expand Down Expand Up @@ -133,4 +135,28 @@ internal class QueryChannelsStateLogicTest {
queryChannelsSpec.cids `should contain same` setOf(testCid, channel1.cid, channel2.cid)
verify(mutableState).setChannels(channels.associateBy { it.cid })
}

@Test
fun `getActiveChannelState should return channel when it is active in state registry`() {
val channel = randomChannel(type = type, id = id)
val channelState: ChannelState = mock {
on(it.toChannel()) doReturn channel
}

whenever(stateRegistry.isActiveChannel(type, id)) doReturn true
whenever(stateRegistry.channel(type, id)) doReturn channelState

val result = queryChannelsStateLogic.getActiveChannelState(testCid)

assertEquals(channel, result)
}

@Test
fun `getActiveChannelState should return null when channel is not active in state registry`() {
whenever(stateRegistry.isActiveChannel(type, id)) doReturn false

val result = queryChannelsStateLogic.getActiveChannelState(testCid)

assertNull(result)
}
}
Loading