Skip to content
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added
- Support for fetching tables and views across all catalogs using SHOW TABLES FROM/IN ALL CATALOGS in the SQL Exec API.
- Support for Token Exchange in OAuth flows where in third party tokens are exchanged for InHouse tokens.
- Support for fetching schemas across all catalogs in the SQL Exec API client.
- Added support for polling of statementStatus and sqlState for async SQL execution.

### Updated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.common.*;
import com.databricks.jdbc.common.util.DriverUtil;
import com.databricks.jdbc.common.util.JdbcThreadUtils;
import com.databricks.jdbc.dbclient.impl.common.MetadataResultSetBuilder;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
import com.databricks.jdbc.exception.DatabricksSQLException;
Expand All @@ -19,6 +20,8 @@
import com.databricks.sdk.service.sql.StatementState;
import java.sql.*;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class DatabricksDatabaseMetaData implements DatabaseMetaData {

Expand All @@ -37,6 +40,9 @@ public class DatabricksDatabaseMetaData implements DatabaseMetaData {
public static final String SYSTEM_FUNCTIONS = "DATABASE,IFNULL,USER";
public static final String TIME_DATE_FUNCTIONS =
"CURDATE,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURTIME,DAYNAME,DAYOFMONTH,DAYOFWEEK,DAYOFYEAR,HOUR,MINUTE,MONTH,MONTHNAME,NOW,QUARTER,SECOND,TIMESTAMPADD,TIMESTAMPDIFF,WEEK,YEAR";
private static final Object THREAD_POOL_LOCK = new Object();
private static ExecutorService schemasThreadPool = null;
private static final int DEFAULT_MAX_THREADS = 10;
private final IDatabricksConnectionInternal connection;
private final IDatabricksSession session;
private final MetadataResultSetBuilder metadataResultSetBuilder;
Expand Down Expand Up @@ -994,9 +1000,6 @@ public ResultSet getTables(
@Override
public ResultSet getSchemas() throws SQLException {
LOGGER.debug("public ResultSet getSchemas()");
if (session.getConnectionContext().getClientType() == DatabricksClientType.SEA) {
return metadataResultSetBuilder.getSchemasResult(null);
}
return getSchemas(null /* catalog */, null /* schema pattern */);
Comment thread
jayantsing-db marked this conversation as resolved.
Outdated
}

Expand Down Expand Up @@ -1498,11 +1501,52 @@ public RowIdLifetime getRowIdLifetime() throws SQLException {
@Override
public ResultSet getSchemas(String catalog, String schemaPattern) throws SQLException {
LOGGER.debug(
String.format(
"public ResultSet getSchemas(String catalog = {}, String schemaPattern = {})",
catalog,
schemaPattern));
throwExceptionIfConnectionIsClosed();
"public ResultSet getSchemas(String catalog = %s, String schemaPattern = %s)",
catalog, schemaPattern);
throwExceptionIfConnectionIsClosed();

if (session.getConnectionContext().getClientType() == DatabricksClientType.SEA
&& (catalog == null || catalog.equals("*") || catalog.equals("%"))) {
// Fetch catalogs from the metadata client
List<String> catalogList = new ArrayList<>();
try (ResultSet catalogs = getCatalogs()) {
while (catalogs.next()) {
String c = catalogs.getString(1);
if (c != null && !c.isEmpty()) {
catalogList.add(c);
}
}
}

// Process catalogs in parallel, gathering schema information
List<List<Object>> schemaRows =
JdbcThreadUtils.parallelFlatMap(
catalogList,
session.getConnectionContext(),
DEFAULT_MAX_THREADS, // Not significant since the executor is provided as a parameter
90, // 90 seconds timeout

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a basis for selecting this particular value for the timeout? Had the same concerns for the max threads variable, but that is configurable.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i added based on if timeout is long enough and below is one data-point:

On the e2-dogfood environment, which has around 9,000 schemas, this approach took approximately 12 seconds—compared to 9 seconds with the existing driver—an acceptable difference.

c -> {
List<List<Object>> rows = new ArrayList<>();
try (ResultSet catalogSchemas =
session.getDatabricksMetadataClient().listSchemas(session, c, schemaPattern)) {
while (catalogSchemas.next()) {
Comment thread
gopalldb marked this conversation as resolved.
Outdated
List<Object> schemaRow = new ArrayList<>();
schemaRow.add(catalogSchemas.getString(1)); // TABLE_SCHEM
schemaRow.add(catalogSchemas.getString(2)); // TABLE_CATALOG
rows.add(schemaRow);
}
} catch (SQLException e) {
LOGGER.warn("Error fetching schemas for catalog %s %s", c, e.getMessage());
}
return rows;
},
getOrCreateSchemasThreadPool());

// Convert combined data into a result set
return metadataResultSetBuilder.getResultSetWithGivenRowsAndColumns(
SCHEMA_COLUMNS, schemaRows, METADATA_STATEMENT_ID, CommandName.LIST_SCHEMAS);
}

return session.getDatabricksMetadataClient().listSchemas(session, catalog, schemaPattern);
}

Expand Down Expand Up @@ -1615,6 +1659,23 @@ public boolean isWrapperFor(Class<?> iface) throws SQLException {
return iface != null && iface.isAssignableFrom(this.getClass());
}

private static ExecutorService getOrCreateSchemasThreadPool() {
Comment thread
jayantsing-db marked this conversation as resolved.
Outdated
synchronized (THREAD_POOL_LOCK) {
if (schemasThreadPool == null || schemasThreadPool.isShutdown()) {
// Could read max threads from a configuration property
schemasThreadPool =
Executors.newFixedThreadPool(
DEFAULT_MAX_THREADS,
r -> {
Thread t = new Thread(r, "jdbc-schemas-fetcher");
t.setDaemon(true);
return t;
});
}
return schemasThreadPool;
}
}

private void throwExceptionIfConnectionIsClosed() throws SQLException {
LOGGER.debug("private void throwExceptionIfConnectionIsClosed()");
if (!connection.getSession().isOpen()) {
Expand Down
166 changes: 166 additions & 0 deletions src/main/java/com/databricks/jdbc/common/util/JdbcThreadUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package com.databricks.jdbc.common.util;

import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.*;
import java.util.function.Function;

/** Utility class for executing tasks in parallel with proper context handling. */
public class JdbcThreadUtils {

/**
* Executes tasks concurrently with appropriate context management, utilizing a provided executor
* service (which can be null, in which case a new one will be created).
*
* @param items The items to process
* @param connectionContext The connection context to propagate to worker threads
* @param maxThreads Maximum number of threads to use (when creating internal executor)
* @param timeoutSeconds Timeout in seconds
* @param task The task to execute for each item
* @param executor Optional executor service to use; if null, an internal one will be created
* @param <T> Type of input items
* @param <R> Type of result
* @return List of results from all tasks
* @throws SQLException If an error occurs during execution
*/
public static <T, R> List<R> parallelMap(
Collection<T> items,
IDatabricksConnectionContext connectionContext,
int maxThreads,
int timeoutSeconds,
Function<T, R> task,
ExecutorService executor)
throws SQLException {

if (items.isEmpty()) {
return Collections.emptyList();
}

boolean createdExecutor = false;
ExecutorService executorToUse = executor;

// Create an executor if one wasn't provided
if (executorToUse == null) {
int threadCount = Math.min(items.size(), maxThreads);
executorToUse = Executors.newFixedThreadPool(threadCount);
createdExecutor = true;
}

try {
List<Future<R>> futures = new ArrayList<>();

// Submit tasks to the executor
for (T item : items) {
futures.add(
executorToUse.submit(
() -> {
// Set connection context for this thread
DatabricksThreadContextHolder.setConnectionContext(connectionContext);
try {
// Execute the task
return task.apply(item);
} finally {
// Clear connection context
DatabricksThreadContextHolder.clearConnectionContext();
}
}));
}

// Collect results
List<R> results = new ArrayList<>(items.size());
for (Future<R> future : futures) {
try {
results.add(future.get(timeoutSeconds, TimeUnit.SECONDS));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new DatabricksSQLException(
"Parallel execution interrupted",
e,
DatabricksDriverErrorCode.THREAD_INTERRUPTED_ERROR);
} catch (ExecutionException e) {
SQLException sqlEx = findSQLExceptionInCauseChain(e);
if (sqlEx != null) {
throw sqlEx;
} else {
throw new DatabricksSQLException(
"Error in parallel execution", e, DatabricksDriverErrorCode.INVALID_STATE);
}
} catch (TimeoutException e) {
throw new DatabricksSQLException(
"Parallel execution timed out after " + timeoutSeconds + " seconds",
e,
DatabricksDriverErrorCode.OPERATION_TIMEOUT_ERROR);
}
}

return results;
} finally {
// Only shut down the executor if we created it
if (createdExecutor && executorToUse != null) {
executorToUse.shutdownNow();
}
}
}

/**
* Executes tasks in parallel, collecting and flattening all results, utilizing a provided
* executor service (which can be null, in which case a new one will be created).
*
* @param items The items to process
* @param connectionContext The connection context to propagate to worker threads
* @param maxThreads Maximum number of threads to use
* @param timeoutSeconds Timeout in seconds
* @param task The task to execute for each item, producing a collection of results
* @param executor Optional executor service to use; if null, an internal one will be created
* @param <T> Type of input items
* @param <R> Type of result
* @return Flattened list of all results
* @throws SQLException If an error occurs during execution
*/
public static <T, R> List<R> parallelFlatMap(
Collection<T> items,
IDatabricksConnectionContext connectionContext,
int maxThreads,
int timeoutSeconds,
Function<T, Collection<R>> task,
ExecutorService executor)
throws SQLException {

List<Collection<R>> collections =
parallelMap(items, connectionContext, maxThreads, timeoutSeconds, task, executor);

// Flatten the results
List<R> allResults = new ArrayList<>();
for (Collection<R> collection : collections) {
if (collection != null) {
allResults.addAll(collection);
}
}

return allResults;
}

/**
* Recursively searches for a SQLException in the exception cause chain.
*
* @param throwable The exception to search
* @return The first SQLException found in the cause chain, or null if none
*/
private static SQLException findSQLExceptionInCauseChain(Throwable throwable) {
if (throwable == null) {
return null;
}

if (throwable instanceof SQLException) {
return (SQLException) throwable;
}

return findSQLExceptionInCauseChain(throwable.getCause());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public enum DatabricksDriverErrorCode {
CATALOG_OR_SCHEMA_FETCH_ERROR,
INTEGRATION_TEST_ERROR,
SDK_CLIENT_ERROR,
OPERATION_TIMEOUT_ERROR,
SSL_HANDSHAKE_ERROR,
AUTH_ERROR
}
Loading
Loading