diff --git a/core/src/main/java/com/predic8/membrane/core/http/Header.java b/core/src/main/java/com/predic8/membrane/core/http/Header.java index 4be7cc3c14..8d39a61d44 100644 --- a/core/src/main/java/com/predic8/membrane/core/http/Header.java +++ b/core/src/main/java/com/predic8/membrane/core/http/Header.java @@ -16,6 +16,7 @@ import com.predic8.membrane.annot.Constants; import com.predic8.membrane.core.http.cookie.*; +import com.predic8.membrane.core.util.security.BasicAuthenticationUtil; import com.predic8.membrane.core.util.*; import jakarta.mail.internet.*; import org.jetbrains.annotations.*; @@ -31,6 +32,7 @@ import static com.predic8.membrane.core.http.MimeType.*; import static com.predic8.membrane.core.util.HttpUtil.*; +import static com.predic8.membrane.core.util.security.BasicAuthenticationUtil.createAuthorizationHeader; import static java.nio.charset.StandardCharsets.*; import static java.util.Arrays.*; import static java.util.Collections.*; @@ -428,9 +430,7 @@ public String toString() { * @param password the password for authentication */ public void setAuthorization(String user, String password) { - setValue("Authorization", "Basic " - + new String(encodeBase64((user + ":" + password) - .getBytes(UTF_8)), UTF_8)); + setValue("Authorization", createAuthorizationHeader(user, password)); } /** diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2ClientInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2ClientInterceptor.java new file mode 100644 index 0000000000..67b59f3b81 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2ClientInterceptor.java @@ -0,0 +1,232 @@ +package com.predic8.membrane.core.interceptor.oauth2; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.annot.Required; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.AbstractInterceptor; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.transport.http.HttpClient; +import com.predic8.membrane.core.util.security.BasicAuthenticationUtil; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.predic8.membrane.core.exceptions.ProblemDetails.gateway; +import static com.predic8.membrane.core.http.Header.ACCEPT; +import static com.predic8.membrane.core.http.Header.AUTHORIZATION; +import static com.predic8.membrane.core.http.MimeType.APPLICATION_JSON; +import static com.predic8.membrane.core.http.MimeType.APPLICATION_X_WWW_FORM_URLENCODED; +import static com.predic8.membrane.core.http.Request.post; +import static com.predic8.membrane.core.interceptor.Interceptor.Flow.Set.REQUEST_FLOW; +import static com.predic8.membrane.core.interceptor.Outcome.ABORT; +import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; +import static com.predic8.membrane.core.util.URLParamUtil.createQueryStringOmitNullValues; +import static com.predic8.membrane.core.util.security.BasicAuthenticationUtil.createAuthorizationHeader; +import static java.lang.Math.max; +import static java.lang.System.currentTimeMillis; +import static java.net.URLEncoder.encode; +import static java.nio.charset.StandardCharsets.UTF_8; + +/** + * @description Obtains an OAuth2 access token using the client credentials flow and forwards the request with a Bearer token. + * @yaml
+ * api:
+ * port: 2000
+ * flow:
+ * - oauth2Client:
+ * tokenUrl: https://auth.example.com/oauth2/token
+ * clientId: gateway
+ * clientSecret: secret
+ * scope: read write
+ * target:
+ * url: https://api.example.com
+ *
+ */
+@MCElement(name="oauth2Client")
+public class OAuth2ClientInterceptor extends AbstractInterceptor {
+
+ private static final Logger log = LoggerFactory.getLogger(OAuth2ClientInterceptor.class);
+
+ private String tokenUrl;
+
+ private String clientId;
+
+ private String clientSecret;
+
+ private String scope;
+
+ private final ObjectMapper objectMapper = new ObjectMapper();
+
+ private HttpClient httpClient;
+ private final Object tokenLock = new Object();
+
+ private volatile String cachedAccessToken;
+ private volatile long cachedAccessTokenValidUntilEpochMillis;
+
+ @Override
+ public void init() {
+ super.init();
+ name = "OAuth2 Client";
+ setAppliedFlow(REQUEST_FLOW);
+
+ httpClient = router.getHttpClientFactory().createClient(router.getHttpClientConfig());
+ }
+
+ @Override
+ public Outcome handleRequest(Exchange exc) {
+ try {
+ String token = getAccessToken();
+ exc.getRequest().getHeader().setValue(AUTHORIZATION, "Bearer " + token);
+ return CONTINUE;
+ } catch (Exception e) {
+ log.warn("Could not obtain OAuth2 access token from {}: {}", tokenUrl, e.getMessage());
+ log.debug("OAuth2 token request failed.", e);
+ gateway(router.getConfiguration().isProduction(), getDisplayName())
+ .title("Bad Gateway")
+ .status(502)
+ .addSubSee("oauth2-token")
+ .detail("Could not obtain an OAuth2 access token.")
+ .buildAndSetResponse(exc);
+ return ABORT;
+ }
+ }
+
+ private String getAccessToken() throws Exception {
+ if (hasValidCachedToken()) {
+ return cachedAccessToken;
+ }
+
+ synchronized (tokenLock) {
+ if (hasValidCachedToken()) {
+ return cachedAccessToken;
+ }
+ return fetchAccessToken();
+ }
+ }
+
+ private boolean hasValidCachedToken() {
+ return cachedAccessToken != null && currentTimeMillis() < cachedAccessTokenValidUntilEpochMillis;
+ }
+
+ private String fetchAccessToken() throws Exception {
+ Exchange tokenExchange = post(tokenUrl)
+ .contentType(APPLICATION_X_WWW_FORM_URLENCODED)
+ .header(ACCEPT, APPLICATION_JSON)
+ .header(AUTHORIZATION, buildBasicAuthorization())
+ .body(buildTokenRequestBody())
+ .buildExchange();
+
+ httpClient.call(tokenExchange);
+
+ var response = tokenExchange.getResponse();
+ String responseBody = response.getBodyAsStringDecoded();
+ if (response.getStatusCode() != 200) {
+ throw new IllegalStateException("Authorization server returned status " + response.getStatusCode() + ".");
+ }
+
+ var responseJson = objectMapper.readTree(responseBody);
+ String token = extractAccessToken(responseJson);
+
+ updateTokenCache(token, responseJson.path("expires_in").asLong(-1));
+ return token;
+ }
+
+ private static @NotNull String extractAccessToken(JsonNode responseJson) {
+ String token = responseJson.path("access_token").asText(null);
+ if (token == null || token.isBlank()) {
+ throw new IllegalStateException("Authorization server did not return an access token.");
+ }
+ return token;
+ }
+
+ private void updateTokenCache(String token, long expiresInSeconds) {
+ if (expiresInSeconds <= 0) {
+ cachedAccessToken = null;
+ cachedAccessTokenValidUntilEpochMillis = 0;
+ log.debug("Token response from {} has no usable expires_in. Token will not be cached.", tokenUrl);
+ return;
+ }
+
+ // Refresh slightly before expiry to avoid sending a token that expires mid-request.
+ long refreshBufferSeconds = Math.min(30, max(1, expiresInSeconds / 10));
+
+ cachedAccessToken = token;
+ cachedAccessTokenValidUntilEpochMillis = currentTimeMillis() + max(1, expiresInSeconds - refreshBufferSeconds) * 1000;
+ }
+
+ private String buildTokenRequestBody() {
+ return createQueryStringOmitNullValues(
+ "grant_type", "client_credentials",
+ "scope", scope == null || scope.isBlank() ? null : scope
+ );
+ }
+
+ private String buildBasicAuthorization() {
+ return createAuthorizationHeader(clientId, clientSecret, this::encodeClientCredential);
+ }
+
+ private String encodeClientCredential(String value) {
+ return encode(value, UTF_8);
+ }
+
+ /**
+ * @description The token endpoint used to obtain the OAuth2 access token.
+ * @required
+ * @example https://auth.example.com/oauth2/token
+ */
+ @MCAttribute
+ @Required
+ public void setTokenUrl(String tokenUrl) {
+ this.tokenUrl = tokenUrl;
+ }
+
+ public String getTokenUrl() {
+ return tokenUrl;
+ }
+
+ /**
+ * @description The OAuth2 client id used for the token request.
+ * @required
+ * @example gateway
+ */
+ @MCAttribute
+ @Required
+ public void setClientId(String clientId) {
+ this.clientId = clientId;
+ }
+
+ public String getClientId() {
+ return clientId;
+ }
+
+ /**
+ * @description The OAuth2 client secret used for the token request.
+ * @required
+ * @example secret
+ */
+ @MCAttribute
+ @Required
+ public void setClientSecret(String clientSecret) {
+ this.clientSecret = clientSecret;
+ }
+
+ public String getClientSecret() {
+ return clientSecret;
+ }
+
+ /**
+ * @description Space-separated scopes requested for the access token.
+ * @example read write
+ */
+ @MCAttribute
+ public void setScope(String scope) {
+ this.scope = scope;
+ }
+
+ public String getScope() {
+ return scope;
+ }
+}
diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2TokenBody.java b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2TokenBody.java
index 4af05e4eae..e3892d674e 100644
--- a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2TokenBody.java
+++ b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2TokenBody.java
@@ -14,10 +14,7 @@
package com.predic8.membrane.core.interceptor.oauth2;
-import java.util.function.Function;
-
-import static java.net.URLEncoder.encode;
-import static java.nio.charset.StandardCharsets.UTF_8;
+import com.predic8.membrane.core.util.URLParamUtil;
public class OAuth2TokenBody {
private String code;
@@ -70,27 +67,18 @@ public OAuth2TokenBody clientAssertion(String type, String assertion) {
}
public String build() {
- StringBuilder r = new StringBuilder("grant_type=" + grantType);
- appendParam(r, "refresh_token", refreshToken);
- appendParam(r, "code", code);
- appendParam(r, "redirect_uri", redirectUri);
- appendParam(r, "scope", scope);
- appendParam(r, "code_verifier", codeVerifier);
- appendParam(r, "client_id", clientId);
- appendParam(r, "client_secret", clientSecret);
- appendParam(r, "client_assertion_type", clientAssertionType);
- appendParam(r, "client_assertion", clientAssertion);
- return r.toString();
- }
-
- private void appendParam(StringBuilder sb, String paramName, String paramValue) {
- appendParam(sb, paramName, paramValue, e -> encode(e, UTF_8));
- }
-
- private void appendParam(StringBuilder sb, String paramName, String paramValue, FunctionConfiguration for an outbound HTTP proxy used by the HTTP client. @@ -167,7 +167,7 @@ public void setSslParser(SSLParser sslParser) { * The "Basic" authentication scheme defined in RFC 2617 does not properly define how to treat non-ASCII characters. */ public String getCredentials() { - return "Basic " + new String(encodeBase64((username + ":" + password).getBytes(UTF_8)), UTF_8); + return createAuthorizationHeader(username, password); } } diff --git a/core/src/main/java/com/predic8/membrane/core/util/URLParamUtil.java b/core/src/main/java/com/predic8/membrane/core/util/URLParamUtil.java index 0f9acd4b5d..e204e33ecc 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/URLParamUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/URLParamUtil.java @@ -115,6 +115,26 @@ public static String createQueryString(String... params) { return buf.toString(); } + public static String createQueryStringOmitNullValues(String... params) { + if (params.length % 2 != 0) + throw new IllegalArgumentException("params must contain key/value pairs"); + + StringBuilder buf = new StringBuilder(); + boolean first = true; + for (int i = 0; i < params.length; i += 2) { + if (params[i + 1] == null) + continue; + if (first) + first = false; + else + buf.append('&'); + buf.append(URLEncoder.encode(params[i], UTF_8)); + buf.append('='); + buf.append(URLEncoder.encode(params[i + 1], UTF_8)); + } + return buf.toString(); + } + /** * Parse a URL query into parameter pairs. The query is expected to be application/x-www-form-urlencoded . *
diff --git a/core/src/main/java/com/predic8/membrane/core/util/security/BasicAuthenticationUtil.java b/core/src/main/java/com/predic8/membrane/core/util/security/BasicAuthenticationUtil.java
index fb66686f65..fab0db78f8 100644
--- a/core/src/main/java/com/predic8/membrane/core/util/security/BasicAuthenticationUtil.java
+++ b/core/src/main/java/com/predic8/membrane/core/util/security/BasicAuthenticationUtil.java
@@ -17,8 +17,10 @@
import com.predic8.membrane.core.exchange.*;
import java.util.*;
+import java.util.function.UnaryOperator;
import static java.nio.charset.StandardCharsets.*;
+import static java.util.Objects.requireNonNull;
public class BasicAuthenticationUtil {
@@ -32,8 +34,8 @@ public class BasicAuthenticationUtil {
*/
public record BasicCredentials(String username, String password) {
public BasicCredentials {
- Objects.requireNonNull(username, "Username cannot be null");
- Objects.requireNonNull(password, "Password cannot be null");
+ requireNonNull(username, "Username cannot be null");
+ requireNonNull(password, "Password cannot be null");
}
/**
@@ -68,6 +70,19 @@ public static BasicCredentials getCredentials(Exchange exc) {
return parseCredentials(decodeAuthorizationHeader(exc));
}
+ public static String createAuthorizationHeader(String username, String password) {
+ return createAuthorizationHeader(username, password, UnaryOperator.identity());
+ }
+
+ public static String createAuthorizationHeader(String username, String password, UnaryOperator