Skip to content

Commit 15e2488

Browse files
committed
Sync branch 'main'
2 parents 65f43d6 + 8c6683a commit 15e2488

35 files changed

Lines changed: 1099 additions & 440 deletions

File tree

oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ private static OAuth2AuthorizationCodeRequestAuthenticationException createExcep
500500
registeredClient);
501501
if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
502502
&& (parameterName.equals(OAuth2ParameterNames.CLIENT_ID)
503-
|| parameterName.equals(OAuth2ParameterNames.STATE))) {
503+
|| parameterName.equals(OAuth2ParameterNames.STATE)
504+
|| parameterName.equals(OAuth2ParameterNames.REQUEST_URI))) {
504505
redirectUri = null; // Prevent redirects
505506
}
506507

oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationValidator.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,10 @@ private static OAuth2AuthorizationCodeRequestAuthenticationException createExcep
298298
String redirectUri = StringUtils.hasText(authorizationCodeRequestAuthentication.getRedirectUri())
299299
? authorizationCodeRequestAuthentication.getRedirectUri()
300300
: registeredClient.getRedirectUris().iterator().next();
301-
if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
302-
&& parameterName.equals(OAuth2ParameterNames.REDIRECT_URI)) {
301+
if ((error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
302+
|| error.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT))
303+
&& (parameterName.equals(OAuth2ParameterNames.CLIENT_ID)
304+
|| parameterName.equals(OAuth2ParameterNames.REDIRECT_URI))) {
303305
redirectUri = null; // Prevent redirects
304306
}
305307

oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -132,52 +132,66 @@ else if (!responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue()))
132132
principal = ANONYMOUS_AUTHENTICATION;
133133
}
134134

135-
// redirect_uri (OPTIONAL)
136-
String redirectUri = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
137-
List<String> redirectUriParams = parameters.get(OAuth2ParameterNames.REDIRECT_URI);
138-
if (StringUtils.hasText(redirectUri) && redirectUriParams != null && redirectUriParams.size() != 1) {
139-
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
135+
String redirectUri = null;
136+
if (!StringUtils.hasText(requestUri)) {
137+
// redirect_uri (OPTIONAL)
138+
redirectUri = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
139+
List<String> redirectUriParams = parameters.get(OAuth2ParameterNames.REDIRECT_URI);
140+
if (StringUtils.hasText(redirectUri) && redirectUriParams != null && redirectUriParams.size() != 1) {
141+
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
142+
}
140143
}
141144

142-
// scope (OPTIONAL)
143145
Set<String> scopes = null;
144-
String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
145-
List<String> scopeParams = parameters.get(OAuth2ParameterNames.SCOPE);
146-
if (StringUtils.hasText(scope) && scopeParams != null && scopeParams.size() != 1) {
147-
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE);
148-
}
149-
if (StringUtils.hasText(scope)) {
150-
scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
146+
if (!StringUtils.hasText(requestUri)) {
147+
// scope (OPTIONAL)
148+
String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
149+
List<String> scopeParams = parameters.get(OAuth2ParameterNames.SCOPE);
150+
if (StringUtils.hasText(scope) && scopeParams != null && scopeParams.size() != 1) {
151+
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE);
152+
}
153+
if (StringUtils.hasText(scope)) {
154+
scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
155+
}
151156
}
152157

153-
// state (RECOMMENDED)
154-
String state = parameters.getFirst(OAuth2ParameterNames.STATE);
155-
List<String> stateParams = parameters.get(OAuth2ParameterNames.STATE);
156-
if (StringUtils.hasText(state) && stateParams != null && stateParams.size() != 1) {
157-
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE);
158+
String state = null;
159+
if (!StringUtils.hasText(requestUri)) {
160+
// state (RECOMMENDED)
161+
state = parameters.getFirst(OAuth2ParameterNames.STATE);
162+
List<String> stateParams = parameters.get(OAuth2ParameterNames.STATE);
163+
if (StringUtils.hasText(state) && stateParams != null && stateParams.size() != 1) {
164+
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE);
165+
}
158166
}
159167

