|
24 | 24 | import jakarta.servlet.http.HttpServletResponse; |
25 | 25 |
|
26 | 26 | import org.springframework.core.log.LogMessage; |
| 27 | +import org.springframework.http.HttpHeaders; |
27 | 28 | import org.springframework.http.HttpStatus; |
28 | 29 | import org.springframework.http.converter.HttpMessageConverter; |
29 | 30 | import org.springframework.http.server.ServletServerHttpResponse; |
@@ -90,24 +91,37 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter |
90 | 91 |
|
91 | 92 | private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource(); |
92 | 93 |
|
| 94 | + private final String realmName; |
| 95 | + |
93 | 96 | private AuthenticationConverter authenticationConverter; |
94 | 97 |
|
95 | 98 | private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess; |
96 | 99 |
|
97 | 100 | private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure; |
98 | 101 |
|
| 102 | + /** |
| 103 | + * Internal error code used to distinguish missing authentication from invalid |
| 104 | + * authentication in order to display a WWW-Authenticate header when appropriate The |
| 105 | + * default failure handler will convert this to the spec-compliant 'invalid_client' |
| 106 | + * before returning to the caller. |
| 107 | + */ |
| 108 | + public static final String MISSING_CLIENT_AUTH_ERROR_CODE = "missing_client_auth"; |
| 109 | + |
99 | 110 | /** |
100 | 111 | * Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided |
101 | 112 | * parameters. |
102 | 113 | * @param authenticationManager the {@link AuthenticationManager} used for |
103 | 114 | * authenticating the client |
104 | 115 | * @param requestMatcher the {@link RequestMatcher} used for matching against the |
105 | 116 | * {@code HttpServletRequest} |
| 117 | + * @param realmName realm name to use in WWW-Authenticate header for Basic auth |
106 | 118 | */ |
107 | | - public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager, |
108 | | - RequestMatcher requestMatcher) { |
| 119 | + public OAuth2ClientAuthenticationFilter(AuthenticationManager authenticationManager, RequestMatcher requestMatcher, |
| 120 | + String realmName) { |
| 121 | + this.realmName = realmName; |
109 | 122 | Assert.notNull(authenticationManager, "authenticationManager cannot be null"); |
110 | 123 | Assert.notNull(requestMatcher, "requestMatcher cannot be null"); |
| 124 | + Assert.notNull(realmName, "realmName cannot be null"); |
111 | 125 | this.authenticationManager = authenticationManager; |
112 | 126 | this.requestMatcher = requestMatcher; |
113 | 127 | // @formatter:off |
@@ -140,9 +154,12 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse |
140 | 154 | validateClientIdentifier(authenticationRequest); |
141 | 155 | Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest); |
142 | 156 | this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult); |
| 157 | + filterChain.doFilter(request, response); |
| 158 | + } |
| 159 | + else { |
| 160 | + this.authenticationFailureHandler.onAuthenticationFailure(request, response, |
| 161 | + new OAuth2AuthenticationException(MISSING_CLIENT_AUTH_ERROR_CODE)); |
143 | 162 | } |
144 | | - filterChain.doFilter(request, response); |
145 | | - |
146 | 163 | } |
147 | 164 | catch (OAuth2AuthenticationException ex) { |
148 | 165 | if (this.logger.isTraceEnabled()) { |
@@ -204,27 +221,23 @@ private void onAuthenticationFailure(HttpServletRequest request, HttpServletResp |
204 | 221 |
|
205 | 222 | SecurityContextHolder.clearContext(); |
206 | 223 |
|
207 | | - // TODO |
208 | | - // The authorization server MAY return an HTTP 401 (Unauthorized) status code |
209 | | - // to indicate which HTTP authentication schemes are supported. |
210 | | - // If the client attempted to authenticate via the "Authorization" request header |
211 | | - // field, |
212 | | - // the authorization server MUST respond with an HTTP 401 (Unauthorized) status |
213 | | - // code and |
214 | | - // include the "WWW-Authenticate" response header field |
215 | | - // matching the authentication scheme used by the client. |
216 | | - |
217 | 224 | OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); |
218 | 225 | ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); |
219 | | - if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) { |
| 226 | + String errorCode = error.getErrorCode(); |
| 227 | + String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION); |
| 228 | + |
| 229 | + if (MISSING_CLIENT_AUTH_ERROR_CODE.equals(errorCode) || (OAuth2ErrorCodes.INVALID_CLIENT.equals(errorCode) |
| 230 | + && authHeader != null && authHeader.trim().startsWith("Basic"))) { |
| 231 | + errorCode = OAuth2ErrorCodes.INVALID_CLIENT; |
220 | 232 | httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED); |
| 233 | + httpResponse.getHeaders().set("WWW-Authenticate", "Basic realm=\"" + this.realmName + "\""); |
221 | 234 | } |
222 | 235 | else { |
223 | 236 | httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); |
224 | 237 | } |
225 | 238 | // We don't want to reveal too much information to the caller so just return the |
226 | 239 | // error code |
227 | | - OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode()); |
| 240 | + OAuth2Error errorResponse = new OAuth2Error(errorCode); |
228 | 241 | this.errorHttpResponseConverter.write(errorResponse, null, httpResponse); |
229 | 242 | } |
230 | 243 |
|
|
0 commit comments