Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/src/main/java/com/predic8/membrane/core/http/Header.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import com.predic8.membrane.annot.Constants;
import com.predic8.membrane.core.http.cookie.*;
import com.predic8.membrane.core.util.security.BasicAuthenticationUtil;
import com.predic8.membrane.core.util.*;
import jakarta.mail.internet.*;
import org.jetbrains.annotations.*;
Expand All @@ -31,6 +32,7 @@

import static com.predic8.membrane.core.http.MimeType.*;
import static com.predic8.membrane.core.util.HttpUtil.*;
import static com.predic8.membrane.core.util.security.BasicAuthenticationUtil.createAuthorizationHeader;
import static java.nio.charset.StandardCharsets.*;
import static java.util.Arrays.*;
import static java.util.Collections.*;
Expand Down Expand Up @@ -428,9 +430,7 @@ public String toString() {
* @param password the password for authentication
*/
public void setAuthorization(String user, String password) {
setValue("Authorization", "Basic "
+ new String(encodeBase64((user + ":" + password)
.getBytes(UTF_8)), UTF_8));
setValue("Authorization", createAuthorizationHeader(user, password));
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package com.predic8.membrane.core.interceptor.oauth2;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.predic8.membrane.annot.MCAttribute;
import com.predic8.membrane.annot.MCElement;
import com.predic8.membrane.annot.Required;
import com.predic8.membrane.core.exchange.Exchange;
import com.predic8.membrane.core.interceptor.AbstractInterceptor;
import com.predic8.membrane.core.interceptor.Outcome;
import com.predic8.membrane.core.transport.http.HttpClient;
import com.predic8.membrane.core.util.security.BasicAuthenticationUtil;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static com.predic8.membrane.core.exceptions.ProblemDetails.gateway;
import static com.predic8.membrane.core.http.Header.ACCEPT;
import static com.predic8.membrane.core.http.Header.AUTHORIZATION;
import static com.predic8.membrane.core.http.MimeType.APPLICATION_JSON;
import static com.predic8.membrane.core.http.MimeType.APPLICATION_X_WWW_FORM_URLENCODED;
import static com.predic8.membrane.core.http.Request.post;
import static com.predic8.membrane.core.interceptor.Interceptor.Flow.Set.REQUEST_FLOW;
import static com.predic8.membrane.core.interceptor.Outcome.ABORT;
import static com.predic8.membrane.core.interceptor.Outcome.CONTINUE;
import static com.predic8.membrane.core.util.URLParamUtil.createQueryStringOmitNullValues;
import static com.predic8.membrane.core.util.security.BasicAuthenticationUtil.createAuthorizationHeader;
import static java.lang.Math.max;
import static java.lang.System.currentTimeMillis;
import static java.net.URLEncoder.encode;
import static java.nio.charset.StandardCharsets.UTF_8;

/**
* @description Obtains an OAuth2 access token using the client credentials flow and forwards the request with a Bearer token.
* @yaml <pre><code>
* api:
* port: 2000
* flow:
* - oauth2Client:
* tokenUrl: https://auth.example.com/oauth2/token
* clientId: gateway
* clientSecret: secret
* scope: read write
* target:
* url: https://api.example.com
* </code></pre>
*/
@MCElement(name="oauth2Client")
public class OAuth2ClientInterceptor extends AbstractInterceptor {

private static final Logger log = LoggerFactory.getLogger(OAuth2ClientInterceptor.class);

private String tokenUrl;

private String clientId;

private String clientSecret;

private String scope;

private final ObjectMapper objectMapper = new ObjectMapper();

private HttpClient httpClient;
private final Object tokenLock = new Object();

private volatile String cachedAccessToken;
private volatile long cachedAccessTokenValidUntilEpochMillis;

@Override
public void init() {
super.init();
name = "OAuth2 Client";
setAppliedFlow(REQUEST_FLOW);

httpClient = router.getHttpClientFactory().createClient(router.getHttpClientConfig());
}

@Override
public Outcome handleRequest(Exchange exc) {
try {
String token = getAccessToken();
exc.getRequest().getHeader().setValue(AUTHORIZATION, "Bearer " + token);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return CONTINUE;
} catch (Exception e) {
log.warn("Could not obtain OAuth2 access token from {}: {}", tokenUrl, e.getMessage());
log.debug("OAuth2 token request failed.", e);
gateway(router.getConfiguration().isProduction(), getDisplayName())
.title("Bad Gateway")
.status(502)
.addSubSee("oauth2-token")
.detail("Could not obtain an OAuth2 access token.")
.buildAndSetResponse(exc);
Comment thread
christiangoerdes marked this conversation as resolved.
return ABORT;
}
}

private String getAccessToken() throws Exception {
if (hasValidCachedToken()) {
return cachedAccessToken;
}

synchronized (tokenLock) {
if (hasValidCachedToken()) {
return cachedAccessToken;
}
return fetchAccessToken();
}
}

private boolean hasValidCachedToken() {
return cachedAccessToken != null && currentTimeMillis() < cachedAccessTokenValidUntilEpochMillis;
}

private String fetchAccessToken() throws Exception {
Exchange tokenExchange = post(tokenUrl)
.contentType(APPLICATION_X_WWW_FORM_URLENCODED)
.header(ACCEPT, APPLICATION_JSON)
.header(AUTHORIZATION, buildBasicAuthorization())
.body(buildTokenRequestBody())
.buildExchange();

httpClient.call(tokenExchange);

var response = tokenExchange.getResponse();
String responseBody = response.getBodyAsStringDecoded();
if (response.getStatusCode() != 200) {
throw new IllegalStateException("Authorization server returned status " + response.getStatusCode() + ".");
}

var responseJson = objectMapper.readTree(responseBody);
Comment thread
christiangoerdes marked this conversation as resolved.
String token = extractAccessToken(responseJson);

updateTokenCache(token, responseJson.path("expires_in").asLong(-1));
return token;
}

private static @NotNull String extractAccessToken(JsonNode responseJson) {
String token = responseJson.path("access_token").asText(null);
if (token == null || token.isBlank()) {
throw new IllegalStateException("Authorization server did not return an access token.");
}
return token;
}

private void updateTokenCache(String token, long expiresInSeconds) {
if (expiresInSeconds <= 0) {
cachedAccessToken = null;
cachedAccessTokenValidUntilEpochMillis = 0;
log.debug("Token response from {} has no usable expires_in. Token will not be cached.", tokenUrl);
return;
}

// Refresh slightly before expiry to avoid sending a token that expires mid-request.
long refreshBufferSeconds = Math.min(30, max(1, expiresInSeconds / 10));

cachedAccessToken = token;
cachedAccessTokenValidUntilEpochMillis = currentTimeMillis() + max(1, expiresInSeconds - refreshBufferSeconds) * 1000;
}

private String buildTokenRequestBody() {
return createQueryStringOmitNullValues(
"grant_type", "client_credentials",
"scope", scope == null || scope.isBlank() ? null : scope
);
}

private String buildBasicAuthorization() {
Comment thread
christiangoerdes marked this conversation as resolved.
return createAuthorizationHeader(clientId, clientSecret, this::encodeClientCredential);
}

private String encodeClientCredential(String value) {
return encode(value, UTF_8);
}

/**
* @description The token endpoint used to obtain the OAuth2 access token.
* @required
* @example https://auth.example.com/oauth2/token
*/
@MCAttribute
@Required
public void setTokenUrl(String tokenUrl) {
this.tokenUrl = tokenUrl;
}

public String getTokenUrl() {
return tokenUrl;
}

/**
* @description The OAuth2 client id used for the token request.
* @required
* @example gateway
*/
@MCAttribute
@Required
public void setClientId(String clientId) {
this.clientId = clientId;
}

public String getClientId() {
return clientId;
}

/**
* @description The OAuth2 client secret used for the token request.
* @required
* @example secret
*/
@MCAttribute
@Required
public void setClientSecret(String clientSecret) {
this.clientSecret = clientSecret;
}

public String getClientSecret() {
return clientSecret;
}

/**
* @description Space-separated scopes requested for the access token.
* @example read write
*/
@MCAttribute
public void setScope(String scope) {
this.scope = scope;
}

public String getScope() {
return scope;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@

package com.predic8.membrane.core.interceptor.oauth2;

import java.util.function.Function;

import static java.net.URLEncoder.encode;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.predic8.membrane.core.util.URLParamUtil;

public class OAuth2TokenBody {
private String code;
Expand Down Expand Up @@ -70,27 +67,18 @@ public OAuth2TokenBody clientAssertion(String type, String assertion) {
}

public String build() {
StringBuilder r = new StringBuilder("grant_type=" + grantType);
appendParam(r, "refresh_token", refreshToken);
appendParam(r, "code", code);
appendParam(r, "redirect_uri", redirectUri);
appendParam(r, "scope", scope);
appendParam(r, "code_verifier", codeVerifier);
appendParam(r, "client_id", clientId);
appendParam(r, "client_secret", clientSecret);
appendParam(r, "client_assertion_type", clientAssertionType);
appendParam(r, "client_assertion", clientAssertion);
return r.toString();
}

private void appendParam(StringBuilder sb, String paramName, String paramValue) {
appendParam(sb, paramName, paramValue, e -> encode(e, UTF_8));
}

private void appendParam(StringBuilder sb, String paramName, String paramValue, Function<String, String> encoder) {
if (paramValue == null)
return;
sb.append("&").append(paramName).append("=").append(encoder.apply(paramValue));
return URLParamUtil.createQueryStringOmitNullValues(
"grant_type", grantType,
"refresh_token", refreshToken,
"code", code,
"redirect_uri", redirectUri,
"scope", scope,
"code_verifier", codeVerifier,
"client_id", clientId,
"client_secret", clientSecret,
"client_assertion_type", clientAssertionType,
"client_assertion", clientAssertion
);
}

public OAuth2TokenBody redirectUri(String redirectUri) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private boolean dispatchCall(Exchange exc, String target, int attempt) throws Ex
var outConType = connectionFactory.getConnection(exc, hcp, attempt);
setRequestURI(exc.getRequest(), target, outConType.con());

if (configuration.getProxy() != null && outConType.sslProvider() == null) {
if (configuration.getProxy() != null && configuration.getProxy().isAuthentication() && outConType.sslProvider() == null) {
// if we use a proxy for a plain HTTP (=non-HTTPS) request, attach the proxy credentials.
exc.getRequest().getHeader().setProxyAuthorization(configuration.getProxy().getCredentials());
}
Expand Down Expand Up @@ -179,4 +179,4 @@ void setRequestURI(Request req, String dest, @NotNull Connection con) throws Mal
protected void finalize() {
close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import com.predic8.membrane.annot.*;
import com.predic8.membrane.core.config.security.*;
import com.predic8.membrane.core.util.security.BasicAuthenticationUtil;

import java.util.*;

import static java.nio.charset.StandardCharsets.*;
import static org.apache.commons.codec.binary.Base64.*;
import static com.predic8.membrane.core.util.security.BasicAuthenticationUtil.createAuthorizationHeader;

/**
* @description <p>Configuration for an outbound HTTP proxy used by the HTTP client.
Expand Down Expand Up @@ -167,7 +167,7 @@ public void setSslParser(SSLParser sslParser) {
* The "Basic" authentication scheme defined in RFC 2617 does not properly define how to treat non-ASCII characters.
*/
public String getCredentials() {
return "Basic " + new String(encodeBase64((username + ":" + password).getBytes(UTF_8)), UTF_8);
return createAuthorizationHeader(username, password);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ public static String createQueryString(String... params) {
return buf.toString();
}

public static String createQueryStringOmitNullValues(String... params) {
if (params.length % 2 != 0)
throw new IllegalArgumentException("params must contain key/value pairs");

StringBuilder buf = new StringBuilder();
boolean first = true;
for (int i = 0; i < params.length; i += 2) {
if (params[i + 1] == null)
continue;
if (first)
first = false;
else
buf.append('&');
buf.append(URLEncoder.encode(params[i], UTF_8));
buf.append('=');
buf.append(URLEncoder.encode(params[i + 1], UTF_8));
}
return buf.toString();
}

/**
* Parse a URL query into parameter pairs. The query is expected to be application/x-www-form-urlencoded .
* <p>
Expand Down
Loading
Loading