160-
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
161-
String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE);
162-
List<String> codeChallengeParams = parameters.get(PkceParameterNames.CODE_CHALLENGE);
163-
if (StringUtils.hasText(codeChallenge) && codeChallengeParams != null && codeChallengeParams.size() != 1) {
164-
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
168+
if (!StringUtils.hasText(requestUri)) {
169+
// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
170+
String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE);
171+
List<String> codeChallengeParams = parameters.get(PkceParameterNames.CODE_CHALLENGE);
172+
if (StringUtils.hasText(codeChallenge) && codeChallengeParams != null && codeChallengeParams.size() != 1) {
173+
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
174+
}
165175
}
166176

167-
// code_challenge_method (OPTIONAL for public clients) - RFC 7636 (PKCE)
168-
String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
169-
List<String> codeChallengeMethodParams = parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
170-
if (StringUtils.hasText(codeChallengeMethod) && codeChallengeMethodParams != null
171-
&& codeChallengeMethodParams.size() != 1) {
172-
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
177+
if (!StringUtils.hasText(requestUri)) {
178+
// code_challenge_method (OPTIONAL for public clients) - RFC 7636 (PKCE)
179+
String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
180+
List<String> codeChallengeMethodParams = parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
181+
if (StringUtils.hasText(codeChallengeMethod) && codeChallengeMethodParams != null
182+
&& codeChallengeMethodParams.size() != 1) {
183+
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
184+
}
173185
}
174186

175-
// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
176-
if (!CollectionUtils.isEmpty(scopes) && scopes.contains(OidcScopes.OPENID)) {
177-
String prompt = parameters.getFirst("prompt");
178-
List<String> promptParams = parameters.get("prompt");
179-
if (StringUtils.hasText(prompt) && promptParams != null && promptParams.size() != 1) {
180-
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "prompt");
187+
if (!StringUtils.hasText(requestUri)) {
188+
// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
189+
if (!CollectionUtils.isEmpty(scopes) && scopes.contains(OidcScopes.OPENID)) {
190+
String prompt = parameters.getFirst("prompt");
191+
List<String> promptParams = parameters.get("prompt");
192+
if (StringUtils.hasText(prompt) && promptParams != null && promptParams.size() != 1) {
193+
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "prompt");
194+
}
181195
}
182196
}
183197

oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ public void authenticateWhenClientNotAuthorizedToRequestCodeThenThrowOAuth2Autho
311311
assertThatExceptionOfType(OAuth2AuthorizationCodeRequestAuthenticationException.class)
312312
.isThrownBy(() -> this.authenticationProvider.authenticate(authentication))
313313
.satisfies((ex) -> assertAuthenticationException(ex, OAuth2ErrorCodes.UNAUTHORIZED_CLIENT,
314-
OAuth2ParameterNames.CLIENT_ID, authentication.getRedirectUri()));
314+
OAuth2ParameterNames.CLIENT_ID, null));
315315
}
316316

317317
@Test

oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2PushedAuthorizationRequestAuthenticationProviderTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public void authenticateWhenClientNotAuthorizedToRequestCodeThenThrowOAuth2Autho
128128
assertThatExceptionOfType(OAuth2AuthorizationCodeRequestAuthenticationException.class)
129129
.isThrownBy(() -> this.authenticationProvider.authenticate(authentication))
130130
.satisfies((ex) -> assertAuthenticationException(ex, OAuth2ErrorCodes.UNAUTHORIZED_CLIENT,
131-
OAuth2ParameterNames.CLIENT_ID, authentication.getRedirectUri()));
131+
OAuth2ParameterNames.CLIENT_ID, null));
132132
}
133133

134134
@Test

saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ tasks.register("opensaml5Test", Test) {
148148
classpath = sourceSets.opensaml5Test.output + sourceSets.opensaml5Test.runtimeClasspath
149149
}
150150

