diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java index 2c0d99ad6bf..4c15457c8bc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java @@ -29,6 +29,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Consumer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -755,14 +756,26 @@ public static final class ClientSettings implements Serializable { @Serial private static final long serialVersionUID = 7495627155437124692L; - private boolean requireProofKey; - - private ClientSettings() { + private final Map settings; + private ClientSettings(Map settings) { + this.settings = Collections.unmodifiableMap(new LinkedHashMap<>(settings)); } + private static final String REQUIRE_PROOF_KEY = "settings.client.require-proof-key"; + public boolean isRequireProofKey() { - return this.requireProofKey; + return Boolean.TRUE.equals(getSetting(REQUIRE_PROOF_KEY)); + } + + @SuppressWarnings("unchecked") + public @Nullable T getSetting(String name) { + Assert.hasText(name, "name cannot be empty"); + return (T) this.settings.get(name); + } + + public Map getSettings() { + return this.settings; } @Override @@ -773,17 +786,17 @@ public boolean equals(@Nullable Object o) { if (!(o instanceof ClientSettings that)) { return false; } - return this.requireProofKey == that.requireProofKey; + return Objects.equals(this.settings, that.settings); } @Override public int hashCode() { - return Objects.hashCode(this.requireProofKey); + return Objects.hashCode(this.settings); } @Override public String toString() { - return "ClientSettings{" + "requireProofKey=" + this.requireProofKey + '}'; + return "ClientSettings{" + "settings=" + this.settings + '}'; } public static Builder builder() { @@ -792,11 +805,12 @@ public static Builder builder() { public static final class Builder { - private boolean requireProofKey = true; + private final Map settings = new LinkedHashMap<>(); private Builder() { + this.settings.put(REQUIRE_PROOF_KEY, true); } - + /** * Set to {@code true} if the client is required to provide a proof key * challenge and verifier when performing the Authorization Code Grant flow. @@ -805,14 +819,35 @@ private Builder() { * @return the {@link Builder} for further configuration */ public Builder requireProofKey(boolean requireProofKey) { - this.requireProofKey = requireProofKey; + return setting(REQUIRE_PROOF_KEY, requireProofKey); + } + + /** + * Sets a configuration setting. + * @param name the name of the setting + * @param value the value of the setting + * @return the {@link Builder} for further configuration + */ + public Builder setting(String name, Object value) { + Assert.hasText(name, "name cannot be empty"); + Assert.notNull(value, "value cannot be null"); + this.settings.put(name, value); return this; } + /** + * Sets the configuration settings. + * @param settings the configuration settings + * @return the {@link Builder} for further configuration + */ + public Builder settings(Consumer> settingsConsumer) { + Assert.notNull(settingsConsumer, "settingsConsumer cannot be null"); + settingsConsumer.accept(this.settings); + return this; + } + public ClientSettings build() { - ClientSettings clientSettings = new ClientSettings(); - clientSettings.requireProofKey = this.requireProofKey; - return clientSettings; + return new ClientSettings(this.settings); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java index 8a029e16bdc..942f4aa3a7c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java @@ -751,4 +751,79 @@ private static T getStaticValue(Field field, Class clazz) { } } + @Test + void buildWhenScopesHaveInvalidCharactersThenThrowException() { + assertThatIllegalArgumentException().isThrownBy(() -> + // @formatter:off + ClientRegistration.withRegistrationId("test") + .clientId("client") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .redirectUri("{baseUrl}/login/oauth2/code/{registrationId}") + .authorizationUri("https://provider.com/auth") + .tokenUri("https://provider.com/token") + .scope("read", "invalid scope ^") // space is 0x20, which is outside the valid range + .build() + // @formatter:on + ); + } + + @Test + void buildWhenClientCredentialsMissingTokenUriThenThrowException() { + assertThatIllegalArgumentException().isThrownBy(() -> + // @formatter:off + ClientRegistration.withRegistrationId("test") + .clientId("client") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + // Missing tokenUri + .build() + // @formatter:on + ); + } + + @Test + void buildWhenValidThenSettingsAreCorrect() { + // @formatter:off + ClientRegistration registration = ClientRegistration.withRegistrationId("google") + .clientId("my-client") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) + .tokenUri("https://google.com/token") + .build(); + // @formatter:on + assertThat(registration.getRegistrationId()).isEqualTo("google"); + assertThat(registration.getClientId()).isEqualTo("my-client"); + // ClientSettings assertions + assertThat(registration.getClientSettings()).isNotNull(); + assertThat(registration.getClientSettings().isRequireProofKey()).isTrue(); + assertThat(registration.getClientSettings().getSettings()).containsKey("settings.client.require-proof-key"); + } + + @Test + void clientSettingsWhenCustomSettingThenGetSettingReturnsValue() { + ClientRegistration.ClientSettings clientSettings = ClientRegistration.ClientSettings.builder() + .setting("custom.key", "customValue") + .build(); + assertThat(clientSettings.getSetting("custom.key")).isEqualTo("customValue"); + } + + @Test + void clientSettingsWhenSettingsConsumerThenSettingsApplied() { + ClientRegistration.ClientSettings clientSettings = ClientRegistration.ClientSettings.builder() + .settings(s -> s.put("custom.key", "value")) + .build(); + assertThat(clientSettings.getSetting("custom.key")).isEqualTo("value"); + } + + @Test + void clientSettingsGetSettingWhenNameEmptyThenThrowException() { + ClientRegistration.ClientSettings clientSettings = ClientRegistration.ClientSettings.builder().build(); + assertThatIllegalArgumentException() + .isThrownBy(() -> clientSettings.getSetting("")) + .withMessageContaining("name cannot be empty"); + } + + @Test + void clientSettingsSettingsConsumerWhenNullThenThrowException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> ClientRegistration.ClientSettings.builder().settings(null)); + } }