Skip to content

Commit 149d867

Browse files
authored
Token Federation Support (#795)
* Basic working example for External Browser Auth * Working prototype [MI, GCP PrivateKey left for testing] * working state * M2M Auth integration fake service tests * MVN clean package * Checked Private Key cred * Fixed issue of repeated refreshing of tokens * Added e2e tests * Removed the private key test * Added identity federation client id * Reverted fake service stubs * Reverted fake service * Merged Main and bypassed fake service stubs * Updated changelog * PR comments * added java doc * Tested the token exchange params * Refractored code * Added to telemetry model * Removed Inhouse Token exchange file * Added the OIDC endpoint code * Removed the expiry comment * Fixed private key test failing * Updated next changelog * Changed param description * Added OAuth Tests * Updated error handling * Fixed the logging * Added bug catcher creds
1 parent 9cc1bd4 commit 149d867

15 files changed

Lines changed: 536 additions & 19 deletions

File tree

.github/workflows/bugCatcher.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,9 @@ jobs:
4040
DATABRICKS_JDBC_M2M_CLIENT_SECRET: ${{ secrets.DATABRICKS_JDBC_M2M_CLIENT_SECRET }}
4141
DATABRICKS_JDBC_M2M_HTTP_PATH: ${{ secrets.DATABRICKS_JDBC_M2M_HTTP_PATH }}
4242
DATABRICKS_JDBC_M2M_HOST: ${{ secrets.DATABRICKS_JDBC_M2M_HOST }}
43+
DATABRICKS_JDBC_SP_TOKEN_FED_HOST : ${{secrets.DATABRICKS_JDBC_SP_TOKEN_FED_HOST}}
44+
DATABRICKS_JDBC_SP_TOKEN_FED_HTTP_PATH : ${{secrets.DATABRICKS_JDBC_SP_TOKEN_FED_HTTP_PATH}}
45+
DATABRICKS_JDBC_SP_TOKEN_FED_CLIENT_ID : ${{secrets.DATABRICKS_JDBC_SP_TOKEN_FED_CLIENT_ID}}
46+
DATABRICKS_JDBC_SP_TOKEN_FED_CLIENT_SECRET : ${{secrets.DATABRICKS_JDBC_SP_TOKEN_FED_CLIENT_SECRET}}
47+
DATABRICKS_SP_TOKEN_FED_FEDERATION_ID : ${{secrets.DATABRICKS_SP_TOKEN_FED_FEDERATION_ID}}
48+
DATABRICKS_SP_TOKEN_FED_AZURE_TENANT_ID : ${{secrets.DATABRICKS_SP_TOKEN_FED_AZURE_TENANT_ID}}

NEXT_CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
## [Unreleased]
44

55
### Added
6-
-
6+
- Support for Token Exchange in OAuth flows where in third party tokens are exchanged for InHouse tokens.
77

88
### Updated
99
-

src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,11 @@ public boolean isOAuthDiscoveryModeEnabled() {
662662
.equals("1");
663663
}
664664

665+
@Override
666+
public String getIdentityFederationClientId() {
667+
return getParameter(DatabricksJdbcUrlParams.IDENTITY_FEDERATION_CLIENT_ID);
668+
}
669+
665670
@Override
666671
public String getOAuthDiscoveryURL() {
667672
return getParameter(

src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ public interface IDatabricksConnectionContext {
176176
/** Returns whether OAuth2 discovery mode is enabled, which fetches endpoints dynamically. */
177177
boolean isOAuthDiscoveryModeEnabled();
178178

179+
/**
180+
* OAuth Client Id for identity federation which is used in exchanging the access token with
181+
* Databricks in-house token
182+
*/
183+
String getIdentityFederationClientId();
184+
179185
/** Returns the discovery URL used to obtain the OAuth2 token and authorization endpoints. */
180186
String getOAuthDiscoveryURL();
181187

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
package com.databricks.jdbc.auth;
2+
3+
import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
4+
import com.databricks.jdbc.common.DatabricksJdbcConstants;
5+
import com.databricks.jdbc.common.util.DriverUtil;
6+
import com.databricks.jdbc.common.util.JsonUtil;
7+
import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
8+
import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClientFactory;
9+
import com.databricks.jdbc.exception.DatabricksDriverException;
10+
import com.databricks.jdbc.log.JdbcLogger;
11+
import com.databricks.jdbc.log.JdbcLoggerFactory;
12+
import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode;
13+
import com.databricks.sdk.core.*;
14+
import com.databricks.sdk.core.oauth.OAuthResponse;
15+
import com.databricks.sdk.core.oauth.RefreshableTokenSource;
16+
import com.databricks.sdk.core.oauth.Token;
17+
import com.google.common.annotations.VisibleForTesting;
18+
import com.nimbusds.jwt.JWTClaimsSet;
19+
import com.nimbusds.jwt.SignedJWT;
20+
import java.net.MalformedURLException;
21+
import java.net.URL;
22+
import java.nio.charset.StandardCharsets;
23+
import java.text.ParseException;
24+
import java.time.Instant;
25+
import java.time.LocalDateTime;
26+
import java.time.ZoneId;
27+
import java.util.HashMap;
28+
import java.util.Map;
29+
import java.util.Optional;
30+
import java.util.stream.Collectors;
31+
import org.apache.http.HttpHeaders;
32+
import org.apache.http.HttpResponse;
33+
import org.apache.http.client.entity.UrlEncodedFormEntity;
34+
import org.apache.http.client.methods.HttpPost;
35+
import org.apache.http.client.utils.URIBuilder;
36+
import org.apache.http.message.BasicNameValuePair;
37+
38+
/**
39+
* Implementation of the Credential Provider that exchanges the third party access token for a
40+
* Databricks InHouse Token This class exchanges the access token if the issued token is not from
41+
* the same host as the Databricks host.
42+
*
43+
* <p>Note: In future this class will be replaced with the Databricks SDK implementation
44+
*/
45+
public class DatabricksTokenFederationProvider extends RefreshableTokenSource
46+
implements CredentialsProvider {
47+
48+
private static final JdbcLogger LOGGER =
49+
JdbcLoggerFactory.getLogger(DatabricksTokenFederationProvider.class);
50+
private static final Map<String, String> TOKEN_EXCHANGE_PARAMS =
51+
Map.of(
52+
"grant_type",
53+
"urn:ietf:params:oauth:grant-type:token-exchange",
54+
"scope",
55+
"sql",
56+
"subject_token_type",
57+
"urn:ietf:params:oauth:token-type:jwt",
58+
"return_original_token_if_authenticated",
59+
"true");
60+
private static final String TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token";
61+
private final IDatabricksConnectionContext connectionContext;
62+
private final CredentialsProvider credentialsProvider;
63+
private DatabricksConfig config;
64+
private Map<String, String> externalProviderHeaders;
65+
private IDatabricksHttpClient hc;
66+
67+
public DatabricksTokenFederationProvider(
68+
IDatabricksConnectionContext connectionContext, CredentialsProvider credentialsProvider) {
69+
this.connectionContext = connectionContext;
70+
this.credentialsProvider = credentialsProvider;
71+
this.externalProviderHeaders = new HashMap<>();
72+
this.hc = DatabricksHttpClientFactory.getInstance().getClient(connectionContext);
73+
this.token =
74+
new Token(
75+
DatabricksJdbcConstants.EMPTY_STRING,
76+
DatabricksJdbcConstants.EMPTY_STRING,
77+
DatabricksJdbcConstants.EMPTY_STRING,
78+
LocalDateTime.now().minusMinutes(1));
79+
}
80+
81+
@VisibleForTesting
82+
DatabricksTokenFederationProvider(
83+
IDatabricksConnectionContext connectionContext,
84+
CredentialsProvider credentialsProvider,
85+
DatabricksConfig config) {
86+
this.connectionContext = connectionContext;
87+
this.credentialsProvider = credentialsProvider;
88+
this.config = config;
89+
this.externalProviderHeaders = new HashMap<>();
90+
this.token =
91+
new Token(
92+
DatabricksJdbcConstants.EMPTY_STRING,
93+
DatabricksJdbcConstants.EMPTY_STRING,
94+
DatabricksJdbcConstants.EMPTY_STRING,
95+
LocalDateTime.now().minusMinutes(1));
96+
}
97+
98+
public String authType() {
99+
return this.credentialsProvider.authType();
100+
}
101+
102+
public CredentialsProvider getCredentialsProvider() {
103+
return this.credentialsProvider;
104+
}
105+
106+
public HeaderFactory configure(DatabricksConfig databricksConfig) {
107+
LOGGER.debug("DatabricksTokenFederation configure");
108+
109+
// ByPassing the token exchange for fake service test
110+
// Issue: Unable to map token exchange URL to localhost (WireMock host)
111+
// because URLs are generated inside SDK
112+
if (DriverUtil.isRunningAgainstFake()) {
113+
return this.credentialsProvider.configure(databricksConfig);
114+
}
115+
116+
this.config = databricksConfig;
117+
return () -> {
118+
Token exchangedToken = getToken();
119+
Map<String, String> headers = new HashMap<>(this.externalProviderHeaders);
120+
headers.put(
121+
HttpHeaders.AUTHORIZATION,
122+
exchangedToken.getTokenType() + " " + exchangedToken.getAccessToken());
123+
return headers;
124+
};
125+
}
126+
127+
protected Token refresh() {
128+
this.externalProviderHeaders = this.credentialsProvider.configure(this.config).headers();
129+
String[] tokenInfo = extractTokenInfoFromHeader(this.externalProviderHeaders);
130+
String accessTokenType = tokenInfo[0];
131+
String accessToken = tokenInfo[1];
132+
133+
try {
134+
SignedJWT signedJWT = SignedJWT.parse(accessToken);
135+
JWTClaimsSet claims = signedJWT.getJWTClaimsSet();
136+
137+
Optional<Token> optionalToken = Optional.empty();
138+
if (!isSameHost(claims.getIssuer(), this.config.getHost())) {
139+
optionalToken = tryTokenExchange(accessToken, accessTokenType);
140+
}
141+
if (optionalToken.isEmpty()) {
142+
optionalToken = Optional.of(createToken(accessToken, accessTokenType));
143+
}
144+
return optionalToken.get();
145+
} catch (Exception e) {
146+
LOGGER.error(e, "Failed to refresh access token");
147+
throw new DatabricksDriverException(
148+
"Failed to refresh access token", e, DatabricksDriverErrorCode.AUTH_ERROR);
149+
}
150+
}
151+
152+
@VisibleForTesting
153+
Optional<Token> tryTokenExchange(String accessToken, String accessTokenType) {
154+
LOGGER.debug(
155+
"Token tryTokenExchange(String accessToken, String accessTokenType = {})", accessTokenType);
156+
try {
157+
return Optional.of(exchangeToken(accessToken));
158+
} catch (Exception e) {
159+
LOGGER.error(e, "Token exchange failed, falling back to using external token");
160+
return Optional.empty();
161+
}
162+
}
163+
164+
@VisibleForTesting
165+
Token createToken(String accessToken, String accessTokenType) throws ParseException {
166+
SignedJWT signedJWT = SignedJWT.parse(accessToken);
167+
JWTClaimsSet claims = signedJWT.getJWTClaimsSet();
168+
169+
Instant expirationTimeInstant = Instant.ofEpochMilli(claims.getExpirationTime().getTime());
170+
ZoneId zoneId = ZoneId.systemDefault();
171+
LocalDateTime expiry = expirationTimeInstant.atZone(zoneId).toLocalDateTime();
172+
return new Token(accessToken, accessTokenType, DatabricksJdbcConstants.EMPTY_STRING, expiry);
173+
}
174+
175+
@VisibleForTesting
176+
Token exchangeToken(String accessToken) {
177+
LOGGER.debug("Token exchangeToken( String accessToken )");
178+
final String tokenUrl = this.config.getHost() + TOKEN_EXCHANGE_ENDPOINT;
179+
180+
Map<String, String> params = new HashMap<>(TOKEN_EXCHANGE_PARAMS);
181+
params.put("subject_token", accessToken);
182+
183+
if (connectionContext.getIdentityFederationClientId() != null) {
184+
params.put("client_id", connectionContext.getIdentityFederationClientId());
185+
}
186+
187+
Map<String, String> headers = new HashMap<>();
188+
headers.put(HttpHeaders.ACCEPT, "*/*");
189+
headers.put(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded");
190+
191+
return retrieveToken(hc, tokenUrl, params, headers);
192+
}
193+
194+
@VisibleForTesting
195+
Token retrieveToken(
196+
IDatabricksHttpClient hc,
197+
String tokenUrl,
198+
Map<String, String> params,
199+
Map<String, String> headers) {
200+
try {
201+
URIBuilder uriBuilder = new URIBuilder(tokenUrl);
202+
HttpPost postRequest = new HttpPost(uriBuilder.build());
203+
postRequest.setEntity(
204+
new UrlEncodedFormEntity(
205+
params.entrySet().stream()
206+
.map(e -> new BasicNameValuePair(e.getKey(), e.getValue()))
207+
.collect(Collectors.toList()),
208+
StandardCharsets.UTF_8));
209+
headers.forEach(postRequest::setHeader);
210+
HttpResponse response = hc.execute(postRequest);
211+
OAuthResponse resp =
212+
JsonUtil.getMapper().readValue(response.getEntity().getContent(), OAuthResponse.class);
213+
return createToken(resp.getAccessToken(), resp.getTokenType());
214+
} catch (Exception e) {
215+
LOGGER.error(e, "Failed to retrieve the exchanged token");
216+
throw new DatabricksDriverException(
217+
"Failed to retrieve the exchanged token", e, DatabricksDriverErrorCode.AUTH_ERROR);
218+
}
219+
}
220+
221+
private boolean isSameHost(String url1, String url2) {
222+
try {
223+
String host1 = new URL(url1).getHost();
224+
String host2 = new URL(url2).getHost();
225+
return host1.equals(host2);
226+
} catch (MalformedURLException e) {
227+
LOGGER.error(e, "Unable to parse URL String");
228+
}
229+
return false;
230+
}
231+
232+
private String[] extractTokenInfoFromHeader(Map<String, String> headers) {
233+
String authHeader = headers.get(HttpHeaders.AUTHORIZATION);
234+
try {
235+
return authHeader.split(" ", 2);
236+
} catch (NullPointerException e) {
237+
LOGGER.error(e, "Failed to extract token info from header");
238+
throw new DatabricksDriverException(
239+
"Failed to extract token info from header", e, DatabricksDriverErrorCode.AUTH_ERROR);
240+
}
241+
}
242+
}

src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ public enum DatabricksJdbcUrlParams {
5151
AUTH_SCOPE("Auth_Scope", "Authentication scope", "all-apis"),
5252
OIDC_DISCOVERY_ENDPOINT("OIDCDiscoveryEndpoint", "OIDC Discovery Endpoint"),
5353
DISCOVERY_URL("OAuthDiscoveryURL", "OAuth discovery URL"), // Same as OIDC_DISCOVERY_ENDPOINT
54+
IDENTITY_FEDERATION_CLIENT_ID(
55+
"Identity_Federation_Client_Id", "OAuth Client ID for Token Federation"),
5456
ENABLE_ARROW("EnableArrow", "Enable Arrow", "1"),
5557
DIRECT_RESULT("EnableDirectResults", "Enable direct results", "1"),
5658
LZ4_COMPRESSION_FLAG(

src/main/java/com/databricks/jdbc/common/util/DatabricksDriverPropertyUtil.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ public static List<DriverPropertyInfo> buildMissingPropertiesList(
9191
addMissingProperty(missingPropertyInfos, connectionContext, PWD, true);
9292
} else if (authMech == OAUTH) {
9393
AuthFlow authFlow = connectionContext.getAuthFlow();
94+
addMissingProperty(
95+
missingPropertyInfos, connectionContext, IDENTITY_FEDERATION_CLIENT_ID, false);
9496

9597
if (connectionContext.isPropertyPresent(AUTH_FLOW)) {
9698
switch (authFlow) {

0 commit comments

Comments
 (0)