Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,127 +16,79 @@
package com.amplifyframework.auth.cognito

import android.content.Context
import androidx.annotation.VisibleForTesting
import com.amplifyframework.AmplifyException
import com.amplifyframework.auth.cognito.data.AWSCognitoAuthCredentialStore
import com.amplifyframework.auth.cognito.data.AWSCognitoLegacyCredentialStore
import com.amplifyframework.auth.cognito.helpers.collectWhile
import com.amplifyframework.auth.exceptions.InvalidStateException
import com.amplifyframework.logging.Logger
import com.amplifyframework.statemachine.StateChangeListenerToken
import com.amplifyframework.statemachine.codegen.data.AmplifyCredential
import com.amplifyframework.statemachine.codegen.data.CredentialType
import com.amplifyframework.statemachine.codegen.events.CredentialStoreEvent
import com.amplifyframework.statemachine.codegen.states.CredentialStoreState
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.coroutines.resumeWithException
import kotlin.coroutines.suspendCoroutine
import kotlinx.coroutines.flow.drop
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.onSubscription

internal interface StoreClientBehavior {
suspend fun loadCredentials(credentialType: CredentialType): AmplifyCredential
suspend fun storeCredentials(credentialType: CredentialType, amplifyCredential: AmplifyCredential)
suspend fun clearCredentials(credentialType: CredentialType)
}

