Skip to content

Commit 8874f50

Browse files
committed
Ensure ID Token is updated after refresh token (Reactive)
Closes gh-17188 Signed-off-by: Evgeniy Cheban <mister.cheban@gmail.com>
1 parent 20493ef commit 8874f50

7 files changed

Lines changed: 876 additions & 8 deletions

File tree

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
/*
2+
* Copyright 2004-present the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.client;
18+
19+
import java.time.Duration;
20+
import java.util.Collection;
21+
import java.util.HashSet;
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.Set;
25+
26+
import reactor.core.publisher.Mono;
27+
28+
import org.springframework.security.core.Authentication;
29+
import org.springframework.security.core.GrantedAuthority;
30+
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
31+
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
32+
import org.springframework.security.core.context.SecurityContext;
33+
import org.springframework.security.core.context.SecurityContextImpl;
34+
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
35+
import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory;
36+
import org.springframework.security.oauth2.client.oidc.userinfo.OidcReactiveOAuth2UserService;
37+
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
38+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
39+
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
40+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
41+
import org.springframework.security.oauth2.core.OAuth2Error;
42+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
43+
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
44+
import org.springframework.security.oauth2.core.oidc.OidcScopes;
45+
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
46+
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
47+
import org.springframework.security.oauth2.jwt.JwtException;
48+
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
49+
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
50+
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
51+
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
52+
import org.springframework.util.Assert;
53+
import org.springframework.util.StringUtils;
54+
import org.springframework.web.server.ServerWebExchange;
55+
56+
/**
57+
* A {@link ReactiveOAuth2AuthorizationSuccessHandler} that refreshes an {@link OidcUser}
58+
* in the {@link SecurityContext} if the refreshed {@link OidcIdToken} is valid according
59+
* to <a href=
60+
* "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse">OpenID
61+
* Connect Core 1.0 - Section 12.2 Successful Refresh Response</a>
62+
*
63+
* @author Evgeniy Cheban
64+
* @since 7.1
65+
*/
66+
public final class RefreshTokenReactiveOAuth2AuthorizationSuccessHandler
67+
implements ReactiveOAuth2AuthorizationSuccessHandler {
68+
69+
private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
70+
71+
private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce";
72+
73+
private static final String REFRESH_TOKEN_RESPONSE_ERROR_URI = "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse";
74+
75+
// @formatter:off
76+
private static final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.deferContextual(Mono::just)
77+
.filter((c) -> c.hasKey(ServerWebExchange.class))
78+
.map((c) -> c.get(ServerWebExchange.class));
79+
// @formatter:on
80+
81+
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
82+
83+
private ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new ReactiveOidcIdTokenDecoderFactory();
84+
85+
private ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = new OidcReactiveOAuth2UserService();
86+
87+
private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities;
88+
89+
private Duration clockSkew = Duration.ofSeconds(60);
90+
91+
/**
92+
* Sets a {@link ServerSecurityContextRepository} to use for refreshing a
93+
* {@link SecurityContext}, defaults to
94+
* {@link WebSessionServerSecurityContextRepository}.
95+
* @param serverSecurityContextRepository the {@link ServerSecurityContextRepository}
96+
* to use
97+
*/
98+
public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) {
99+
Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null");
100+
this.serverSecurityContextRepository = serverSecurityContextRepository;
101+
}
102+
103+
/**
104+
* Sets a {@link ReactiveJwtDecoderFactory} to use for decoding refreshed oidc
105+
* id-token, defaults to {@link ReactiveOidcIdTokenDecoderFactory}.
106+
* @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} to use
107+
*/
108+
public void setJwtDecoderFactory(ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
109+
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
110+
this.jwtDecoderFactory = jwtDecoderFactory;
111+
}
112+
113+
/**
114+
* Sets a {@link GrantedAuthoritiesMapper} to use for mapping
115+
* {@link GrantedAuthority}s, defaults to no-op implementation.
116+
* @param authoritiesMapper the {@link GrantedAuthoritiesMapper} to use
117+
*/
118+
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
119+
Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
120+
this.authoritiesMapper = authoritiesMapper;
121+
}
122+
123+
/**
124+
* Sets a {@link ReactiveOAuth2UserService} to use for loading an {@link OidcUser}
125+
* from refreshed oidc id-token, defaults to {@link OidcReactiveOAuth2UserService}.
126+
* @param userService the {@link ReactiveOAuth2UserService} to use
127+
*/
128+
public void setUserService(ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService) {
129+
Assert.notNull(userService, "userService cannot be null");
130+
this.userService = userService;
131+
}
132+
133+
/**
134+
* Sets the maximum acceptable clock skew, which is used when checking the
135+
* {@link OidcIdToken#getIssuedAt()} to match the existing
136+
* {@link OidcUser#getIdToken()}'s issuedAt time, defaults to 60 seconds.
137+
* @param clockSkew the maximum acceptable clock skew to use
138+
*/
139+
public void setClockSkew(Duration clockSkew) {
140+
Assert.notNull(clockSkew, "clockSkew cannot be null");
141+
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
142+
this.clockSkew = clockSkew;
143+
}
144+
145+
@Override
146+
public Mono<Void> onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal,
147+
Map<String, Object> attributes) {
148+
if (!(principal instanceof OAuth2AuthenticationToken authenticationToken)
149+
|| authenticationToken.getClass() != OAuth2AuthenticationToken.class) {
150+
// If the application customizes the authentication result, then a custom
151+
// handler should be provided.
152+
return Mono.empty();
153+
}
154+
// The current principal must be an OidcUser.
155+
if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) {
156+
return Mono.empty();
157+
}
158+
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
159+
// The registrationId must match the one used to log in.
160+
if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) {
161+
return Mono.empty();
162+
}
163+
// Create, validate OidcIdToken and refresh OidcUser in the SecurityContext.
164+
return Mono.zip(serverWebExchange(attributes), accessTokenResponse(attributes)).flatMap((t2) -> {
165+
ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
166+
Map<String, Object> additionalParameters = t2.getT2().getAdditionalParameters();
167+
return jwtDecoder.decode((String) additionalParameters.get(OidcParameterNames.ID_TOKEN))
168+
.onErrorMap(JwtException.class, (ex) -> {
169+
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(),
170+
null);
171+
return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
172+
})
173+
.map((jwt) -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(),
174+
jwt.getClaims()))
175+
.doOnNext((idToken) -> validateIdToken(existingOidcUser, idToken))
176+
.flatMap((idToken) -> {
177+
OidcUserRequest userRequest = new OidcUserRequest(clientRegistration,
178+
authorizedClient.getAccessToken(), idToken);
179+
return this.userService.loadUser(userRequest);
180+
})
181+
.flatMap((oidcUser) -> refreshSecurityContext(t2.getT1(), clientRegistration, authenticationToken,
182+
oidcUser));
183+
});
184+
}
185+
186+
private Mono<ServerWebExchange> serverWebExchange(Map<String, Object> attributes) {
187+
if (attributes.get(ServerWebExchange.class.getName()) instanceof ServerWebExchange exchange) {
188+
return Mono.just(exchange);
189+
}
190+
return currentServerWebExchangeMono;
191+
}
192+
193+
private Mono<OAuth2AccessTokenResponse> accessTokenResponse(Map<String, Object> attributes) {
194+
if (attributes.get(OAuth2AccessTokenResponse.class.getName()) instanceof OAuth2AccessTokenResponse response) {
195+
// The response must contain the openid scope
196+
if (!response.getAccessToken().getScopes().contains(OidcScopes.OPENID)) {
197+
return Mono.empty();
198+
}
199+
// The response must contain an id_token
200+
Map<String, Object> additionalParameters = response.getAdditionalParameters();
201+
if (!StringUtils.hasText((String) additionalParameters.get(OidcParameterNames.ID_TOKEN))) {
202+
return Mono.empty();
203+
}
204+
return Mono.just(response);
205+
}
206+
return Mono.empty();
207+
}
208+
209+
private void validateIdToken(OidcUser existingOidcUser, OidcIdToken idToken) {
210+
// OpenID Connect Core 1.0 - Section 12.2 Successful Refresh Response
211+
// If an ID Token is returned as a result of a token refresh request, the
212+
// following requirements apply:
213+
// its iss Claim Value MUST be the same as in the ID Token issued when the
214+
// original authentication occurred,
215+
validateIssuer(existingOidcUser, idToken);
216+
// its sub Claim Value MUST be the same as in the ID Token issued when the
217+
// original authentication occurred,
218+
validateSubject(existingOidcUser, idToken);
219+
// its iat Claim MUST represent the time that the new ID Token is issued,
220+
validateIssuedAt(existingOidcUser, idToken);
221+
// its aud Claim Value MUST be the same as in the ID Token issued when the
222+
// original authentication occurred,
223+
validateAudience(existingOidcUser, idToken);
224+
// if the ID Token contains an auth_time Claim, its value MUST represent the time
225+
// of the original authentication - not the time that the new ID token is issued,
226+
validateAuthenticatedAt(existingOidcUser, idToken);
227+
// it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of
228+
// the original authentication contained nonce; however, if it is present, its
229+
// value MUST be the same as in the ID Token issued at the time of the original
230+
// authentication,
231+
validateNonce(existingOidcUser, idToken);
232+
}
233+
234+
private void validateIssuer(OidcUser existingOidcUser, OidcIdToken idToken) {
235+
if (!idToken.getIssuer().toString().equals(existingOidcUser.getIdToken().getIssuer().toString())) {
236+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issuer",
237+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
238+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
239+
}
240+
}
241+
242+
private void validateSubject(OidcUser existingOidcUser, OidcIdToken idToken) {
243+
if (!idToken.getSubject().equals(existingOidcUser.getIdToken().getSubject())) {
244+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid subject",
245+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
246+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
247+
}
248+
}
249+
250+
private void validateIssuedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
251+
if (!idToken.getIssuedAt().isAfter(existingOidcUser.getIdToken().getIssuedAt().minus(this.clockSkew))) {
252+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issued at time",
253+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
254+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
255+
}
256+
}
257+
258+
private void validateAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
259+
if (!isValidAudience(existingOidcUser, idToken)) {
260+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid audience",
261+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
262+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
263+
}
264+
}
265+
266+
private boolean isValidAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
267+
List<String> idTokenAudiences = idToken.getAudience();
268+
Set<String> oidcUserAudiences = new HashSet<>(existingOidcUser.getIdToken().getAudience());
269+
if (idTokenAudiences.size() != oidcUserAudiences.size()) {
270+
return false;
271+
}
272+
for (String audience : idTokenAudiences) {
273+
if (!oidcUserAudiences.contains(audience)) {
274+
return false;
275+
}
276+
}
277+
return true;
278+
}
279+
280+
private void validateAuthenticatedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
281+
if (idToken.getAuthenticatedAt() == null) {
282+
return;
283+
}
284+
if (!idToken.getAuthenticatedAt().equals(existingOidcUser.getIdToken().getAuthenticatedAt())) {
285+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid authenticated at time",
286+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
287+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
288+
}
289+
}
290+
291+
private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) {
292+
if (!StringUtils.hasText(idToken.getNonce())) {
293+
return;
294+
}
295+
if (!idToken.getNonce().equals(existingOidcUser.getIdToken().getNonce())) {
296+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE, "Invalid nonce",
297+
REFRESH_TOKEN_RESPONSE_ERROR_URI);
298+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
299+
}
300+
}
301+
302+
private Mono<Void> refreshSecurityContext(ServerWebExchange exchange, ClientRegistration clientRegistration,
303+
OAuth2AuthenticationToken authenticationToken, OidcUser oidcUser) {
304+
Collection<? extends GrantedAuthority> mappedAuthorities = this.authoritiesMapper
305+
.mapAuthorities(oidcUser.getAuthorities());
306+
OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities,
307+
clientRegistration.getRegistrationId());
308+
authenticationResult.setDetails(authenticationToken.getDetails());
309+
SecurityContextImpl securityContext = new SecurityContextImpl(authenticationResult);
310+
return this.serverSecurityContextRepository.save(exchange, securityContext)
311+
.contextWrite(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)));
312+
}
313+
314+
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
3333
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
3434
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
35+
import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizationSuccessHandler;
3536
import org.springframework.security.oauth2.client.RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler;
3637
import org.springframework.security.oauth2.client.registration.ClientRegistration;
3738
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
@@ -85,6 +86,7 @@
8586
*
8687
* @author Joe Grandja
8788
* @author Phil Clay
89+
* @author Evgeniy Cheban
8890
* @since 5.2
8991
* @see ReactiveOAuth2AuthorizedClientManager
9092
* @see ReactiveOAuth2AuthorizedClientProvider
@@ -115,6 +117,8 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
115117

