Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public open class AuthenticationActivity : Activity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
if (savedInstanceState != null) {
WebAuthProvider.onRestoreInstanceState(savedInstanceState)
WebAuthProvider.onRestoreInstanceState(savedInstanceState, this)
intentLaunched = savedInstanceState.getBoolean(EXTRA_INTENT_LAUNCHED, false)
}
}
Expand Down
15 changes: 12 additions & 3 deletions auth0/src/main/java/com/auth0/android/provider/OAuthManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ internal class OAuthManager(
@get:VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
internal val dPoP: DPoP? = null
) : ResumableManager() {

private val parameters: MutableMap<String, String>
private val headers: MutableMap<String, String>
private val ctOptions: CustomTabsOptions
Expand Down Expand Up @@ -211,7 +212,8 @@ internal class OAuthManager(
auth0 = account,
idTokenVerificationIssuer = idTokenVerificationIssuer,
idTokenVerificationLeeway = idTokenVerificationLeeway,
customAuthorizeUrl = this.customAuthorizeUrl
customAuthorizeUrl = this.customAuthorizeUrl,
dPoPEnabled = dPoP != null
)
}

Expand Down Expand Up @@ -387,14 +389,21 @@ internal class OAuthManager(

internal fun OAuthManager.Companion.fromState(
state: OAuthManagerState,
callback: Callback<Credentials, AuthenticationException>
callback: Callback<Credentials, AuthenticationException>,
context: Context
): OAuthManager {
// Enable DPoP on the restored PKCE's AuthenticationAPIClient so that
// the token exchange request includes the DPoP proof after process restore.
if (state.dPoPEnabled && state.pkce != null) {
state.pkce.apiClient.useDPoP(context)
}
return OAuthManager(
account = state.auth0,
ctOptions = state.ctOptions,
parameters = state.parameters,
callback = callback,
customAuthorizeUrl = state.customAuthorizeUrl
customAuthorizeUrl = state.customAuthorizeUrl,
dPoP = if (state.dPoPEnabled) DPoP(context) else null
).apply {
setHeaders(
state.headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import android.util.Base64
import androidx.core.os.ParcelCompat
import com.auth0.android.Auth0
import com.auth0.android.authentication.AuthenticationAPIClient
import com.auth0.android.dpop.DPoP
import com.auth0.android.request.internal.GsonProvider
import com.google.gson.Gson

Expand All @@ -20,7 +19,7 @@ internal data class OAuthManagerState(
val idTokenVerificationLeeway: Int?,
val idTokenVerificationIssuer: String?,
val customAuthorizeUrl: String? = null,
val dPoP: DPoP? = null
val dPoPEnabled: Boolean = false
) {

private class OAuthManagerJson(
Expand All @@ -37,7 +36,7 @@ internal data class OAuthManagerState(
val idTokenVerificationLeeway: Int?,
val idTokenVerificationIssuer: String?,
val customAuthorizeUrl: String? = null,
val dPoP: DPoP? = null
val dPoPEnabled: Boolean = false
)

fun serializeToJson(
Expand All @@ -62,7 +61,7 @@ internal data class OAuthManagerState(
idTokenVerificationIssuer = idTokenVerificationIssuer,
idTokenVerificationLeeway = idTokenVerificationLeeway,
customAuthorizeUrl = this.customAuthorizeUrl,
dPoP = this.dPoP
dPoPEnabled = this.dPoPEnabled
)
return gson.toJson(json)
} finally {
Expand Down Expand Up @@ -112,7 +111,7 @@ internal data class OAuthManagerState(
idTokenVerificationIssuer = oauthManagerJson.idTokenVerificationIssuer,
idTokenVerificationLeeway = oauthManagerJson.idTokenVerificationLeeway,
customAuthorizeUrl = oauthManagerJson.customAuthorizeUrl,
dPoP = oauthManagerJson.dPoP
dPoPEnabled = oauthManagerJson.dPoPEnabled
)
} finally {
parcel.recycle()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,50 @@ public object WebAuthProvider : SenderConstraining<WebAuthProvider> {
private val callbacks = CopyOnWriteArraySet<Callback<Credentials, AuthenticationException>>()
private val parCallbacks = CopyOnWriteArraySet<Callback<AuthorizationCode, AuthenticationException>>()

// Buffers a state-restore result that completed before any callback was registered (process
// death during login: the restored AuthenticationActivity finishes the token exchange before
// the host app can subscribe). Delivered to the next addCallback subscriber.
private sealed class RecoveredResult {
class Success(val credentials: Credentials) : RecoveredResult()
class Failure(val error: AuthenticationException) : RecoveredResult()
}

private val recoveryLock = Any()
private var pendingRecovered: RecoveredResult? = null

@JvmStatic
@get:VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
internal var managerInstance: ResumableManager? = null
private set

/**
* Registers a callback for Universal Login results from the state-restore path
* ([onRestoreInstanceState]). A result buffered before this call is delivered immediately and
* consumed. Normal in-process logins resolve through the [Builder.start] callback, not here.
*/
@JvmStatic
public fun addCallback(callback: Callback<Credentials, AuthenticationException>) {
callbacks += callback
val buffered = synchronized(recoveryLock) {
val pending = pendingRecovered
if (pending != null) {
pendingRecovered = null
} else {
callbacks += callback
}
pending
}
when (buffered) {
is RecoveredResult.Success -> callback.onSuccess(buffered.credentials)
is RecoveredResult.Failure -> callback.onFailure(buffered.error)
null -> {}
}
}

@JvmStatic
public fun removeCallback(callback: Callback<Credentials, AuthenticationException>) {
callbacks -= callback
synchronized(recoveryLock) {
callbacks -= callback
}
}

// Public methods
Expand Down Expand Up @@ -142,7 +173,7 @@ public object WebAuthProvider : SenderConstraining<WebAuthProvider> {
}
}

internal fun onRestoreInstanceState(bundle: Bundle) {
internal fun onRestoreInstanceState(bundle: Bundle, context: Context) {
if (managerInstance == null) {
val oauthStateJson = bundle.getString(KEY_BUNDLE_OAUTH_MANAGER_STATE).orEmpty()
val parStateJson = bundle.getString(KEY_BUNDLE_PAR_MANAGER_STATE).orEmpty()
Expand All @@ -152,17 +183,32 @@ public object WebAuthProvider : SenderConstraining<WebAuthProvider> {
state,
object : Callback<Credentials, AuthenticationException> {
override fun onSuccess(result: Credentials) {
for (callback in callbacks) {
val subscribers = synchronized(recoveryLock) {
if (callbacks.isEmpty()) {
pendingRecovered = RecoveredResult.Success(result)
return
}
callbacks.toList()
}
for (callback in subscribers) {
callback.onSuccess(result)
}
}

override fun onFailure(error: AuthenticationException) {
for (callback in callbacks) {
val subscribers = synchronized(recoveryLock) {
if (callbacks.isEmpty()) {
pendingRecovered = RecoveredResult.Failure(error)
return
}
callbacks.toList()
}
for (callback in subscribers) {
callback.onFailure(error)
}
}
}
},
context
)
} else if (parStateJson.isNotBlank()) {
val state = PARCodeManagerState.deserializeState(parStateJson)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
package com.auth0.android.provider

import android.content.Context
import android.graphics.Color
import com.auth0.android.Auth0
import com.auth0.android.authentication.AuthenticationAPIClient
import com.auth0.android.authentication.AuthenticationException
import com.auth0.android.callback.Callback
import com.auth0.android.result.Credentials
import com.nhaarman.mockitokotlin2.mock
import com.nhaarman.mockitokotlin2.whenever
import org.hamcrest.MatcherAssert.assertThat
import org.hamcrest.core.Is.`is`
import org.junit.Assert
import org.junit.Test
import org.junit.runner.RunWith
Expand Down Expand Up @@ -44,4 +52,138 @@ internal class OAuthManagerStateTest {
Assert.assertEquals(1, deserializedState.idTokenVerificationLeeway)
Assert.assertEquals("issuer", deserializedState.idTokenVerificationIssuer)
}

@Test
fun `serialize should persist dPoPEnabled flag as true`() {
val auth0 = Auth0.getInstance("clientId", "domain")
val state = OAuthManagerState(
auth0 = auth0,
parameters = mapOf("param1" to "value1"),
headers = mapOf("header1" to "value1"),
requestCode = 1,
ctOptions = CustomTabsOptions.newBuilder()
.showTitle(true)
.withBrowserPicker(
BrowserPicker.newBuilder().withAllowedPackages(emptyList()).build()
)
.build(),
pkce = PKCE(mock(), "redirectUri", mapOf("header1" to "value1")),
idTokenVerificationLeeway = 1,
idTokenVerificationIssuer = "issuer",
dPoPEnabled = true
)

val json = state.serializeToJson()

Assert.assertTrue(json.isNotBlank())
Assert.assertTrue(json.contains("\"dPoPEnabled\":true"))

val deserializedState = OAuthManagerState.deserializeState(json)

Assert.assertTrue(deserializedState.dPoPEnabled)
}

@Test
fun `serialize should persist dPoPEnabled flag as false by default`() {
val auth0 = Auth0.getInstance("clientId", "domain")
val state = OAuthManagerState(
auth0 = auth0,
parameters = mapOf("param1" to "value1"),
headers = mapOf("header1" to "value1"),
requestCode = 1,
ctOptions = CustomTabsOptions.newBuilder()
.showTitle(true)
.withBrowserPicker(
BrowserPicker.newBuilder().withAllowedPackages(emptyList()).build()
)
.build(),
pkce = PKCE(mock(), "redirectUri", mapOf("header1" to "value1")),
idTokenVerificationLeeway = 1,
idTokenVerificationIssuer = "issuer"
)

val json = state.serializeToJson()

val deserializedState = OAuthManagerState.deserializeState(json)

Assert.assertFalse(deserializedState.dPoPEnabled)
}

@Test
fun `deserialize should default dPoPEnabled to false when field is missing from JSON`() {
val auth0 = Auth0.getInstance("clientId", "domain")
val state = OAuthManagerState(
auth0 = auth0,
parameters = emptyMap(),
headers = emptyMap(),
requestCode = 0,
ctOptions = CustomTabsOptions.newBuilder()
.showTitle(true)
.withBrowserPicker(
BrowserPicker.newBuilder().withAllowedPackages(emptyList()).build()
)
.build(),
pkce = PKCE(mock(), "redirectUri", emptyMap()),
idTokenVerificationLeeway = null,
idTokenVerificationIssuer = null
)

val json = state.serializeToJson()
// Remove the dPoPEnabled field to simulate legacy JSON
val legacyJson = json.replace(",\"dPoPEnabled\":false", "")

val deserializedState = OAuthManagerState.deserializeState(legacyJson)

Assert.assertFalse(deserializedState.dPoPEnabled)
}

@Test
fun `fromState should re-enable DPoP on the restored PKCE's API client when dPoPEnabled is true`() {
val context = mock<Context>()
whenever(context.applicationContext).thenReturn(context)
val auth0 = Auth0.getInstance("clientId", "domain")
val apiClient = AuthenticationAPIClient(auth0)
val state = OAuthManagerState(
auth0 = auth0,
parameters = emptyMap(),
headers = emptyMap(),
requestCode = 0,
ctOptions = CustomTabsOptions.newBuilder().build(),
pkce = PKCE(apiClient, "codeVerifier", "redirectUri", "codeChallenge", emptyMap()),
idTokenVerificationLeeway = null,
idTokenVerificationIssuer = null,
dPoPEnabled = true
)
val callback = mock<Callback<Credentials, AuthenticationException>>()

OAuthManager.fromState(state, callback, context)

// This is the actual regression guard: the token exchange after process death only
// includes the DPoP proof because fromState re-enables DPoP on the restored API client.
assertThat(apiClient.isDPoPEnabled, `is`(true))
}

@Test
fun `fromState should not enable DPoP on the restored PKCE's API client when dPoPEnabled is false`() {
val context = mock<Context>()
whenever(context.applicationContext).thenReturn(context)
val auth0 = Auth0.getInstance("clientId", "domain")
val apiClient = AuthenticationAPIClient(auth0)
val state = OAuthManagerState(
auth0 = auth0,
parameters = emptyMap(),
headers = emptyMap(),
requestCode = 0,
ctOptions = CustomTabsOptions.newBuilder().build(),
pkce = PKCE(apiClient, "codeVerifier", "redirectUri", "codeChallenge", emptyMap()),
idTokenVerificationLeeway = null,
idTokenVerificationIssuer = null,
dPoPEnabled = false
)
val callback = mock<Callback<Credentials, AuthenticationException>>()

OAuthManager.fromState(state, callback, context)

assertThat(apiClient.isDPoPEnabled, `is`(false))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import android.app.Activity
import android.content.Context
import android.content.Intent
import android.net.Uri
import android.os.Bundle
import android.os.Parcelable
import androidx.test.espresso.intent.matcher.IntentMatchers
import androidx.test.espresso.intent.matcher.UriMatchers
Expand Down Expand Up @@ -2958,6 +2959,31 @@ public class WebAuthProviderTest {
mockAPI.shutdown()
}

@Test
public fun shouldReEnableDPoPOnOAuthManagerAfterProcessDeathRestore() {
`when`(mockKeyStore.hasKeyPair()).thenReturn(true)
`when`(mockKeyStore.getKeyPair()).thenReturn(Pair(mock(), FakeECPublicKey()))

WebAuthProvider.useDPoP(mockContext)
.login(account)
.start(activity, callback)

val bundle = Bundle()
WebAuthProvider.onSaveInstanceState(bundle)

// Simulate the host process being killed and recreated: the manager instance is gone,
// and the activity is recreated with the saved state.
WebAuthProvider.resetManagerInstance()
WebAuthProvider.onRestoreInstanceState(bundle, activity)

val restoredManager = WebAuthProvider.managerInstance as OAuthManager
// This asserts the save/restore wiring reconstructs a DPoP-enabled manager. The actual
// regression guard — that DPoP is re-enabled on the restored PKCE's API client so the
// token exchange carries the proof — lives in OAuthManagerStateTest.fromState tests,
// since OAuthManager.pkce is private and not reachable here without reflection.
assertThat(restoredManager.dPoP, `is`(notNullValue()))
Comment thread
pmathew92 marked this conversation as resolved.
}

//** ** ** ** ** ** **//
//** ** ** ** ** ** **//
//** Helpers Functions**//
Expand Down
Loading