Skip to content

Commit 76105af

Browse files
committed
Proposed fix for missing WWW-Authenticate header
Current implementation does not include the WWW-Authenticate header when returning a 401 for missing/invalid credentials when attempting to access the token endpoints. Fixes-468 Signed-off-by: Lucian Holland <lucian@patientsknowbest.com>
1 parent b76300b commit 76105af

3 files changed

Lines changed: 53 additions & 22 deletions

File tree

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientAuthenticationConfigurer.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ public final class OAuth2ClientAuthenticationConfigurer extends AbstractOAuth2Co
7979

8080
private AuthenticationFailureHandler errorResponseHandler;
8181

82+
private String realmName = "oauth";
83+
8284
/**
8385
* Restrict for internal use only.
8486
* @param objectPostProcessor an {@code ObjectPostProcessor}
@@ -102,6 +104,18 @@ public OAuth2ClientAuthenticationConfigurer authenticationConverter(
102104
return this;
103105
}
104106

107+
/**
108+
* Sets the realm name for Http Basic when returning a WWW-Authenticate header on
109+
* client authentication failure.
110+
* @param realmName the Http Basic realm name
111+
* @return the {@link OAuth2ClientAuthenticationConfigurer} for further configuration
112+
*/
113+
public OAuth2ClientAuthenticationConfigurer realmName(String realmName) {
114+
Assert.hasText(realmName, "realmName cannot be empty");
115+
this.realmName = realmName;
116+
return this;
117+
}
118+
105119
/**
106120
* Sets the {@code Consumer} providing access to the {@code List} of default and
107121
* (optionally) added {@link #authenticationConverter(AuthenticationConverter)
@@ -213,7 +227,7 @@ void init(HttpSecurity httpSecurity) {
213227
void configure(HttpSecurity httpSecurity) {
214228
AuthenticationManager authenticationManager = httpSecurity.getSharedObject(AuthenticationManager.class);
215229
OAuth2ClientAuthenticationFilter clientAuthenticationFilter = new OAuth2ClientAuthenticationFilter(
216-
authenticationManager, this.requestMatcher);
230+
authenticationManager, this.requestMatcher, this.realmName);
217231
List<AuthenticationConverter> authenticationConverters = createDefaultAuthenticationConverters();
218232
if (!this.authenticationConverters.isEmpty()) {
219233
authenticationConverters.addAll(0, this.authenticationConverters);

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import jakarta.servlet.http.HttpServletResponse;
2525

2626
import org.springframework.core.log.LogMessage;
27+
import org.springframework.http.HttpHeaders;
2728
import org.springframework.http.HttpStatus;
2829
import org.springframework.http.converter.HttpMessageConverter;
2930
import org.springframework.http.server.ServletServerHttpResponse;
@@ -90,24 +91,37 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
9091

9192
private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
9293

94+
private final String realmName;
95+
9396
private AuthenticationConverter authenticationConverter;
9497

9598
private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
9699

97100
private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure;
98101

102+
/**
103+
* Internal error code used to distinguish missing authentication from invalid
104+
* authentication in order to display a WWW-Authenticate header when appropriate The
105+
* default failure handler will convert this to the spec-compliant 'invalid_client'
106+
* before returning to the caller.
107+
*/
108+
public static final String MISSING_CLIENT_AUTH_ERROR_CODE = "missing_client_auth";
109+
99110
/**
100111
* Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided
101112
* parameters.
102113
* @param authenticationManager the {@link AuthenticationManager} used for
103114
* authenticating the client
104115
* @param requestMatcher the {@link RequestMatcher} used for matching against the
105116
* {@code HttpServletRequest}
117+
* @param realmName realm name to use in WWW-Authenticate header for Basic auth
106118
*/
107-
public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager,
108-
RequestMatcher requestMatcher) {
119+
public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager, RequestMatcher requestMatcher,
120+
String realmName) {
121+
this.realmName = realmName;
109122
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
110123
Assert.notNull(requestMatcher, "requestMatcher cannot be null");
124+
Assert.notNull(realmName, "realmName cannot be null");
111125
this.authenticationManager = authenticationManager;
112126
this.requestMatcher = requestMatcher;
113127
// @formatter:off
@@ -140,9 +154,12 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
140154
validateClientIdentifier(authenticationRequest);
141155
Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
142156
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
157+
filterChain.doFilter(request, response);
158+
}
159+
else {
160+
this.authenticationFailureHandler.onAuthenticationFailure(request, response,
161+
new OAuth2AuthenticationException(MISSING_CLIENT_AUTH_ERROR_CODE));
143162
}
144-
filterChain.doFilter(request, response);
145-
146163
}
147164
catch (OAuth2AuthenticationException ex) {
148165
if (this.logger.isTraceEnabled()) {
@@ -204,27 +221,23 @@ private void onAuthenticationFailure(HttpServletRequest request, HttpServletResp
204221

205222
SecurityContextHolder.clearContext();
206223

207-
// TODO
208-
// The authorization server MAY return an HTTP 401 (Unauthorized) status code
209-
// to indicate which HTTP authentication schemes are supported.
210-
// If the client attempted to authenticate via the "Authorization" request header
211-
// field,
212-
// the authorization server MUST respond with an HTTP 401 (Unauthorized) status
213-
// code and
214-
// include the "WWW-Authenticate" response header field
215-
// matching the authentication scheme used by the client.
216-
217224
OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
218225
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
219-
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
226+
String errorCode = error.getErrorCode();
227+
String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
228+
229+
if (MISSING_CLIENT_AUTH_ERROR_CODE.equals(errorCode) || (OAuth2ErrorCodes.INVALID_CLIENT.equals(errorCode)
230+
&& authHeader != null && authHeader.trim().startsWith("Basic"))) {
231+
errorCode = OAuth2ErrorCodes.INVALID_CLIENT;
220232
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
233+
httpResponse.getHeaders().set("WWW-Authenticate", "Basic realm=\"" + this.realmName + "\"");
221234
}
222235
else {
223236
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
224237
}
225238
// We don't want to reveal too much information to the caller so just return the
226239
// error code
227-
OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
240+
OAuth2Error errorResponse = new OAuth2Error(errorCode);
228241
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
229242
}
230243

oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.junit.jupiter.api.Test;
2727
import org.mockito.ArgumentCaptor;
2828

29+
import org.springframework.http.HttpHeaders;
2930
import org.springframework.http.HttpMethod;
3031
import org.springframework.http.HttpStatus;
3132
import org.springframework.http.converter.HttpMessageConverter;
@@ -81,7 +82,7 @@ public class OAuth2ClientAuthenticationFilterTests {
8182
public void setUp() {
8283
this.authenticationManager = mock(AuthenticationManager.class);
8384
this.requestMatcher = new AntPathRequestMatcher(this.filterProcessesUrl, HttpMethod.POST.name());
84-
this.filter = new OAuth2ClientAuthenticationFilter(this.authenticationManager, this.requestMatcher);
85+
this.filter = new OAuth2ClientAuthenticationFilter(this.authenticationManager, this.requestMatcher, "realm");
8586
this.authenticationConverter = mock(AuthenticationConverter.class);
8687
this.filter.setAuthenticationConverter(this.authenticationConverter);
8788
}
@@ -93,14 +94,14 @@ public void cleanup() {
9394

9495
@Test
9596
public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() {
96-
assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(null, this.requestMatcher))
97+
assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(null, this.requestMatcher, "realm"))
9798
.isInstanceOf(IllegalArgumentException.class)
9899
.hasMessage("authenticationManager cannot be null");
99100
}
100101

101102
@Test
102103
public void constructorWhenRequestMatcherNullThenThrowIllegalArgumentException() {
103-
assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(this.authenticationManager, null))
104+
assertThatThrownBy(() -> new OAuth2ClientAuthenticationFilter(this.authenticationManager, null, "realm"))
104105
.isInstanceOf(IllegalArgumentException.class)
105106
.hasMessage("requestMatcher cannot be null");
106107
}
@@ -149,8 +150,10 @@ public void doFilterWhenRequestMatchesAndEmptyCredentialsThenNotProcessed() thro
149150

150151
this.filter.doFilter(request, response, filterChain);
151152

152-
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
153-
verifyNoInteractions(this.authenticationManager);
153+
verifyNoInteractions(this.authenticationManager, filterChain);
154+
assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
155+
OAuth2Error error = readError(response);
156+
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
154157
}
155158

156159
@Test
@@ -225,6 +228,7 @@ public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError()
225228

226229
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl);
227230
request.setServletPath(this.filterProcessesUrl);
231+
request.addHeader(HttpHeaders.AUTHORIZATION, "Basic invalid-secret");
228232
MockHttpServletResponse response = new MockHttpServletResponse();
229233
FilterChain filterChain = mock(FilterChain.class);
230234

0 commit comments

Comments
 (0)