116118
private Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper = new DefaultContextAttributesMapper();
117119

120+
private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();
121+
118122
private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
119123

120124
private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;
@@ -132,15 +136,23 @@ public DefaultReactiveOAuth2AuthorizedClientManager(
132136
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
133137
this.clientRegistrationRepository = clientRegistrationRepository;
134138
this.authorizedClientRepository = authorizedClientRepository;
135-
this.authorizationSuccessHandler = (authorizedClient, principal, attributes) -> authorizedClientRepository
136-
.saveAuthorizedClient(authorizedClient, principal,
137-
(ServerWebExchange) attributes.get(ServerWebExchange.class.getName()));
139+
this.authorizationSuccessHandler = getAuthorizationSuccessHandler(authorizedClientRepository);
138140
this.authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(
139141
(clientRegistrationId, principal, attributes) -> authorizedClientRepository.removeAuthorizedClient(
140142
clientRegistrationId, principal,
141143
(ServerWebExchange) attributes.get(ServerWebExchange.class.getName())));
142144
}
143145

146+
private ReactiveOAuth2AuthorizationSuccessHandler getAuthorizationSuccessHandler(
147+
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
148+
return (authorizedClient, principal, attributes) -> {
149+
Mono<Void> saveAuthorizedClient = authorizedClientRepository.saveAuthorizedClient(authorizedClient,
150+
principal, (ServerWebExchange) attributes.get(ServerWebExchange.class.getName()));
151+
return saveAuthorizedClient.then(Mono.defer(() -> this.refreshTokenSuccessHandler
152+
.onAuthorizationSuccess(authorizedClient, principal, attributes)));
153+
};
154+
}
155+
144156
@Override
145157
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
146158
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
@@ -274,6 +286,19 @@ public void setContextAttributesMapper(
274286
this.contextAttributesMapper = contextAttributesMapper;
275287
}
276288

