diff --git a/core/src/main/java/com/predic8/membrane/core/http/Header.java b/core/src/main/java/com/predic8/membrane/core/http/Header.java index 4be7cc3c14..8d39a61d44 100644 --- a/core/src/main/java/com/predic8/membrane/core/http/Header.java +++ b/core/src/main/java/com/predic8/membrane/core/http/Header.java @@ -16,6 +16,7 @@ import com.predic8.membrane.annot.Constants; import com.predic8.membrane.core.http.cookie.*; +import com.predic8.membrane.core.util.security.BasicAuthenticationUtil; import com.predic8.membrane.core.util.*; import jakarta.mail.internet.*; import org.jetbrains.annotations.*; @@ -31,6 +32,7 @@ import static com.predic8.membrane.core.http.MimeType.*; import static com.predic8.membrane.core.util.HttpUtil.*; +import static com.predic8.membrane.core.util.security.BasicAuthenticationUtil.createAuthorizationHeader; import static java.nio.charset.StandardCharsets.*; import static java.util.Arrays.*; import static java.util.Collections.*; @@ -428,9 +430,7 @@ public String toString() { * @param password the password for authentication */ public void setAuthorization(String user, String password) { - setValue("Authorization", "Basic " - + new String(encodeBase64((user + ":" + password) - .getBytes(UTF_8)), UTF_8)); + setValue("Authorization", createAuthorizationHeader(user, password)); } /** diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2ClientInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2ClientInterceptor.java new file mode 100644 index 0000000000..67b59f3b81 --- /dev/null +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2ClientInterceptor.java @@ -0,0 +1,232 @@ +package com.predic8.membrane.core.interceptor.oauth2; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.predic8.membrane.annot.MCAttribute; +import com.predic8.membrane.annot.MCElement; +import com.predic8.membrane.annot.Required; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.interceptor.AbstractInterceptor; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.transport.http.HttpClient; +import com.predic8.membrane.core.util.security.BasicAuthenticationUtil; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static com.predic8.membrane.core.exceptions.ProblemDetails.gateway; +import static com.predic8.membrane.core.http.Header.ACCEPT; +import static com.predic8.membrane.core.http.Header.AUTHORIZATION; +import static com.predic8.membrane.core.http.MimeType.APPLICATION_JSON; +import static com.predic8.membrane.core.http.MimeType.APPLICATION_X_WWW_FORM_URLENCODED; +import static com.predic8.membrane.core.http.Request.post; +import static com.predic8.membrane.core.interceptor.Interceptor.Flow.Set.REQUEST_FLOW; +import static com.predic8.membrane.core.interceptor.Outcome.ABORT; +import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE; +import static com.predic8.membrane.core.util.URLParamUtil.createQueryStringOmitNullValues; +import static com.predic8.membrane.core.util.security.BasicAuthenticationUtil.createAuthorizationHeader; +import static java.lang.Math.max; +import static java.lang.System.currentTimeMillis; +import static java.net.URLEncoder.encode; +import static java.nio.charset.StandardCharsets.UTF_8; + +/** + * @description Obtains an OAuth2 access token using the client credentials flow and forwards the request with a Bearer token. + * @yaml

+ * api:
+ *   port: 2000
+ *   flow:
+ *     - oauth2Client:
+ *         tokenUrl: https://auth.example.com/oauth2/token
+ *         clientId: gateway
+ *         clientSecret: secret
+ *         scope: read write
+ *   target:
+ *     url: https://api.example.com
+ * 
+ */ +@MCElement(name="oauth2Client") +public class OAuth2ClientInterceptor extends AbstractInterceptor { + + private static final Logger log = LoggerFactory.getLogger(OAuth2ClientInterceptor.class); + + private String tokenUrl; + + private String clientId; + + private String clientSecret; + + private String scope; + + private final ObjectMapper objectMapper = new ObjectMapper(); + + private HttpClient httpClient; + private final Object tokenLock = new Object(); + + private volatile String cachedAccessToken; + private volatile long cachedAccessTokenValidUntilEpochMillis; + + @Override + public void init() { + super.init(); + name = "OAuth2 Client"; + setAppliedFlow(REQUEST_FLOW); + + httpClient = router.getHttpClientFactory().createClient(router.getHttpClientConfig()); + } + + @Override + public Outcome handleRequest(Exchange exc) { + try { + String token = getAccessToken(); + exc.getRequest().getHeader().setValue(AUTHORIZATION, "Bearer " + token); + return CONTINUE; + } catch (Exception e) { + log.warn("Could not obtain OAuth2 access token from {}: {}", tokenUrl, e.getMessage()); + log.debug("OAuth2 token request failed.", e); + gateway(router.getConfiguration().isProduction(), getDisplayName()) + .title("Bad Gateway") + .status(502) + .addSubSee("oauth2-token") + .detail("Could not obtain an OAuth2 access token.") + .buildAndSetResponse(exc); + return ABORT; + } + } + + private String getAccessToken() throws Exception { + if (hasValidCachedToken()) { + return cachedAccessToken; + } + + synchronized (tokenLock) { + if (hasValidCachedToken()) { + return cachedAccessToken; + } + return fetchAccessToken(); + } + } + + private boolean hasValidCachedToken() { + return cachedAccessToken != null && currentTimeMillis() < cachedAccessTokenValidUntilEpochMillis; + } + + private String fetchAccessToken() throws Exception { + Exchange tokenExchange = post(tokenUrl) + .contentType(APPLICATION_X_WWW_FORM_URLENCODED) + .header(ACCEPT, APPLICATION_JSON) + .header(AUTHORIZATION, buildBasicAuthorization()) + .body(buildTokenRequestBody()) + .buildExchange(); + + httpClient.call(tokenExchange); + + var response = tokenExchange.getResponse(); + String responseBody = response.getBodyAsStringDecoded(); + if (response.getStatusCode() != 200) { + throw new IllegalStateException("Authorization server returned status " + response.getStatusCode() + "."); + } + + var responseJson = objectMapper.readTree(responseBody); + String token = extractAccessToken(responseJson); + + updateTokenCache(token, responseJson.path("expires_in").asLong(-1)); + return token; + } + + private static @NotNull String extractAccessToken(JsonNode responseJson) { + String token = responseJson.path("access_token").asText(null); + if (token == null || token.isBlank()) { + throw new IllegalStateException("Authorization server did not return an access token."); + } + return token; + } + + private void updateTokenCache(String token, long expiresInSeconds) { + if (expiresInSeconds <= 0) { + cachedAccessToken = null; + cachedAccessTokenValidUntilEpochMillis = 0; + log.debug("Token response from {} has no usable expires_in. Token will not be cached.", tokenUrl); + return; + } + + // Refresh slightly before expiry to avoid sending a token that expires mid-request. + long refreshBufferSeconds = Math.min(30, max(1, expiresInSeconds / 10)); + + cachedAccessToken = token; + cachedAccessTokenValidUntilEpochMillis = currentTimeMillis() + max(1, expiresInSeconds - refreshBufferSeconds) * 1000; + } + + private String buildTokenRequestBody() { + return createQueryStringOmitNullValues( + "grant_type", "client_credentials", + "scope", scope == null || scope.isBlank() ? null : scope + ); + } + + private String buildBasicAuthorization() { + return createAuthorizationHeader(clientId, clientSecret, this::encodeClientCredential); + } + + private String encodeClientCredential(String value) { + return encode(value, UTF_8); + } + + /** + * @description The token endpoint used to obtain the OAuth2 access token. + * @required + * @example https://auth.example.com/oauth2/token + */ + @MCAttribute + @Required + public void setTokenUrl(String tokenUrl) { + this.tokenUrl = tokenUrl; + } + + public String getTokenUrl() { + return tokenUrl; + } + + /** + * @description The OAuth2 client id used for the token request. + * @required + * @example gateway + */ + @MCAttribute + @Required + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public String getClientId() { + return clientId; + } + + /** + * @description The OAuth2 client secret used for the token request. + * @required + * @example secret + */ + @MCAttribute + @Required + public void setClientSecret(String clientSecret) { + this.clientSecret = clientSecret; + } + + public String getClientSecret() { + return clientSecret; + } + + /** + * @description Space-separated scopes requested for the access token. + * @example read write + */ + @MCAttribute + public void setScope(String scope) { + this.scope = scope; + } + + public String getScope() { + return scope; + } +} diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2TokenBody.java b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2TokenBody.java index 4af05e4eae..e3892d674e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2TokenBody.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/OAuth2TokenBody.java @@ -14,10 +14,7 @@ package com.predic8.membrane.core.interceptor.oauth2; -import java.util.function.Function; - -import static java.net.URLEncoder.encode; -import static java.nio.charset.StandardCharsets.UTF_8; +import com.predic8.membrane.core.util.URLParamUtil; public class OAuth2TokenBody { private String code; @@ -70,27 +67,18 @@ public OAuth2TokenBody clientAssertion(String type, String assertion) { } public String build() { - StringBuilder r = new StringBuilder("grant_type=" + grantType); - appendParam(r, "refresh_token", refreshToken); - appendParam(r, "code", code); - appendParam(r, "redirect_uri", redirectUri); - appendParam(r, "scope", scope); - appendParam(r, "code_verifier", codeVerifier); - appendParam(r, "client_id", clientId); - appendParam(r, "client_secret", clientSecret); - appendParam(r, "client_assertion_type", clientAssertionType); - appendParam(r, "client_assertion", clientAssertion); - return r.toString(); - } - - private void appendParam(StringBuilder sb, String paramName, String paramValue) { - appendParam(sb, paramName, paramValue, e -> encode(e, UTF_8)); - } - - private void appendParam(StringBuilder sb, String paramName, String paramValue, Function encoder) { - if (paramValue == null) - return; - sb.append("&").append(paramName).append("=").append(encoder.apply(paramValue)); + return URLParamUtil.createQueryStringOmitNullValues( + "grant_type", grantType, + "refresh_token", refreshToken, + "code", code, + "redirect_uri", redirectUri, + "scope", scope, + "code_verifier", codeVerifier, + "client_id", clientId, + "client_secret", clientSecret, + "client_assertion_type", clientAssertionType, + "client_assertion", clientAssertion + ); } public OAuth2TokenBody redirectUri(String redirectUri) { diff --git a/core/src/main/java/com/predic8/membrane/core/transport/http/HttpClient.java b/core/src/main/java/com/predic8/membrane/core/transport/http/HttpClient.java index c156f27b48..b5e88ef873 100644 --- a/core/src/main/java/com/predic8/membrane/core/transport/http/HttpClient.java +++ b/core/src/main/java/com/predic8/membrane/core/transport/http/HttpClient.java @@ -80,7 +80,7 @@ private boolean dispatchCall(Exchange exc, String target, int attempt) throws Ex var outConType = connectionFactory.getConnection(exc, hcp, attempt); setRequestURI(exc.getRequest(), target, outConType.con()); - if (configuration.getProxy() != null && outConType.sslProvider() == null) { + if (configuration.getProxy() != null && configuration.getProxy().isAuthentication() && outConType.sslProvider() == null) { // if we use a proxy for a plain HTTP (=non-HTTPS) request, attach the proxy credentials. exc.getRequest().getHeader().setProxyAuthorization(configuration.getProxy().getCredentials()); } @@ -179,4 +179,4 @@ void setRequestURI(Request req, String dest, @NotNull Connection con) throws Mal protected void finalize() { close(); } -} \ No newline at end of file +} diff --git a/core/src/main/java/com/predic8/membrane/core/transport/http/client/ProxyConfiguration.java b/core/src/main/java/com/predic8/membrane/core/transport/http/client/ProxyConfiguration.java index a88a44a6d5..be201f9b8f 100644 --- a/core/src/main/java/com/predic8/membrane/core/transport/http/client/ProxyConfiguration.java +++ b/core/src/main/java/com/predic8/membrane/core/transport/http/client/ProxyConfiguration.java @@ -16,11 +16,11 @@ import com.predic8.membrane.annot.*; import com.predic8.membrane.core.config.security.*; +import com.predic8.membrane.core.util.security.BasicAuthenticationUtil; import java.util.*; -import static java.nio.charset.StandardCharsets.*; -import static org.apache.commons.codec.binary.Base64.*; +import static com.predic8.membrane.core.util.security.BasicAuthenticationUtil.createAuthorizationHeader; /** * @description

Configuration for an outbound HTTP proxy used by the HTTP client. @@ -167,7 +167,7 @@ public void setSslParser(SSLParser sslParser) { * The "Basic" authentication scheme defined in RFC 2617 does not properly define how to treat non-ASCII characters. */ public String getCredentials() { - return "Basic " + new String(encodeBase64((username + ":" + password).getBytes(UTF_8)), UTF_8); + return createAuthorizationHeader(username, password); } } diff --git a/core/src/main/java/com/predic8/membrane/core/util/URLParamUtil.java b/core/src/main/java/com/predic8/membrane/core/util/URLParamUtil.java index 0f9acd4b5d..e204e33ecc 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/URLParamUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/URLParamUtil.java @@ -115,6 +115,26 @@ public static String createQueryString(String... params) { return buf.toString(); } + public static String createQueryStringOmitNullValues(String... params) { + if (params.length % 2 != 0) + throw new IllegalArgumentException("params must contain key/value pairs"); + + StringBuilder buf = new StringBuilder(); + boolean first = true; + for (int i = 0; i < params.length; i += 2) { + if (params[i + 1] == null) + continue; + if (first) + first = false; + else + buf.append('&'); + buf.append(URLEncoder.encode(params[i], UTF_8)); + buf.append('='); + buf.append(URLEncoder.encode(params[i + 1], UTF_8)); + } + return buf.toString(); + } + /** * Parse a URL query into parameter pairs. The query is expected to be application/x-www-form-urlencoded . *

diff --git a/core/src/main/java/com/predic8/membrane/core/util/security/BasicAuthenticationUtil.java b/core/src/main/java/com/predic8/membrane/core/util/security/BasicAuthenticationUtil.java index fb66686f65..fab0db78f8 100644 --- a/core/src/main/java/com/predic8/membrane/core/util/security/BasicAuthenticationUtil.java +++ b/core/src/main/java/com/predic8/membrane/core/util/security/BasicAuthenticationUtil.java @@ -17,8 +17,10 @@ import com.predic8.membrane.core.exchange.*; import java.util.*; +import java.util.function.UnaryOperator; import static java.nio.charset.StandardCharsets.*; +import static java.util.Objects.requireNonNull; public class BasicAuthenticationUtil { @@ -32,8 +34,8 @@ public class BasicAuthenticationUtil { */ public record BasicCredentials(String username, String password) { public BasicCredentials { - Objects.requireNonNull(username, "Username cannot be null"); - Objects.requireNonNull(password, "Password cannot be null"); + requireNonNull(username, "Username cannot be null"); + requireNonNull(password, "Password cannot be null"); } /** @@ -68,6 +70,19 @@ public static BasicCredentials getCredentials(Exchange exc) { return parseCredentials(decodeAuthorizationHeader(exc)); } + public static String createAuthorizationHeader(String username, String password) { + return createAuthorizationHeader(username, password, UnaryOperator.identity()); + } + + public static String createAuthorizationHeader(String username, String password, UnaryOperator credentialEncoder) { + requireNonNull(username, "username cannot be null"); + requireNonNull(password, "password cannot be null"); + requireNonNull(credentialEncoder, "credentialEncoder cannot be null"); + + String credentials = credentialEncoder.apply(username) + ":" + credentialEncoder.apply(password); + return BASIC_PREFIX + Base64.getEncoder().encodeToString(credentials.getBytes(UTF_8)); + } + /** * Decodes the Authorization header and returns the raw credentials string. * @@ -126,4 +141,4 @@ private static BasicCredentials parseCredentials(String credentials) { return new BasicCredentials(username, password); } -} \ No newline at end of file +}