Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import bootstrap.liftweb.AuthenticationMethods
import bootstrap.liftweb.RudderConfig
import bootstrap.liftweb.RudderInMemoryUserDetailsService
import bootstrap.liftweb.RudderProperties
import cats.syntax.apply.*
import com.github.benmanes.caffeine.cache.Caffeine
import com.normation.errors.IOResult
import com.normation.plugins.RudderPluginModule
Expand Down Expand Up @@ -72,6 +73,7 @@ import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.core.ParameterizedTypeReference
import org.springframework.core.convert.converter.Converter
import org.springframework.http.MediaType
import org.springframework.http.RequestEntity
import org.springframework.http.ResponseEntity
import org.springframework.security.authentication.AbstractAuthenticationToken
Expand Down Expand Up @@ -1050,7 +1052,7 @@ class RudderOidcUserService(
roleApiMapping: RoleApiMapping
) extends OidcUserService with RudderUserServerMapping[OidcUserRequest, OidcUser, RudderUserDetail & OidcUser] {
// we need to use our copy of DefaultOAuth2UserService to log/manage errors
super.setOauth2UserService(new RudderDefaultOAuth2UserService()): @unchecked
super.setOauth2UserService(new RudderDefaultOAuth2UserService(registrationRepository)): @unchecked

override val protocolId = RudderOidcUserService.PROTOCOL_ID
override val protocolName = "OIDC"
Expand All @@ -1077,7 +1079,7 @@ class RudderOAuth2UserService(
roleApiMapping: RoleApiMapping
) extends OAuth2UserService[OAuth2UserRequest, OAuth2User]
with RudderUserServerMapping[OAuth2UserRequest, OAuth2User, RudderUserDetail & OAuth2User] {
val defaultUserService = new RudderDefaultOAuth2UserService()
val defaultUserService = new RudderDefaultOAuth2UserService(registrationRepository)

override val protocolId = RudderOAuth2UserService.PROTOCOL_ID
override val protocolName = "OAuth2"
Expand Down Expand Up @@ -1130,7 +1132,8 @@ object BuildLogout {
}
}

class RudderDefaultOAuth2UserService extends DefaultOAuth2UserService with DebugOAuth2Attributes {
class RudderDefaultOAuth2UserService(registrationRepository: RudderClientRegistrationRepository)
extends DefaultOAuth2UserService with DebugOAuth2Attributes {

/*
* this is a copy of parent method with more logs/error management
Expand All @@ -1147,7 +1150,9 @@ class RudderDefaultOAuth2UserService extends DefaultOAuth2UserService with Debug
private val requestEntityConverter: Converter[OAuth2UserRequest, RequestEntity[?]] = new OAuth2UserRequestEntityConverter

private val restOperations: RestOperations = {
val restTemplate = new RestTemplate
val restTemplate = new RestTemplate()
// TODO: set message converters instead of intercepting exception ? https://docs.spring.io/spring-framework/reference/web/webmvc/mvc-config/message-converters.html
// restTemplate.setMessageConverters(List(JwtConverter(registration.getProviderDetails.getJwkSetUri)).asJava)
restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler)
restTemplate
}
Expand Down Expand Up @@ -1205,8 +1210,11 @@ class RudderDefaultOAuth2UserService extends DefaultOAuth2UserService with Debug
userRequest: OAuth2UserRequest,
request: RequestEntity[?]
): ResponseEntity[java.util.Map[String, AnyRef]] = {
given jwtConfig: Option[JwtConfig] =
registrationRepository.registrations.get(userRequest.getClientRegistration.getRegistrationId).map(_.jwtConfig)
given ClientRegistration = userRequest.getClientRegistration
try {
return this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE)
this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE)
} catch {
case ex: OAuth2AuthorizationException =>
var oauth2Error: OAuth2Error = ex.getError
Expand All @@ -1226,6 +1234,8 @@ class RudderDefaultOAuth2UserService extends DefaultOAuth2UserService with Debug
null
)
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString, ex)
case JwtContentTypeException(jwtResponse) =>
jwtResponse
case ex: UnknownContentTypeException =>
val errorMessage: String =
"An error occurred while attempting to retrieve the UserInfo Resource from '" + userRequest.getClientRegistration.getProviderDetails.getUserInfoEndpoint.getUri + "': response contains invalid content type '" + ex.getContentType.toString + "'. " + "The UserInfo Response should return a JSON object (content type 'application/json') " + "that contains a collection of name and value pairs of the claims about the authenticated End-User. " + "Please ensure the UserInfo Uri in UserInfoEndpoint for Client Registration '" + userRequest.getClientRegistration.getRegistrationId + "' conforms to the UserInfo Endpoint, " + "as defined in OpenID Connect 1.0: 'https://openid.net/specs/openid-connect-core-1_0.html#UserInfo'"
Expand Down Expand Up @@ -1565,3 +1575,23 @@ trait DebugOAuth2TokenAttributes extends DebugOAuth2Attributes {
def pivotAttributeRegistration: RegistrationWithPivotAttribute
def pivotAttribute: String = pivotAttributeRegistration.pivotAttributeName
}

// Special case for JWT interception upon content-type exception in getResponse : signed/encrypted response.
// Spring has no direct support for these, we need to use Nimbus for decrypting
object JwtContentTypeException {

private val APPLICATION_JWT: MediaType = new MediaType("application", "jwt")

def unapply(
ex: UnknownContentTypeException
)(using
jwtConfig: Option[JwtConfig],
clientRegistration: ClientRegistration
): Option[ResponseEntity[util.Map[String, AnyRef]]] = {
given JwtConfig = jwtConfig.getOrElse(JwtConfig.default(clientRegistration.getProviderDetails.getJwkSetUri))
Option(ex.getResponseHeaders).filter(_ => ex.getContentType.isCompatibleWith(APPLICATION_JWT)) *>
decodeJwt(ex.getResponseBodyAsString).toOption
.map(jwt => ResponseEntity.ok().body(jwt.getClaims))
}

}
Loading