Skip to content

Commit 279a6d0

Browse files
Implement OAuth2RedirectUrlPort
1 parent bda478c commit 279a6d0

6 files changed

Lines changed: 258 additions & 1 deletion

File tree

src/main/java/com/databricks/jdbc/api/IDatabricksConnectionContext.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ public interface IDatabricksConnectionContext {
181181
*/
182182
String getOAuthRefreshToken();
183183

184+
/** Returns the list of OAuth2 redirect URL ports used for OAuth authentication. */
185+
List<Integer> getOAuth2RedirectUrlPorts();
186+
184187
String getGcpAuthType() throws DatabricksParsingException;
185188

186189
String getGoogleServiceAccount();

src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,27 @@ public String getOAuthRefreshToken() {
664664
getParameter(DatabricksJdbcUrlParams.OAUTH_REFRESH_TOKEN_2));
665665
}
666666

667+
@Override
668+
public List<Integer> getOAuth2RedirectUrlPorts() {
669+
String portsStr = getParameter(DatabricksJdbcUrlParams.OAUTH_REDIRECT_URL_PORT);
670+
671+
try {
672+
// Parse comma-separated list of ports
673+
return Arrays.stream(portsStr.split(","))
674+
.map(String::trim)
675+
.filter(s -> !s.isEmpty())
676+
.map(Integer::parseInt)
677+
.collect(Collectors.toList());
678+
} catch (NumberFormatException e) {
679+
LOGGER.warn(
680+
"Invalid port format in OAuth2RedirectUrlPort: {}. Using default port {}.",
681+
portsStr,
682+
DatabricksJdbcUrlParams.OAUTH_REDIRECT_URL_PORT.getDefaultValue());
683+
return List.of(
684+
Integer.parseInt(DatabricksJdbcUrlParams.OAUTH_REDIRECT_URL_PORT.getDefaultValue()));
685+
}
686+
}
687+
667688
@Override
668689
public Boolean getUseEmptyMetadata() {
669690
String param = getParameter(DatabricksJdbcUrlParams.USE_EMPTY_METADATA);

src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public enum DatabricksJdbcUrlParams {
3333
AUTH_FLOW("auth_flow", "Authentication flow"),
3434
OAUTH_REFRESH_TOKEN("Auth_RefreshToken", "OAuth2 Refresh Token"),
3535
OAUTH_REFRESH_TOKEN_2("OAuthRefreshToken", "OAuth2 Refresh Token"), // Same as OAUTH_REFRESH_TOKEN
36+
OAUTH_REDIRECT_URL_PORT("OAuth2RedirectUrlPort", "OAuth2 Redirect URL port", "8020"),
3637
PWD("pwd", "Password (used when AUTH_MECH = 3)", true),
3738
POLL_INTERVAL("asyncexecpollinterval", "Async execution poll interval", "200"),
3839
HTTP_PATH("httppath", "HTTP path", true),

src/main/java/com/databricks/jdbc/dbclient/impl/common/ClientConfigurator.java

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121
import com.databricks.sdk.core.ProxyConfig;
2222
import com.databricks.sdk.core.commons.CommonsHttpClient;
2323
import com.databricks.sdk.core.utils.Cloud;
24+
import java.io.IOException;
25+
import java.net.ServerSocket;
2426
import java.security.cert.*;
27+
import java.util.ArrayList;
2528
import java.util.Arrays;
29+
import java.util.List;
2630
import java.util.stream.Collectors;
2731
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
2832

@@ -129,17 +133,82 @@ public void setupOAuthConfig() throws DatabricksParsingException {
129133

130134
/** Setup the OAuth U2M authentication settings in the databricks config. */
131135
public void setupU2MConfig() throws DatabricksParsingException {
136+
int redirectPort = findAvailablePort(connectionContext.getOAuth2RedirectUrlPorts());
137+
String redirectUrl = String.format("http://localhost:%d", redirectPort);
138+
132139
databricksConfig
133140
.setAuthType(DatabricksJdbcConstants.U2M_AUTH_TYPE)
134141
.setHost(connectionContext.getHostForOAuth())
135142
.setClientId(connectionContext.getClientId())
136143
.setClientSecret(connectionContext.getClientSecret())
137-
.setOAuthRedirectUrl(DatabricksJdbcConstants.U2M_AUTH_REDIRECT_URL);
144+
.setOAuthRedirectUrl(redirectUrl);
145+
146+
LOGGER.info("Using OAuth redirect URL: {}", redirectUrl);
147+
138148
if (!databricksConfig.isAzure()) {
139149
databricksConfig.setScopes(connectionContext.getOAuthScopesForU2M());
140150
}
141151
}
142152

153+
/**
154+
* Finds the first available port from the provided list of ports. If a single port is provided,
155+
* it tries incremental ports (port, port+1, port+2, etc.) If multiple ports are provided, it
156+
* tries each port in the list.
157+
*
158+
* @param initialPorts List of ports to try
159+
* @return The first available port
160+
* @throws DatabricksException if no available port is found
161+
*/
162+
int findAvailablePort(List<Integer> initialPorts) {
163+
List<Integer> portsToTry;
164+
165+
// If single port provided, generate sequence of ports to try
166+
if (initialPorts.size() == 1) {
167+
int startPort = initialPorts.get(0);
168+
int maxAttempts = 20;
169+
portsToTry = new ArrayList<>(maxAttempts);
170+
for (int i = 0; i < maxAttempts; i++) {
171+
portsToTry.add(startPort + i);
172+
}
173+
LOGGER.debug(
174+
"Single port provided ({}), will try ports {} through {}",
175+
startPort,
176+
startPort,
177+
startPort + maxAttempts - 1);
178+
} else {
179+
portsToTry = initialPorts;
180+
LOGGER.debug("Multiple ports provided, will try: {}", portsToTry);
181+
}
182+
183+
// Try each port in the list
184+
for (int port : portsToTry) {
185+
if (isPortAvailable(port)) {
186+
return port;
187+
}
188+
LOGGER.debug("Port {} is not available, trying next port", port);
189+
}
190+
191+
// No available ports found
192+
LOGGER.error("No available ports found among: {}", portsToTry);
193+
throw new DatabricksException(
194+
"No available port found for OAuth redirect URL. Tried ports: " + portsToTry);
195+
}
196+
197+
/**
198+
* Checks if a port is available by trying to open a server socket on it.
199+
*
200+
* @param port Port to check
201+
* @return true if the port is available, false otherwise
202+
*/
203+
boolean isPortAvailable(int port) {
204+
try (ServerSocket serverSocket = new ServerSocket(port)) {
205+
serverSocket.setReuseAddress(true);
206+
return true;
207+
} catch (IOException e) {
208+
return false;
209+
}
210+
}
211+
143212
/** Setup the PAT authentication settings in the databricks config. */
144213
public void setupAccessTokenConfig() throws DatabricksParsingException {
145214
databricksConfig

src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,4 +450,48 @@ void testLogLevels() {
450450
assertEquals(getLogLevel(5), LogLevel.DEBUG);
451451
assertEquals(getLogLevel(6), LogLevel.TRACE);
452452
}
453+
454+
@Test
455+
public void testGetOAuth2RedirectUrlPorts() throws DatabricksSQLException {
456+
// Test default value
457+
Properties props = new Properties();
458+
DatabricksConnectionContext context =
459+
(DatabricksConnectionContext)
460+
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);
461+
List<Integer> ports = context.getOAuth2RedirectUrlPorts();
462+
assertEquals(1, ports.size());
463+
assertEquals(8020, ports.get(0)); // Default value
464+
465+
// Test single port
466+
props = new Properties();
467+
props.setProperty("OAuth2RedirectUrlPort", "9090");
468+
context =
469+
(DatabricksConnectionContext)
470+
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);
471+
ports = context.getOAuth2RedirectUrlPorts();
472+
assertEquals(1, ports.size());
473+
assertEquals(9090, ports.get(0));
474+
475+
// Test multiple ports
476+
props = new Properties();
477+
props.setProperty("OAuth2RedirectUrlPort", "9090,9091,9092");
478+
context =
479+
(DatabricksConnectionContext)
480+
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);
481+
ports = context.getOAuth2RedirectUrlPorts();
482+
assertEquals(3, ports.size());
483+
assertEquals(9090, ports.get(0));
484+
assertEquals(9091, ports.get(1));
485+
assertEquals(9092, ports.get(2));
486+
487+
// Test invalid format - should fallback to default
488+
props = new Properties();
489+
props.setProperty("OAuth2RedirectUrlPort", "invalid");
490+
context =
491+
(DatabricksConnectionContext)
492+
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);
493+
ports = context.getOAuth2RedirectUrlPorts();
494+
assertEquals(1, ports.size());
495+
assertEquals(8020, ports.get(0)); // Default value when format is invalid
496+
}
453497
}