289+
/**
290+
* Sets the {@link ReactiveOAuth2AuthorizationSuccessHandler} to use for handling
291+
* successful refresh token request. Defaults to
292+
* {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}.
293+
* @param refreshTokenSuccessHandler the
294+
* {@link ReactiveOAuth2AuthorizationSuccessHandler} to use
295+
* @since 7.1
296+
*/
297+
public void setRefreshTokenSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler) {
298+
Assert.notNull(refreshTokenSuccessHandler, "refreshTokenSuccessHandler cannot be null");
299+
this.refreshTokenSuccessHandler = refreshTokenSuccessHandler;
300+
}
301+
277302
/**
278303
* Sets the handler that handles successful authorizations.
279304
*
@@ -318,10 +343,10 @@ public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest)
318343
return Mono.justOrEmpty(serverWebExchange)
319344
.switchIfEmpty(currentServerWebExchangeMono)
320345
.flatMap((exchange) -> {
321-
Map<String, Object> contextAttributes = Collections.emptyMap();
346+
Map<String, Object> contextAttributes = new HashMap<>();
347+
contextAttributes.put(ServerWebExchange.class.getName(), serverWebExchange);
322348
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
323349
if (StringUtils.hasText(scope)) {
324-
contextAttributes = new HashMap<>();
325350
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
326351
StringUtils.delimitedListToStringArray(scope, " "));
327352
}

0 commit comments

Comments
 (0)