151+
tasks.register("bouncyCastleTest", Test) {
152+
useJUnitPlatform()
153+
testClassesDirs = sourceSets.test.output.classesDirs
154+
classpath = sourceSets.test.runtimeClasspath
155+
include "**/JdbcAssertingPartyMetadataRepositoryBouncyCastleTests.class"
156+
}
157+
151158
tasks.named("test") {
152-
dependsOn opensaml5Test
159+
exclude "**/JdbcAssertingPartyMetadataRepositoryBouncyCastleTests.class"
160+
dependsOn opensaml5Test, bouncyCastleTest
153161
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/internal/Saml2Utils.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.ByteArrayOutputStream;
2020
import java.io.IOException;
21+
import java.io.OutputStream;
2122
import java.nio.charset.StandardCharsets;
2223
import java.util.Arrays;
2324
import java.util.Base64;
@@ -64,7 +65,7 @@ static byte[] samlDeflate(String s) {
6465
static String samlInflate(byte[] b) {
6566
try {
6667
ByteArrayOutputStream out = new ByteArrayOutputStream();
67-
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
68+
InflaterOutputStream iout = new InflaterOutputStream(new CappedOutputStream(out), new Inflater(true));
6869
iout.write(b);
6970
iout.finish();
7071
return new String(out.toByteArray(), StandardCharsets.UTF_8);
@@ -193,4 +194,27 @@ void checkAcceptable(String ins) {
193194

194195
}
195196

197+
static class CappedOutputStream extends OutputStream {
198+
199+
private static final long MAX_SIZE = 1024 * 1024;
200+
201+
private final OutputStream delegate;
202+
203+
private int size;
204+
205+
CappedOutputStream(OutputStream delegate) {
206+
this.delegate = delegate;
207+
}
208+
209+
@Override
210+
public void write(int b) throws IOException {
211+
if (this.size >= MAX_SIZE) {
212+
throw new IOException("SAML payload exceeded maximum size of " + MAX_SIZE);
213+
}
214+
this.delegate.write(b);
215+
this.size++;
216+
}
217+
218+
}
219+
196220
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/BaseOpenSamlAuthenticationProvider.java

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -331,62 +331,75 @@ private void process(Saml2AuthenticationToken token, Response response) {
331331
boolean responseSigned = response.isSigned();
332332

333333
ResponseToken responseToken = new ResponseToken(response, token);
334-
Saml2ResponseValidatorResult result = this.responseSignatureValidator.convert(responseToken);
334+
Collection<Saml2Error> responseSignatureErrors = this.responseSignatureValidator.convert(responseToken)
335+
.getErrors();
336+
if (!responseSignatureErrors.isEmpty()) {
337+
reportErrors(response, responseSignatureErrors);
338+
return;
339+
}
340+
341+
Collection<Saml2Error> errors = new ArrayList<>();
335342
if (responseSigned) {
336343
this.responseElementsDecrypter.accept(responseToken);
337344
}
338345
else if (!response.getEncryptedAssertions().isEmpty()) {
339-
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
346+
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
340347
"Did not decrypt response [" + response.getID() + "] since it is not signed"));
341348
}
342349
if (!this.validateResponseAfterAssertions) {
343-
result = result.concat(this.responseValidator.convert(responseToken));
350+
errors.addAll(this.responseValidator.convert(responseToken).getErrors());
344351
}
345352
boolean allAssertionsSigned = true;
346353
for (Assertion assertion : response.getAssertions()) {
347354
AssertionToken assertionToken = new AssertionToken(assertion, token);
348-
result = result.concat(this.assertionSignatureValidator.convert(assertionToken));
355+
Collection<Saml2Error> assertionSignatureErrors = this.assertionSignatureValidator.convert(assertionToken)
356+
.getErrors();
357+
errors.addAll(assertionSignatureErrors);
349358
allAssertionsSigned = allAssertionsSigned && assertion.isSigned();
359+
if (!assertionSignatureErrors.isEmpty()) {
360+
continue;
361+
}
350362
if (responseSigned || assertion.isSigned()) {
351363
this.assertionElementsDecrypter.accept(new AssertionToken(assertion, token));
352364
}
353-
result = result.concat(this.assertionValidator.convert(assertionToken));
365+
errors.addAll(this.assertionValidator.convert(assertionToken).getErrors());
354366
}
355367
if (!responseSigned && !allAssertionsSigned) {
356368
String description = "Either the response or one of the assertions is unsigned. "
357369
+ "Please either sign the response or all of the assertions.";
358-
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, description));
370+
errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE, description));
359371
}
360372
if (this.validateResponseAfterAssertions) {
361-
result = result.concat(this.responseValidator.convert(responseToken));
373+
errors.addAll(this.responseValidator.convert(responseToken).getErrors());
362374
}
363375
else {
364376
Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
365377
if (firstAssertion != null && !hasName(firstAssertion)) {
366-
Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
367-
"Assertion [" + firstAssertion.getID() + "] is missing a subject");
368-
result = result.concat(error);
378+
errors.add(new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
379+
"Assertion [" + firstAssertion.getID() + "] is missing a subject"));
369380
}
370381
}
371382