src/test/java/com/databricks/jdbc/dbclient/impl/common/ClientConfiguratorTest.java

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.databricks.sdk.core.commons.CommonsHttpClient;
2323
import com.databricks.sdk.core.utils.Cloud;
2424
import java.io.IOException;
25+
import java.net.ServerSocket;
2526
import java.util.List;
2627
import java.util.Properties;
2728
import org.junit.jupiter.api.Test;
@@ -169,6 +170,7 @@ void getWorkspaceClient_OAuthWithBrowserBasedAuthentication_AuthenticatesCorrect
169170
when(mockContext.getClientSecret()).thenReturn("browser-client-secret");
170171
when(mockContext.getOAuthScopesForU2M()).thenReturn(List.of(new String[] {"scope1", "scope2"}));
171172
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
173+
when(mockContext.getOAuth2RedirectUrlPorts()).thenReturn(List.of(8020));
172174
configurator = new ClientConfigurator(mockContext);
173175
WorkspaceClient client = configurator.getWorkspaceClient();
174176
assertNotNull(client);
@@ -195,6 +197,7 @@ void getWorkspaceClient_OAuthWithBrowserBasedAuthentication_AuthenticatesCorrect
195197
when(mockContext.isOAuthDiscoveryModeEnabled()).thenReturn(true);
196198
when(mockContext.getOAuthDiscoveryURL()).thenReturn(TEST_DISCOVERY_URL);
197199
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
200+
when(mockContext.getOAuth2RedirectUrlPorts()).thenReturn(List.of(8020));
198201
configurator = new ClientConfigurator(mockContext);
199202
WorkspaceClient client = configurator.getWorkspaceClient();
200203
assertNotNull(client);
@@ -315,4 +318,120 @@ void setupM2MConfig_WithAzureTenantIdButNonAzureCloud_ThrowsException()
315318
verify(mockContext).getAzureTenantId();
316319
verify(mockContext, times(2)).getCloud();
317320
}
321+
322+
@Test
323+
void testFindAvailablePort() throws Exception {
324+
// Create a mockContext for the ClientConfigurator constructor
325+
when(mockContext.getAuthMech()).thenReturn(AuthMech.PAT);
326+
when(mockContext.getHostUrl()).thenReturn("https://test.databricks.com");
327+
when(mockContext.getToken()).thenReturn("test-token");
328+
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
329+
configurator = new ClientConfigurator(mockContext);
330+
331+
// Test with a single available port
332+
int availablePort = findFreePort();
333+
List<Integer> ports = List.of(availablePort);
334+
int result = configurator.findAvailablePort(ports);
335+
assertEquals(availablePort, result);
336+
337+
// Test with multiple ports, first unavailable
338+
int secondAvailablePort = findFreePort();
339+
try (ServerSocket serverSocket = new ServerSocket(availablePort)) {
340+
serverSocket.setReuseAddress(true);
341+
ports = List.of(availablePort, secondAvailablePort);
342+
result = configurator.findAvailablePort(ports);
343+
assertEquals(secondAvailablePort, result);
344+
}
345+
346+
// Test incremental search - first port unavailable, second available
347+
try (ServerSocket serverSocket = new ServerSocket(availablePort)) {
348+
serverSocket.setReuseAddress(true);
349+
ports = List.of(availablePort);
350+
result = configurator.findAvailablePort(ports);
351+
assertEquals(availablePort + 1, result);
352+
}
353+
}
354+
355+
@Test
356+
void testFindAvailablePortThrowsExceptionWhenNoPortsAvailable() throws Exception {
357+
// Create a mockContext for the ClientConfigurator constructor
358+
when(mockContext.getAuthMech()).thenReturn(AuthMech.PAT);
359+
when(mockContext.getHostUrl()).thenReturn("https://test.databricks.com");
360+
when(mockContext.getToken()).thenReturn("test-token");
361+
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
362+
configurator = new ClientConfigurator(mockContext);
363+
364+
// Use a port that is likely to be available
365+
int port1 = findFreePort();
366+
int port2 = findFreePort();
367+
if (port1 == port2) {
368+
port2 = port1 + 1;
369+
}
370+
371+
// Occupy the ports to make them unavailable
372+
try (ServerSocket socket1 = new ServerSocket(port1);
373+
ServerSocket socket2 = new ServerSocket(port2)) {
374+
socket1.setReuseAddress(true);
375+
socket2.setReuseAddress(true);
376+
377+
// First test with multiple specified ports
378+
List<Integer> unavailablePorts = List.of(port1, port2);
379+
DatabricksException exception =
380+
assertThrows(
381+
DatabricksException.class, () -> configurator.findAvailablePort(unavailablePorts));
382+
assertTrue(exception.getMessage().contains("No available port found"));
383+
384+
// Now test with single port and verify it tries incremental ports
385+
// We need to create a subclass to control isPortAvailable behavior
386+
ClientConfigurator testConfigurator =
387+
new ClientConfigurator(mockContext) {
388+
@Override
389+
protected boolean isPortAvailable(int port) {
390+
return false; // All ports are unavailable
391+
}
392+
};
393+
394+
exception =
395+
assertThrows(
396+
DatabricksException.class, () -> testConfigurator.findAvailablePort(List.of(port1)));
397+
assertTrue(exception.getMessage().contains("No available port found"));
398+
}
399+
}
400+
401+
/** Utility method to find a free port */
402+
private int findFreePort() {
403+
try (ServerSocket socket = new ServerSocket(0)) {
404+
socket.setReuseAddress(true);
405+
return socket.getLocalPort();
406+
} catch (IOException e) {
407+
throw new RuntimeException("Failed to find free port", e);
408+
}
409+
}
410+
411+
@Test
412+
void getWorkspaceClient_OAuthWithBrowserBasedAuthentication_SetsCustomRedirectUrl()
413+
throws Exception {
414+
// We'll mock getOAuth2RedirectUrlPorts to return a predefined list
415+
int testPort = findFreePort();
416+
when(mockContext.getAuthMech()).thenReturn(AuthMech.OAUTH);
417+
when(mockContext.getAuthFlow()).thenReturn(AuthFlow.BROWSER_BASED_AUTHENTICATION);
418+
when(mockContext.getHostForOAuth()).thenReturn("https://oauth-browser.databricks.com");
419+
when(mockContext.getClientId()).thenReturn("browser-client-id");
420+
when(mockContext.getClientSecret()).thenReturn("browser-client-secret");
421+
when(mockContext.getOAuthScopesForU2M()).thenReturn(List.of(new String[] {"scope1", "scope2"}));
422+
when(mockContext.getOAuth2RedirectUrlPorts()).thenReturn(List.of(testPort));
423+
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
424+
425+
configurator = new ClientConfigurator(mockContext);
426+
WorkspaceClient client = configurator.getWorkspaceClient();
427+
assertNotNull(client);
428+
DatabricksConfig config = client.config();
429+
430+
assertEquals("https://oauth-browser.databricks.com", config.getHost());
431+
assertEquals("browser-client-id", config.getClientId());
432+
assertEquals("browser-client-secret", config.getClientSecret());
433+
assertEquals(List.of(new String[] {"scope1", "scope2"}), config.getScopes());
434+
assertEquals("http://localhost:" + testPort, config.getOAuthRedirectUrl());
435+
assertEquals(DatabricksJdbcConstants.U2M_AUTH_TYPE, config.getAuthType());
436+
}
318437
}

0 commit comments

Comments
 (0)