internal class CredentialStoreClient(configuration: AuthConfiguration, context: Context, val logger: Logger) :
StoreClientBehavior {
private val credentialStoreStateMachine = createCredentialStoreStateMachine(configuration, context)
internal class CredentialStoreClient @VisibleForTesting constructor(
private val credentialStoreStateMachine: CredentialStoreStateMachine,
val logger: Logger
) : StoreClientBehavior {

private fun createCredentialStoreStateMachine(
configuration: AuthConfiguration,
context: Context
): CredentialStoreStateMachine {
val awsCognitoAuthCredentialStore = AWSCognitoAuthCredentialStore(context.applicationContext, configuration)
val legacyCredentialStore = AWSCognitoLegacyCredentialStore(context.applicationContext, configuration)
val credentialStoreEnvironment =
CredentialStoreEnvironment(awsCognitoAuthCredentialStore, legacyCredentialStore, logger)
return CredentialStoreStateMachine(credentialStoreEnvironment)
}
constructor(configuration: AuthConfiguration, context: Context, logger: Logger) : this(
credentialStoreStateMachine = createCredentialStoreStateMachine(configuration, context, logger),
logger = logger
)

private fun listenForResult(
event: CredentialStoreEvent,
onSuccess: (Result<AmplifyCredential>) -> Unit,
onError: (Exception) -> Unit
) {
val token = StateChangeListenerToken()
val credentialStoreStateListener = OneShotCredentialStoreStateListener(
{
credentialStoreStateMachine.cancel(token)
onSuccess(it)
},
{
credentialStoreStateMachine.cancel(token)
onError(it)
},
logger
private suspend fun listenForResult(event: CredentialStoreEvent.EventType): AmplifyCredential {
var result: Result<AmplifyCredential>? = null
credentialStoreStateMachine.state
.onSubscription { credentialStoreStateMachine.send(CredentialStoreEvent(event)) }
.drop(1) // skip current state
Comment thread
mattcreaser marked this conversation as resolved.
.onEach { state ->
when (state) {
is CredentialStoreState.Error -> result = result ?: Result.failure(state.error)
is CredentialStoreState.Success -> result = Result.success(state.storedCredentials)
else -> Unit // no-op
}
}
.collectWhile { state -> state !is CredentialStoreState.Idle }
return result?.getOrThrow() ?: throw InvalidStateException(
message = "Credential operation failed",
recoverySuggestion = AmplifyException.REPORT_BUG_TO_AWS_SUGGESTION
)
credentialStoreStateMachine.listen(
token,
credentialStoreStateListener::listen
) { credentialStoreStateMachine.send(event) }
}

override suspend fun loadCredentials(credentialType: CredentialType): AmplifyCredential =
suspendCoroutine { continuation ->
listenForResult(
CredentialStoreEvent(CredentialStoreEvent.EventType.LoadCredentialStore(credentialType)),
{ continuation.resumeWith(it) },
{ continuation.resumeWithException(it) }
)
}
listenForResult(CredentialStoreEvent.EventType.LoadCredentialStore(credentialType))

override suspend fun storeCredentials(credentialType: CredentialType, amplifyCredential: AmplifyCredential) =
suspendCoroutine { continuation ->
listenForResult(
CredentialStoreEvent(
CredentialStoreEvent.EventType.StoreCredentials(credentialType, amplifyCredential)
),
{ continuation.resumeWith(Result.success(Unit)) },
{ continuation.resumeWithException(it) }
)
}

override suspend fun clearCredentials(credentialType: CredentialType) = suspendCoroutine { continuation ->
listenForResult(
CredentialStoreEvent(CredentialStoreEvent.EventType.ClearCredentialStore(credentialType)),
{ continuation.resumeWith(Result.success(Unit)) },
{ continuation.resumeWithException(it) }
)
override suspend fun storeCredentials(credentialType: CredentialType, amplifyCredential: AmplifyCredential) {
listenForResult(CredentialStoreEvent.EventType.StoreCredentials(credentialType, amplifyCredential))
}

/*
This class is a necessary workaround due to undesirable threading issues within the Auth State Machine. If state
machine threading is improved, this class should be considered for removal.
*/
internal class OneShotCredentialStoreStateListener(
val onSuccess: (Result<AmplifyCredential>) -> Unit,
val onError: (Exception) -> Unit,
val logger: Logger
) {
private var capturedSuccess: Result<AmplifyCredential>? = null
private var capturedError: Exception? = null
private val isActive = AtomicBoolean(true)
fun listen(storeState: CredentialStoreState) {
logger.verbose("Credential Store State Change: $storeState")
when (storeState) {
is CredentialStoreState.Success -> {
capturedSuccess = Result.success(storeState.storedCredentials)
}

is CredentialStoreState.Error -> {
capturedError = storeState.error
}

is CredentialStoreState.Idle -> {
val success = capturedSuccess
val error = capturedError
override suspend fun clearCredentials(credentialType: CredentialType) {
listenForResult(CredentialStoreEvent.EventType.ClearCredentialStore(credentialType))
}

if ((success != null || error != null) && isActive.getAndSet(false)) {
if (success != null) {
onSuccess(success)
} else if (error != null) {
onError(error)
}
}
}
else -> Unit
}
companion object {
private fun createCredentialStoreStateMachine(
configuration: AuthConfiguration,
context: Context,
logger: Logger
): CredentialStoreStateMachine {
val awsCognitoAuthCredentialStore =
AWSCognitoAuthCredentialStore(context.applicationContext, configuration)
val legacyCredentialStore = AWSCognitoLegacyCredentialStore(context.applicationContext, configuration)
val credentialStoreEnvironment =
CredentialStoreEnvironment(awsCognitoAuthCredentialStore, legacyCredentialStore, logger)
return CredentialStoreStateMachine(credentialStoreEnvironment)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package com.amplifyframework.statemachine

import java.util.UUID
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.DelicateCoroutinesApi
Expand All @@ -30,14 +29,6 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext

internal typealias OnSubscribedCallback = () -> Unit

internal class StateChangeListenerToken private constructor(val uuid: UUID) {
constructor() : this(UUID.randomUUID())
override fun equals(other: Any?) = other is StateChangeListenerToken && other.uuid == uuid
override fun hashCode() = uuid.hashCode()
}

/**
* Model, mutate and process effects of a system as a finite state automaton. It consists of:
* State - which represents the current state of the system
Expand All @@ -49,7 +40,6 @@ internal class StateChangeListenerToken private constructor(val uuid: UUID) {
* @param resolver responsible for mutating state based on incoming events
* @param environment holds system specific environment info accessible to Effects/Actions
* @param executor responsible for invoking effects
* @param concurrentQueue event queue or thread pool for effect executor and subscription callback
* @param initialState starting state of the system (resolver default state will be used if omitted)
*/
internal open class StateMachine<StateType : State, EnvironmentType : Environment>(
Expand Down Expand Up @@ -77,50 +67,6 @@ internal open class StateMachine<StateType : State, EnvironmentType : Environmen
private val stateMachineContext = SupervisorJob() + newSingleThreadContext("StateMachineContext")
private val stateMachineScope = CoroutineScope(stateMachineContext)

// weak wrapper ??
private val subscribers: MutableMap<StateChangeListenerToken, (StateType) -> Unit> = mutableMapOf()

// atomic value ??
private val pendingCancellations: MutableSet<StateChangeListenerToken> = mutableSetOf()

/**
* Start listening to state changes updates. Asynchronously invoke listener on a background queue with the current state.
* Both `listener` and `onSubscribe` will be invoked on a background queue.
* @param listener listener to be invoked on state changes
* @param onSubscribe callback to invoke when subscription is complete
* @return token that can be used to unsubscribe the listener
*/
@Deprecated("Collect from state flow instead")
fun listen(token: StateChangeListenerToken, listener: (StateType) -> Unit, onSubscribe: OnSubscribedCallback?) {
stateMachineScope.launch {
addSubscription(token, listener, onSubscribe)
}
}

/**
* Stop listening to state changes updates. Register a pending cancellation if a new event comes in between the time
* `cancel` is called and the time the pending cancellation is processed, the event will not be dispatched to the listener.
* @param token identifies the listener to be removed
*/
@Deprecated("Collect from state flow instead")
fun cancel(token: StateChangeListenerToken) {
pendingCancellations.add(token)
stateMachineScope.launch {
removeSubscription(token)
}
}

/**
* Invoke `completion` with the current state
* @param completion callback to invoke with the current state
*/
@Deprecated("Use suspending version instead")
fun getCurrentState(completion: (StateType) -> Unit) {
stateMachineScope.launch {
completion(getCurrentState())
}
}

/**
* Get the current state, dispatching to the state machine context for the read.
*/
Expand All @@ -130,35 +76,6 @@ internal open class StateMachine<StateType : State, EnvironmentType : Environmen
_state.tryEmit(newState)
}

/**
* Register a listener.
* @param token token, which will be retained in the subscribers map
* @param listener listener to invoke when the state has changed
* @param onSubscribe callback to invoke when subscription is complete
*/
private suspend fun addSubscription(
token: StateChangeListenerToken,
listener: (StateType) -> Unit,
onSubscribe: OnSubscribedCallback?
) {
if (pendingCancellations.contains(token)) return
val currentState = getCurrentState()
subscribers[token] = listener
onSubscribe?.invoke()
stateMachineScope.launch(dispatcherQueue) {
listener.invoke(currentState)
}
}

/**
* Unregister a listener.
* @param token token of the listener to remove
*/
private fun removeSubscription(token: StateChangeListenerToken) {
pendingCancellations.remove(token)
subscribers.remove(token)
}

/**
* Send `event` to the StateMachine for resolution, and applies any effects and new states returned from the resolution.
* @param event event to send to the system
Expand All @@ -169,22 +86,6 @@ internal open class StateMachine<StateType : State, EnvironmentType : Environmen
}
}

/**
* Notify all the listeners with the new state.
* @param subscriber pair containing the subscriber token and listener
* @param newState new state to be sent
* @return true if the subscriber was notified, false if the token was null or a cancellation was pending
*/
private fun notifySubscribers(
subscriber: Map.Entry<StateChangeListenerToken, (StateType) -> Unit>,
newState: StateType
): Boolean {
val token = subscriber.key
if (pendingCancellations.contains(token)) return false
subscriber.value(newState)
return true
}

/**
* Resolver mutates the state based on current state and incoming event, and returns resolution with new state and
* effects. If the state machine's state after resolving is not equal to the state before the event, update the
Expand All @@ -197,8 +98,6 @@ internal open class StateMachine<StateType : State, EnvironmentType : Environmen
val resolution = resolver.resolve(currentState, event)
if (currentState != resolution.newState) {
setCurrentState(resolution.newState)
val subscribersToRemove = subscribers.filter { !notifySubscribers(it, resolution.newState) }
subscribersToRemove.forEach { subscribers.remove(it.key) }
}
execute(resolution.actions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,12 @@ import io.mockk.mockkStatic
import io.mockk.slot
import java.io.File
import java.util.TimeZone
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import kotlin.reflect.full.callSuspend
import kotlin.reflect.full.declaredFunctions
import kotlin.test.assertEquals
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.resetMain
import kotlinx.coroutines.test.setMain
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -211,13 +210,7 @@ class AWSCognitoAuthPluginFeatureTest(private val testCase: FeatureTestCase) {
actual shouldEqualJson expected
}
is ExpectationShapes.State -> {
val getStateLatch = CountDownLatch(1)
var authState: AuthState? = null
authStateMachine.getCurrentState {
authState = it
getStateLatch.countDown()
}
getStateLatch.await(10, TimeUnit.SECONDS)
val authState = runBlocking { authStateMachine.getCurrentState() }
assertEquals(getState(validation.expectedState), authState)
}
}
Expand Down
Loading
Loading