372-
if (result.hasErrors()) {
373-
Collection<Saml2Error> errors = result.getErrors();
374-
if (this.logger.isTraceEnabled()) {
375-
this.logger.trace("Found " + errors.size() + " validation errors in SAML response [" + response.getID()
376-
+ "]: " + errors);
377-
}
378-
else if (this.logger.isDebugEnabled()) {
379-
this.logger
380-
.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]");
381-
}
382-
Saml2Error first = errors.iterator().next();
383-
throw new Saml2AuthenticationException(first);
384-
}
385-
else {
383+
reportErrors(response, errors);
384+
}
385+
386+
private void reportErrors(Response response, Collection<Saml2Error> errors) {
387+
if (errors.isEmpty()) {
386388
if (this.logger.isDebugEnabled()) {
387389
this.logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
388390
}
391+
return;
392+
}
393+
if (this.logger.isTraceEnabled()) {
394+
this.logger.trace("Found " + errors.size() + " validation errors in SAML response [" + response.getID()
395+
+ "]: " + errors);
396+
}
397+
else if (this.logger.isDebugEnabled()) {
398+
this.logger
399+
.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]");
389400
}
401+
Saml2Error first = errors.iterator().next();
402+
throw new Saml2AuthenticationException(first);
390403
}
391404

392405
private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseSignatureValidator() {

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2Utils.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.ByteArrayOutputStream;
2020
import java.io.IOException;
21+
import java.io.OutputStream;
2122
import java.nio.charset.StandardCharsets;
2223
import java.util.Arrays;
2324
import java.util.Base64;
@@ -64,7 +65,7 @@ static byte[] samlDeflate(String s) {
6465
static String samlInflate(byte[] b) {
6566
try {
6667
ByteArrayOutputStream out = new ByteArrayOutputStream();
67-
InflaterOutputStream iout = new InflaterOutputStream(out, new Inflater(true));
68+
InflaterOutputStream iout = new InflaterOutputStream(new CappedOutputStream(out), new Inflater(true));
6869
iout.write(b);
6970
iout.finish();
7071
return new String(out.toByteArray(), StandardCharsets.UTF_8);
@@ -193,4 +194,27 @@ void checkAcceptable(String ins) {
193194

194195
}
195196

197+
static class CappedOutputStream extends OutputStream {
198+
199+
private static final long MAX_SIZE = 1024 * 1024;
200+
201+
private final OutputStream delegate;
202+
203+
private int size;
204+
205+
CappedOutputStream(OutputStream delegate) {
206+
this.delegate = delegate;
207+
}
208+
209+
@Override
210+
public void write(int b) throws IOException {
211+
if (this.size >= MAX_SIZE) {
212+
throw new IOException("SAML payload exceeded maximum size of " + MAX_SIZE);
213+
}
214+
this.delegate.write(b);
215+
this.size++;
216+
}
217+
218+
}
219+
196220
}

0 commit comments

Comments
 (0)