From fc65c7c3017677e9b19155acd1b395878fed14ce Mon Sep 17 00:00:00 2001 From: samikshya-chand_data Date: Mon, 15 Sep 2025 17:58:45 +0530 Subject: [PATCH 1/6] Fix race condition in thrift accessor --- .../jdbc/api/impl/DatabricksSession.java | 3 +- .../jdbc/api/internal/IDatabricksSession.java | 1 + .../jdbc/dbclient/IDatabricksClient.java | 1 + .../impl/thrift/DatabricksHttpTTransport.java | 15 ++++--- .../impl/thrift/DatabricksThriftAccessor.java | 41 +++++++------------ .../DatabricksThriftServiceClientTest.java | 3 +- 6 files changed, 30 insertions(+), 34 deletions(-) diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksSession.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksSession.java index 22f21b4ba4..bb03511c06 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksSession.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksSession.java @@ -16,6 +16,7 @@ import com.databricks.jdbc.dbclient.impl.sqlexec.DatabricksSdkClient; import com.databricks.jdbc.dbclient.impl.thrift.DatabricksThriftServiceClient; import com.databricks.jdbc.exception.DatabricksHttpException; +import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.exception.DatabricksTemporaryRedirectException; import com.databricks.jdbc.log.JdbcLogger; @@ -253,7 +254,7 @@ public String getConfigValue(String name) { } @Override - public void setClientInfoProperty(String name, String value) { + public void setClientInfoProperty(String name, String value) { LOGGER.debug( String.format( "public void setClientInfoProperty(String name = {%s}, String value = {%s})", diff --git a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksSession.java b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksSession.java index 9a06427272..bc50288630 100644 --- a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksSession.java +++ b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksSession.java @@ -5,6 +5,7 @@ import com.databricks.jdbc.common.IDatabricksComputeResource; import com.databricks.jdbc.dbclient.IDatabricksClient; import com.databricks.jdbc.dbclient.IDatabricksMetadataClient; +import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import java.util.Map; import javax.annotation.Nullable; diff --git a/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java b/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java index f9c434d198..31c0657299 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java @@ -7,6 +7,7 @@ import com.databricks.jdbc.common.IDatabricksComputeResource; import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; import com.databricks.jdbc.model.core.ExternalLink; diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java index dcbbc42372..7c00a8c6e4 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java @@ -87,7 +87,9 @@ public int read(byte[] buf, int off, int len) throws TTransportException { @Override public void write(byte[] buf, int off, int len) { - requestBuffer.write(buf, off, len); + synchronized (requestBuffer) { + requestBuffer.write(buf, off, len); + } } @Override @@ -115,9 +117,13 @@ public void flush() throws TTransportException { LOGGER.debug("Thrift tracing header: " + traceHeader); request.addHeader(TracingUtil.TRACE_HEADER, traceHeader); } - + byte[] requestPayload; + synchronized (requestBuffer) { + requestPayload = requestBuffer.toByteArray(); + requestBuffer.reset(); + } // Set the request entity - request.setEntity(new ByteArrayEntity(requestBuffer.toByteArray())); + request.setEntity(new ByteArrayEntity(requestPayload)); // Execute the request and handle the response long httpRequestStartTime = System.currentTimeMillis(); @@ -145,9 +151,6 @@ public void flush() throws TTransportException { LOGGER.error(e, errorMessage); throw new TTransportException(TTransportException.UNKNOWN, errorMessage, e); } - - // Reset the request buffer - requestBuffer.reset(); } @Override diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index f4ca5c1aef..72ad11ed2b 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -15,10 +15,7 @@ import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.dbclient.impl.common.TimeoutHandler; import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClientFactory; -import com.databricks.jdbc.exception.DatabricksHttpException; -import com.databricks.jdbc.exception.DatabricksParsingException; -import com.databricks.jdbc.exception.DatabricksSQLException; -import com.databricks.jdbc.exception.DatabricksSQLFeatureNotSupportedException; +import com.databricks.jdbc.exception.*; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; import com.databricks.jdbc.model.client.thrift.generated.*; @@ -51,44 +48,31 @@ final class DatabricksThriftAccessor { TExecuteStatementResp._Fields.OPERATION_HANDLE.getThriftFieldId(); private static final short statusFieldId = TExecuteStatementResp._Fields.STATUS.getThriftFieldId(); - private final ThreadLocal thriftClient; private final DatabricksConfig databricksConfig; private final boolean enableDirectResults; private final int asyncPollIntervalMillis; private final int maxRowsPerBlock; private final String connectionUuid; + private final IDatabricksConnectionContext connectionContext; private TProtocolVersion serverProtocolVersion = JDBC_THRIFT_VERSION; - DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext) - throws DatabricksParsingException { + DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext) { this.enableDirectResults = connectionContext.getDirectResultMode(); this.databricksConfig = DatabricksClientConfiguratorManager.getInstance() .getConfigurator(connectionContext) .getDatabricksConfig(); - String endPointUrl = connectionContext.getEndpointURL(); this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); this.connectionUuid = connectionContext.getConnectionUuid(); - - if (!DriverUtil.isRunningAgainstFake()) { - // Create a new thrift client for each thread as client state is not thread safe. Note that - // the underlying protocol uses the same http client which is thread safe - this.thriftClient = - ThreadLocal.withInitial( - () -> createThriftClient(endPointUrl, databricksConfig, connectionContext)); - } else { - TCLIService.Client client = - createThriftClient(endPointUrl, databricksConfig, connectionContext); - this.thriftClient = ThreadLocal.withInitial(() -> client); - } + this.connectionContext=connectionContext; } @VisibleForTesting DatabricksThriftAccessor( TCLIService.Client client, IDatabricksConnectionContext connectionContext) { this.databricksConfig = null; - this.thriftClient = ThreadLocal.withInitial(() -> client); + this.connectionContext=connectionContext; this.enableDirectResults = connectionContext.getDirectResultMode(); this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); @@ -491,8 +475,13 @@ DatabricksResultSet getStatementResult( executionStatus, statementId, resultSet, StatementType.SQL, parentStatement, session); } - TCLIService.Client getThriftClient() { - return thriftClient.get(); + TCLIService.Client getThriftClient() { + try { + return createThriftClient(databricksConfig, connectionContext); + }catch (DatabricksParsingException e) { + String errorMessage = String.format( "Can't create thrift client as Endpoint URL cannot be parsed. Error: %s", e.getMessage()); + throw new DatabricksDriverException(errorMessage, DatabricksDriverErrorCode.INVALID_STATE); + } } DatabricksConfig getDatabricksConfig() { @@ -605,13 +594,13 @@ private TFetchResultsResp listColumns(TGetColumnsReq request) /** * Creates a new thrift client for the given endpoint URL and authentication headers. * - * @param endPointUrl endpoint URL * @param databricksConfig SDK config object required for authentication headers + * @param connectionContext connection configuration of the driver */ private TCLIService.Client createThriftClient( - String endPointUrl, DatabricksConfig databricksConfig, - IDatabricksConnectionContext connectionContext) { + IDatabricksConnectionContext connectionContext) throws DatabricksParsingException { + String endPointUrl = connectionContext.getEndpointURL(); DatabricksHttpTTransport transport = new DatabricksHttpTTransport( DatabricksHttpClientFactory.getInstance().getClient(connectionContext), diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java index 54c7921be7..ba9f1746b7 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java @@ -20,6 +20,7 @@ import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.model.client.thrift.generated.*; import com.databricks.jdbc.model.core.ExternalLink; @@ -785,7 +786,7 @@ void testGetDatabricksConfig() { } @Test - void testResetAccessToken() { + void testResetAccessToken() throws DatabricksParsingException { DatabricksThriftServiceClient client = new DatabricksThriftServiceClient(thriftAccessor, connectionContext); DatabricksHttpTTransport mockDatabricksHttpTTransport = From 37e3b187769d081598d76e6dc1e88a795f7aa8f1 Mon Sep 17 00:00:00 2001 From: samikshya-chand_data Date: Mon, 15 Sep 2025 23:37:27 +0530 Subject: [PATCH 2/6] Fix client issue --- .../jdbc/api/impl/DatabricksSession.java | 3 +- .../jdbc/api/internal/IDatabricksSession.java | 1 - .../jdbc/dbclient/IDatabricksClient.java | 1 - .../impl/thrift/DatabricksThriftAccessor.java | 18 ++++---- .../thrift/DatabricksThriftAccessorTest.java | 45 ++++++++++++------- 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksSession.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksSession.java index bb03511c06..22f21b4ba4 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksSession.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksSession.java @@ -16,7 +16,6 @@ import com.databricks.jdbc.dbclient.impl.sqlexec.DatabricksSdkClient; import com.databricks.jdbc.dbclient.impl.thrift.DatabricksThriftServiceClient; import com.databricks.jdbc.exception.DatabricksHttpException; -import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.exception.DatabricksTemporaryRedirectException; import com.databricks.jdbc.log.JdbcLogger; @@ -254,7 +253,7 @@ public String getConfigValue(String name) { } @Override - public void setClientInfoProperty(String name, String value) { + public void setClientInfoProperty(String name, String value) { LOGGER.debug( String.format( "public void setClientInfoProperty(String name = {%s}, String value = {%s})", diff --git a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksSession.java b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksSession.java index bc50288630..9a06427272 100644 --- a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksSession.java +++ b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksSession.java @@ -5,7 +5,6 @@ import com.databricks.jdbc.common.IDatabricksComputeResource; import com.databricks.jdbc.dbclient.IDatabricksClient; import com.databricks.jdbc.dbclient.IDatabricksMetadataClient; -import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import java.util.Map; import javax.annotation.Nullable; diff --git a/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java b/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java index 31c0657299..f9c434d198 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java +++ b/src/main/java/com/databricks/jdbc/dbclient/IDatabricksClient.java @@ -7,7 +7,6 @@ import com.databricks.jdbc.common.IDatabricksComputeResource; import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.dbclient.impl.common.StatementId; -import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; import com.databricks.jdbc.model.core.ExternalLink; diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index 72ad11ed2b..5fe3af8560 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -10,7 +10,6 @@ import com.databricks.jdbc.common.DatabricksClientConfiguratorManager; import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.common.util.DatabricksThreadContextHolder; -import com.databricks.jdbc.common.util.DriverUtil; import com.databricks.jdbc.common.util.ProtocolFeatureUtil; import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.dbclient.impl.common.TimeoutHandler; @@ -65,14 +64,14 @@ final class DatabricksThriftAccessor { this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); this.connectionUuid = connectionContext.getConnectionUuid(); - this.connectionContext=connectionContext; + this.connectionContext = connectionContext; } @VisibleForTesting DatabricksThriftAccessor( TCLIService.Client client, IDatabricksConnectionContext connectionContext) { this.databricksConfig = null; - this.connectionContext=connectionContext; + this.connectionContext = connectionContext; this.enableDirectResults = connectionContext.getDirectResultMode(); this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); @@ -475,11 +474,14 @@ DatabricksResultSet getStatementResult( executionStatus, statementId, resultSet, StatementType.SQL, parentStatement, session); } - TCLIService.Client getThriftClient() { + TCLIService.Client getThriftClient() { try { return createThriftClient(databricksConfig, connectionContext); - }catch (DatabricksParsingException e) { - String errorMessage = String.format( "Can't create thrift client as Endpoint URL cannot be parsed. Error: %s", e.getMessage()); + } catch (DatabricksParsingException e) { + String errorMessage = + String.format( + "Can't create thrift client as Endpoint URL cannot be parsed. Error: %s", + e.getMessage()); throw new DatabricksDriverException(errorMessage, DatabricksDriverErrorCode.INVALID_STATE); } } @@ -598,8 +600,8 @@ private TFetchResultsResp listColumns(TGetColumnsReq request) * @param connectionContext connection configuration of the driver */ private TCLIService.Client createThriftClient( - DatabricksConfig databricksConfig, - IDatabricksConnectionContext connectionContext) throws DatabricksParsingException { + DatabricksConfig databricksConfig, IDatabricksConnectionContext connectionContext) + throws DatabricksParsingException { String endPointUrl = connectionContext.getEndpointURL(); DatabricksHttpTTransport transport = new DatabricksHttpTTransport( diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java index c09bc43f81..709e08781f 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java @@ -69,9 +69,12 @@ public class DatabricksThriftAccessorTest { .setOperationState(TOperationState.RUNNING_STATE); void setup(Boolean directResultsEnabled) { - when(connectionContext.getDirectResultMode()).thenReturn(directResultsEnabled); - when(connectionContext.getRowsFetchedPerBlock()).thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + lenient().when(connectionContext.getDirectResultMode()).thenReturn(directResultsEnabled); + lenient() + .when(connectionContext.getRowsFetchedPerBlock()) + .thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); + accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); } @Test @@ -161,7 +164,6 @@ void testExecuteAsync_SQLState() throws TException { @Test void testExecuteThrowsThriftError() throws TException { setup(true); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); TExecuteStatementReq request = new TExecuteStatementReq(); when(thriftClient.ExecuteStatement(request)).thenThrow(TException.class); assertThrows( @@ -172,7 +174,6 @@ void testExecuteThrowsThriftError() throws TException { @Test void testExecuteWithParentStatement() throws TException, SQLException { setup(true); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = new TExecuteStatementResp() @@ -193,7 +194,6 @@ void testExecuteWithParentStatement() throws TException, SQLException { @Test void testExecuteWithDirectResults() throws TException, SQLException { setup(true); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = new TExecuteStatementResp() @@ -209,8 +209,12 @@ void testExecuteWithDirectResults() throws TException, SQLException { @Test void testExecuteWithoutDirectResults() throws TException, SQLException { - setup(false); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + lenient().when(connectionContext.getDirectResultMode()).thenReturn(false); + lenient() + .when(connectionContext.getRowsFetchedPerBlock()) + .thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); + accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = new TExecuteStatementResp() @@ -226,9 +230,10 @@ void testExecuteWithoutDirectResults() throws TException, SQLException { @Test void testExecute_throwsException() throws TException { - setup(true); - - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + when(connectionContext.getDirectResultMode()).thenReturn(false); + when(connectionContext.getRowsFetchedPerBlock()).thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); + accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = new TExecuteStatementResp() @@ -333,7 +338,8 @@ void testCloseOperation_error() throws TException { void testIncludeResultSetMetadataNotSetForOldProtocol() throws TException, DatabricksHttpException { DatabricksThriftAccessor accessor = - new DatabricksThriftAccessor(thriftClient, connectionContext); + spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); accessor.setServerProtocolVersion(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4); TFetchResultsReq expectedReq = getFetchResultsRequest(false); when(thriftClient.FetchResults(expectedReq)) @@ -360,7 +366,8 @@ void testIncludeResultSetMetadataNotSetForOldProtocol() @Test void testGetStatementResult_success() throws Exception { when(connectionContext.getDirectResultMode()).thenReturn(false); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); when(thriftClient.GetOperationStatus(operationStatusReq)) .thenReturn(operationStatusFinishedResp); TFetchResultsReq fetchReq = @@ -381,7 +388,8 @@ void testGetStatementResult_success() throws Exception { @Test void testGetStatementResult_pending() throws Exception { when(connectionContext.getDirectResultMode()).thenReturn(false); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); TGetOperationStatusResp resp = new TGetOperationStatusResp() .setStatus(new TStatus().setStatusCode(TStatusCode.STILL_EXECUTING_STATUS)) @@ -713,7 +721,8 @@ void testExecuteWithTimeout() throws TException, SQLException { // Set the async poll interval to 200 ms when(connectionContext.getAsyncExecPollInterval()).thenReturn(200); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); // Create statement execution mocks TExecuteStatementReq request = new TExecuteStatementReq(); @@ -751,7 +760,8 @@ void testExecuteWithTimeoutExpired() throws TException, SQLException { // Set the async poll interval to 1 second to facilitate testing when(connectionContext.getAsyncExecPollInterval()).thenReturn(1000); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); // Create statement execution mocks TExecuteStatementReq request = new TExecuteStatementReq(); @@ -796,7 +806,8 @@ void testFetchResultsWithCustomMaxRowsPerBlock() throws TException, SQLException IDatabricksConnectionContext mockConnectionContext = mock(IDatabricksConnectionContext.class); when(mockConnectionContext.getDirectResultMode()).thenReturn(true); when(mockConnectionContext.getRowsFetchedPerBlock()).thenReturn(customMaxRows); - accessor = new DatabricksThriftAccessor(thriftClient, mockConnectionContext); + accessor = spy(new DatabricksThriftAccessor(thriftClient, mockConnectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq executeRequest = new TExecuteStatementReq(); TExecuteStatementResp executeResponse = From fae8c12c29e889399cef7b3ca64049a437faa61b Mon Sep 17 00:00:00 2001 From: samikshya-chand_data Date: Tue, 16 Sep 2025 00:18:16 +0530 Subject: [PATCH 3/6] update code to be better --- .../impl/thrift/DatabricksThriftAccessor.java | 35 ++++++------------- .../thrift/DatabricksThriftAccessorTest.java | 25 ++++++------- 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index 5fe3af8560..cfca1ae9da 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -52,15 +52,18 @@ final class DatabricksThriftAccessor { private final int asyncPollIntervalMillis; private final int maxRowsPerBlock; private final String connectionUuid; + private final String endpointUrl; private final IDatabricksConnectionContext connectionContext; private TProtocolVersion serverProtocolVersion = JDBC_THRIFT_VERSION; - DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext) { + DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext) + throws DatabricksParsingException { this.enableDirectResults = connectionContext.getDirectResultMode(); this.databricksConfig = DatabricksClientConfiguratorManager.getInstance() .getConfigurator(connectionContext) .getDatabricksConfig(); + this.endpointUrl = connectionContext.getEndpointURL(); this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); this.connectionUuid = connectionContext.getConnectionUuid(); @@ -69,8 +72,10 @@ final class DatabricksThriftAccessor { @VisibleForTesting DatabricksThriftAccessor( - TCLIService.Client client, IDatabricksConnectionContext connectionContext) { + TCLIService.Client client, IDatabricksConnectionContext connectionContext) + throws DatabricksParsingException { this.databricksConfig = null; + this.endpointUrl = connectionContext.getEndpointURL(); this.connectionContext = connectionContext; this.enableDirectResults = connectionContext.getDirectResultMode(); this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); @@ -474,18 +479,6 @@ DatabricksResultSet getStatementResult( executionStatus, statementId, resultSet, StatementType.SQL, parentStatement, session); } - TCLIService.Client getThriftClient() { - try { - return createThriftClient(databricksConfig, connectionContext); - } catch (DatabricksParsingException e) { - String errorMessage = - String.format( - "Can't create thrift client as Endpoint URL cannot be parsed. Error: %s", - e.getMessage()); - throw new DatabricksDriverException(errorMessage, DatabricksDriverErrorCode.INVALID_STATE); - } - } - DatabricksConfig getDatabricksConfig() { return databricksConfig; } @@ -593,20 +586,12 @@ private TFetchResultsResp listColumns(TGetColumnsReq request) return fetchMetadataResults(response, response.toString()); } - /** - * Creates a new thrift client for the given endpoint URL and authentication headers. - * - * @param databricksConfig SDK config object required for authentication headers - * @param connectionContext connection configuration of the driver - */ - private TCLIService.Client createThriftClient( - DatabricksConfig databricksConfig, IDatabricksConnectionContext connectionContext) - throws DatabricksParsingException { - String endPointUrl = connectionContext.getEndpointURL(); + /** Creates a new thrift client for the given endpoint URL and authentication headers. */ + TCLIService.Client getThriftClient() { DatabricksHttpTTransport transport = new DatabricksHttpTTransport( DatabricksHttpClientFactory.getInstance().getClient(connectionContext), - endPointUrl, + endpointUrl, databricksConfig, connectionContext); TBinaryProtocol protocol = new TBinaryProtocol(transport); diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java index 709e08781f..b1a51648f3 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java @@ -12,6 +12,7 @@ import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.exception.DatabricksHttpException; +import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.exception.DatabricksTimeoutException; import com.databricks.jdbc.model.client.thrift.generated.*; @@ -68,7 +69,7 @@ public class DatabricksThriftAccessorTest { .setStatus(new TStatus().setStatusCode(TStatusCode.SUCCESS_STATUS)) .setOperationState(TOperationState.RUNNING_STATE); - void setup(Boolean directResultsEnabled) { + void setup(Boolean directResultsEnabled) throws DatabricksParsingException { lenient().when(connectionContext.getDirectResultMode()).thenReturn(directResultsEnabled); lenient() .when(connectionContext.getRowsFetchedPerBlock()) @@ -134,7 +135,7 @@ void testExecuteAsync() throws TException, SQLException { } @Test - void testExecuteAsync_error() throws TException { + void testExecuteAsync_error() throws TException, DatabricksParsingException { setup(true); TExecuteStatementReq request = new TExecuteStatementReq(); @@ -145,7 +146,7 @@ void testExecuteAsync_error() throws TException { } @Test - void testExecuteAsync_SQLState() throws TException { + void testExecuteAsync_SQLState() throws TException, DatabricksParsingException { setup(true); TExecuteStatementReq request = new TExecuteStatementReq(); @@ -162,7 +163,7 @@ void testExecuteAsync_SQLState() throws TException { } @Test - void testExecuteThrowsThriftError() throws TException { + void testExecuteThrowsThriftError() throws TException, DatabricksParsingException { setup(true); TExecuteStatementReq request = new TExecuteStatementReq(); when(thriftClient.ExecuteStatement(request)).thenThrow(TException.class); @@ -229,7 +230,7 @@ void testExecuteWithoutDirectResults() throws TException, SQLException { } @Test - void testExecute_throwsException() throws TException { + void testExecute_throwsException() throws TException, DatabricksParsingException { when(connectionContext.getDirectResultMode()).thenReturn(false); when(connectionContext.getRowsFetchedPerBlock()).thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); @@ -251,7 +252,7 @@ void testExecute_throwsException() throws TException { } @Test - void testExecuteThrowsSQLExceptionWithSqlState() throws TException { + void testExecuteThrowsSQLExceptionWithSqlState() throws TException, DatabricksParsingException { setup(true); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = @@ -307,7 +308,7 @@ void testCloseOperation() throws TException, DatabricksSQLException { } @Test - void testCancelOperation_error() throws TException { + void testCancelOperation_error() throws TException, DatabricksParsingException { setup(true); TCancelOperationReq request = @@ -321,7 +322,7 @@ void testCancelOperation_error() throws TException { } @Test - void testCloseOperation_error() throws TException { + void testCloseOperation_error() throws TException, DatabricksParsingException { setup(true); TCloseOperationReq request = @@ -336,7 +337,7 @@ void testCloseOperation_error() throws TException { @Test void testIncludeResultSetMetadataNotSetForOldProtocol() - throws TException, DatabricksHttpException { + throws TException, DatabricksHttpException, DatabricksParsingException { DatabricksThriftAccessor accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); @@ -643,7 +644,7 @@ void testTypeInfoWithDirectResults() throws TException, DatabricksSQLException { } @Test - void testAccessorWhenFetchResultsThrowsError() throws TException { + void testAccessorWhenFetchResultsThrowsError() throws TException, DatabricksParsingException { setup(false); TGetTablesReq request = new TGetTablesReq(); @@ -659,7 +660,7 @@ void testAccessorWhenFetchResultsThrowsError() throws TException { } @Test - void testAccessorDuringThriftError() throws TException { + void testAccessorDuringThriftError() throws TException, DatabricksParsingException { setup(true); TGetTablesReq request = new TGetTablesReq(); @@ -668,7 +669,7 @@ void testAccessorDuringThriftError() throws TException { } @Test - void testAccessorDuringHTTPError() throws TException { + void testAccessorDuringHTTPError() throws TException, DatabricksParsingException { setup(true); TGetTablesReq request = new TGetTablesReq(); From d83d672e92077861e40cca876f427c1bc778dd9c Mon Sep 17 00:00:00 2001 From: samikshya-chand_data Date: Tue, 16 Sep 2025 11:52:27 +0530 Subject: [PATCH 4/6] Fix PR checks --- NEXT_CHANGELOG.md | 1 + .../impl/thrift/DatabricksThriftAccessor.java | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 0c0cd7b609..2c9a2abc1d 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -25,5 +25,6 @@ - Fixed a bug in the JDBC driver's metadata parsing for nested decimal fields within struct types. - Fixed case sensitive table search in `connection.getMetadata().getTables()` - Fixed `connection.getMetadata().getColumns()` to return the correct scale. +- Fixed state leaking issue in thrift client. --- *Note: When making changes, please add your change under the appropriate section with a brief description.* diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index cfca1ae9da..cf89c89a99 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -10,6 +10,7 @@ import com.databricks.jdbc.common.DatabricksClientConfiguratorManager; import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.common.util.DatabricksThreadContextHolder; +import com.databricks.jdbc.common.util.DriverUtil; import com.databricks.jdbc.common.util.ProtocolFeatureUtil; import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.dbclient.impl.common.TimeoutHandler; @@ -55,6 +56,7 @@ final class DatabricksThriftAccessor { private final String endpointUrl; private final IDatabricksConnectionContext connectionContext; private TProtocolVersion serverProtocolVersion = JDBC_THRIFT_VERSION; + private static TCLIService.Client FAKE_SHARED_CLIENT; DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext) throws DatabricksParsingException { @@ -588,15 +590,25 @@ private TFetchResultsResp listColumns(TGetColumnsReq request) /** Creates a new thrift client for the given endpoint URL and authentication headers. */ TCLIService.Client getThriftClient() { + if (DriverUtil.isRunningAgainstFake()) { + synchronized (DatabricksThriftAccessor.class) { + if (FAKE_SHARED_CLIENT == null) { + FAKE_SHARED_CLIENT = newThriftClient(); + } + return FAKE_SHARED_CLIENT; + } + } + return newThriftClient(); + } + + private TCLIService.Client newThriftClient() { DatabricksHttpTTransport transport = new DatabricksHttpTTransport( DatabricksHttpClientFactory.getInstance().getClient(connectionContext), endpointUrl, databricksConfig, connectionContext); - TBinaryProtocol protocol = new TBinaryProtocol(transport); - - return new TCLIService.Client(protocol); + return new TCLIService.Client(new TBinaryProtocol(transport)); } /** From fd547c49b22839b33bcf1692e52673f8ca089123 Mon Sep 17 00:00:00 2001 From: samikshya-chand_data Date: Tue, 16 Sep 2025 14:44:32 +0530 Subject: [PATCH 5/6] Fix tests --- .../impl/thrift/DatabricksThriftAccessor.java | 24 ++----- .../thrift/DatabricksThriftAccessorTest.java | 62 ++++++++++++++++--- 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index cf89c89a99..b2e0f5c8ed 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -24,7 +24,6 @@ import com.databricks.jdbc.telemetry.latency.TelemetryCollector; import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.service.sql.StatementState; -import com.google.common.annotations.VisibleForTesting; import java.sql.SQLException; import java.util.Arrays; import java.util.concurrent.TimeUnit; @@ -70,19 +69,9 @@ final class DatabricksThriftAccessor { this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); this.connectionUuid = connectionContext.getConnectionUuid(); this.connectionContext = connectionContext; - } - - @VisibleForTesting - DatabricksThriftAccessor( - TCLIService.Client client, IDatabricksConnectionContext connectionContext) - throws DatabricksParsingException { - this.databricksConfig = null; - this.endpointUrl = connectionContext.getEndpointURL(); - this.connectionContext = connectionContext; - this.enableDirectResults = connectionContext.getDirectResultMode(); - this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); - this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); - this.connectionUuid = connectionContext.getConnectionUuid(); + if (DriverUtil.isRunningAgainstFake()) { + FAKE_SHARED_CLIENT = newThriftClient(); + } } @SuppressWarnings("rawtypes") @@ -591,12 +580,7 @@ private TFetchResultsResp listColumns(TGetColumnsReq request) /** Creates a new thrift client for the given endpoint URL and authentication headers. */ TCLIService.Client getThriftClient() { if (DriverUtil.isRunningAgainstFake()) { - synchronized (DatabricksThriftAccessor.class) { - if (FAKE_SHARED_CLIENT == null) { - FAKE_SHARED_CLIENT = newThriftClient(); - } - return FAKE_SHARED_CLIENT; - } + return FAKE_SHARED_CLIENT; } return newThriftClient(); } diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java index b1a51648f3..768b3416cf 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java @@ -9,21 +9,27 @@ import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.common.DatabricksClientConfiguratorManager; import com.databricks.jdbc.common.StatementType; +import com.databricks.jdbc.dbclient.impl.common.ClientConfigurator; import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.exception.DatabricksHttpException; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.exception.DatabricksTimeoutException; import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.service.sql.StatementState; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import org.apache.thrift.TException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) @@ -69,12 +75,42 @@ public class DatabricksThriftAccessorTest { .setStatus(new TStatus().setStatusCode(TStatusCode.SUCCESS_STATUS)) .setOperationState(TOperationState.RUNNING_STATE); + private MockedStatic configuratorManagerStatic; + private DatabricksClientConfiguratorManager configuratorManager; + + @BeforeEach + void initConfiguratorManager() throws DatabricksParsingException { + configuratorManagerStatic = mockStatic(DatabricksClientConfiguratorManager.class); + configuratorManager = mock(DatabricksClientConfiguratorManager.class); + configuratorManagerStatic + .when(DatabricksClientConfiguratorManager::getInstance) + .thenReturn(configuratorManager); + ClientConfigurator mockConfigurator = mock(ClientConfigurator.class); + lenient().when(mockConfigurator.getDatabricksConfig()).thenReturn(new DatabricksConfig()); + lenient() + .when(configuratorManager.getConfigurator(any(IDatabricksConnectionContext.class))) + .thenReturn(mockConfigurator); + // Provide common defaults used in constructor and various tests + lenient() + .when(connectionContext.getRowsFetchedPerBlock()) + .thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); + lenient().when(connectionContext.getAsyncExecPollInterval()).thenReturn(1000); + lenient().when(connectionContext.getEndpointURL()).thenReturn("http://localhost"); + } + + @AfterEach + void cleanupConfiguratorManager() { + if (configuratorManagerStatic != null) { + configuratorManagerStatic.close(); + } + } + void setup(Boolean directResultsEnabled) throws DatabricksParsingException { lenient().when(connectionContext.getDirectResultMode()).thenReturn(directResultsEnabled); lenient() .when(connectionContext.getRowsFetchedPerBlock()) .thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); - accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); } @@ -214,7 +250,7 @@ void testExecuteWithoutDirectResults() throws TException, SQLException { lenient() .when(connectionContext.getRowsFetchedPerBlock()) .thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); - accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = @@ -233,7 +269,7 @@ void testExecuteWithoutDirectResults() throws TException, SQLException { void testExecute_throwsException() throws TException, DatabricksParsingException { when(connectionContext.getDirectResultMode()).thenReturn(false); when(connectionContext.getRowsFetchedPerBlock()).thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); - accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = @@ -338,8 +374,7 @@ void testCloseOperation_error() throws TException, DatabricksParsingException { @Test void testIncludeResultSetMetadataNotSetForOldProtocol() throws TException, DatabricksHttpException, DatabricksParsingException { - DatabricksThriftAccessor accessor = - spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + DatabricksThriftAccessor accessor = spy(new DatabricksThriftAccessor(connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); accessor.setServerProtocolVersion(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4); TFetchResultsReq expectedReq = getFetchResultsRequest(false); @@ -367,7 +402,7 @@ void testIncludeResultSetMetadataNotSetForOldProtocol() @Test void testGetStatementResult_success() throws Exception { when(connectionContext.getDirectResultMode()).thenReturn(false); - accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); when(thriftClient.GetOperationStatus(operationStatusReq)) .thenReturn(operationStatusFinishedResp); @@ -389,7 +424,7 @@ void testGetStatementResult_success() throws Exception { @Test void testGetStatementResult_pending() throws Exception { when(connectionContext.getDirectResultMode()).thenReturn(false); - accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); TGetOperationStatusResp resp = new TGetOperationStatusResp() @@ -722,7 +757,7 @@ void testExecuteWithTimeout() throws TException, SQLException { // Set the async poll interval to 200 ms when(connectionContext.getAsyncExecPollInterval()).thenReturn(200); - accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); // Create statement execution mocks @@ -761,7 +796,7 @@ void testExecuteWithTimeoutExpired() throws TException, SQLException { // Set the async poll interval to 1 second to facilitate testing when(connectionContext.getAsyncExecPollInterval()).thenReturn(1000); - accessor = spy(new DatabricksThriftAccessor(thriftClient, connectionContext)); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); // Create statement execution mocks @@ -807,7 +842,14 @@ void testFetchResultsWithCustomMaxRowsPerBlock() throws TException, SQLException IDatabricksConnectionContext mockConnectionContext = mock(IDatabricksConnectionContext.class); when(mockConnectionContext.getDirectResultMode()).thenReturn(true); when(mockConnectionContext.getRowsFetchedPerBlock()).thenReturn(customMaxRows); - accessor = spy(new DatabricksThriftAccessor(thriftClient, mockConnectionContext)); + // Ensure configurator manager returns a configurator for this separate mock context + ClientConfigurator customMockConfigurator = mock(ClientConfigurator.class); + when(customMockConfigurator.getDatabricksConfig()).thenReturn(new DatabricksConfig()); + when(configuratorManager.getConfigurator(mockConnectionContext)) + .thenReturn(customMockConfigurator); + lenient().when(mockConnectionContext.getAsyncExecPollInterval()).thenReturn(1000); + lenient().when(mockConnectionContext.getEndpointURL()).thenReturn("http://localhost"); + accessor = spy(new DatabricksThriftAccessor(mockConnectionContext)); doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq executeRequest = new TExecuteStatementReq(); From 479fd663fc0a2b9a40a0f6de063a507b92ae67d4 Mon Sep 17 00:00:00 2001 From: samikshya-chand_data Date: Tue, 16 Sep 2025 15:10:30 +0530 Subject: [PATCH 6/6] Revert back to original thrift method --- .../dbclient/impl/thrift/DatabricksThriftAccessor.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index b2e0f5c8ed..413fe488d0 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -55,7 +55,7 @@ final class DatabricksThriftAccessor { private final String endpointUrl; private final IDatabricksConnectionContext connectionContext; private TProtocolVersion serverProtocolVersion = JDBC_THRIFT_VERSION; - private static TCLIService.Client FAKE_SHARED_CLIENT; + private ThreadLocal FAKE_SHARED_CLIENT; DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext) throws DatabricksParsingException { @@ -70,7 +70,8 @@ final class DatabricksThriftAccessor { this.connectionUuid = connectionContext.getConnectionUuid(); this.connectionContext = connectionContext; if (DriverUtil.isRunningAgainstFake()) { - FAKE_SHARED_CLIENT = newThriftClient(); + TCLIService.Client client = newThriftClient(); + this.FAKE_SHARED_CLIENT = ThreadLocal.withInitial(() -> client); } } @@ -580,7 +581,7 @@ private TFetchResultsResp listColumns(TGetColumnsReq request) /** Creates a new thrift client for the given endpoint URL and authentication headers. */ TCLIService.Client getThriftClient() { if (DriverUtil.isRunningAgainstFake()) { - return FAKE_SHARED_CLIENT; + return FAKE_SHARED_CLIENT.get(); } return newThriftClient(); }