diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 2e17e6a158..01bb1b860c 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -7,6 +7,7 @@ ### Updated ### Fixed +- Fixed state leaking issue in thrift client. - Fixed timestamp values returning only milliseconds instead of the full nanosecond precision. --- *Note: When making changes, please add your change under the appropriate section with a brief description.* diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java index dcbbc42372..7c00a8c6e4 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksHttpTTransport.java @@ -87,7 +87,9 @@ public int read(byte[] buf, int off, int len) throws TTransportException { @Override public void write(byte[] buf, int off, int len) { - requestBuffer.write(buf, off, len); + synchronized (requestBuffer) { + requestBuffer.write(buf, off, len); + } } @Override @@ -115,9 +117,13 @@ public void flush() throws TTransportException { LOGGER.debug("Thrift tracing header: " + traceHeader); request.addHeader(TracingUtil.TRACE_HEADER, traceHeader); } - + byte[] requestPayload; + synchronized (requestBuffer) { + requestPayload = requestBuffer.toByteArray(); + requestBuffer.reset(); + } // Set the request entity - request.setEntity(new ByteArrayEntity(requestBuffer.toByteArray())); + request.setEntity(new ByteArrayEntity(requestPayload)); // Execute the request and handle the response long httpRequestStartTime = System.currentTimeMillis(); @@ -145,9 +151,6 @@ public void flush() throws TTransportException { LOGGER.error(e, errorMessage); throw new TTransportException(TTransportException.UNKNOWN, errorMessage, e); } - - // Reset the request buffer - requestBuffer.reset(); } @Override diff --git a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java index f4ca5c1aef..413fe488d0 100644 --- a/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java +++ b/src/main/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessor.java @@ -15,10 +15,7 @@ import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.dbclient.impl.common.TimeoutHandler; import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClientFactory; -import com.databricks.jdbc.exception.DatabricksHttpException; -import com.databricks.jdbc.exception.DatabricksParsingException; -import com.databricks.jdbc.exception.DatabricksSQLException; -import com.databricks.jdbc.exception.DatabricksSQLFeatureNotSupportedException; +import com.databricks.jdbc.exception.*; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; import com.databricks.jdbc.model.client.thrift.generated.*; @@ -27,7 +24,6 @@ import com.databricks.jdbc.telemetry.latency.TelemetryCollector; import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.service.sql.StatementState; -import com.google.common.annotations.VisibleForTesting; import java.sql.SQLException; import java.util.Arrays; import java.util.concurrent.TimeUnit; @@ -51,13 +47,15 @@ final class DatabricksThriftAccessor { TExecuteStatementResp._Fields.OPERATION_HANDLE.getThriftFieldId(); private static final short statusFieldId = TExecuteStatementResp._Fields.STATUS.getThriftFieldId(); - private final ThreadLocal thriftClient; private final DatabricksConfig databricksConfig; private final boolean enableDirectResults; private final int asyncPollIntervalMillis; private final int maxRowsPerBlock; private final String connectionUuid; + private final String endpointUrl; + private final IDatabricksConnectionContext connectionContext; private TProtocolVersion serverProtocolVersion = JDBC_THRIFT_VERSION; + private ThreadLocal FAKE_SHARED_CLIENT; DatabricksThriftAccessor(IDatabricksConnectionContext connectionContext) throws DatabricksParsingException { @@ -66,35 +64,17 @@ final class DatabricksThriftAccessor { DatabricksClientConfiguratorManager.getInstance() .getConfigurator(connectionContext) .getDatabricksConfig(); - String endPointUrl = connectionContext.getEndpointURL(); + this.endpointUrl = connectionContext.getEndpointURL(); this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); this.connectionUuid = connectionContext.getConnectionUuid(); - - if (!DriverUtil.isRunningAgainstFake()) { - // Create a new thrift client for each thread as client state is not thread safe. Note that - // the underlying protocol uses the same http client which is thread safe - this.thriftClient = - ThreadLocal.withInitial( - () -> createThriftClient(endPointUrl, databricksConfig, connectionContext)); - } else { - TCLIService.Client client = - createThriftClient(endPointUrl, databricksConfig, connectionContext); - this.thriftClient = ThreadLocal.withInitial(() -> client); + this.connectionContext = connectionContext; + if (DriverUtil.isRunningAgainstFake()) { + TCLIService.Client client = newThriftClient(); + this.FAKE_SHARED_CLIENT = ThreadLocal.withInitial(() -> client); } } - @VisibleForTesting - DatabricksThriftAccessor( - TCLIService.Client client, IDatabricksConnectionContext connectionContext) { - this.databricksConfig = null; - this.thriftClient = ThreadLocal.withInitial(() -> client); - this.enableDirectResults = connectionContext.getDirectResultMode(); - this.asyncPollIntervalMillis = connectionContext.getAsyncExecPollInterval(); - this.maxRowsPerBlock = connectionContext.getRowsFetchedPerBlock(); - this.connectionUuid = connectionContext.getConnectionUuid(); - } - @SuppressWarnings("rawtypes") TBase getThriftResponse(TBase request) throws DatabricksSQLException { LOGGER.debug("Fetching thrift response for request {}", request.toString()); @@ -491,10 +471,6 @@ DatabricksResultSet getStatementResult( executionStatus, statementId, resultSet, StatementType.SQL, parentStatement, session); } - TCLIService.Client getThriftClient() { - return thriftClient.get(); - } - DatabricksConfig getDatabricksConfig() { return databricksConfig; } @@ -602,25 +578,22 @@ private TFetchResultsResp listColumns(TGetColumnsReq request) return fetchMetadataResults(response, response.toString()); } - /** - * Creates a new thrift client for the given endpoint URL and authentication headers. - * - * @param endPointUrl endpoint URL - * @param databricksConfig SDK config object required for authentication headers - */ - private TCLIService.Client createThriftClient( - String endPointUrl, - DatabricksConfig databricksConfig, - IDatabricksConnectionContext connectionContext) { + /** Creates a new thrift client for the given endpoint URL and authentication headers. */ + TCLIService.Client getThriftClient() { + if (DriverUtil.isRunningAgainstFake()) { + return FAKE_SHARED_CLIENT.get(); + } + return newThriftClient(); + } + + private TCLIService.Client newThriftClient() { DatabricksHttpTTransport transport = new DatabricksHttpTTransport( DatabricksHttpClientFactory.getInstance().getClient(connectionContext), - endPointUrl, + endpointUrl, databricksConfig, connectionContext); - TBinaryProtocol protocol = new TBinaryProtocol(transport); - - return new TCLIService.Client(protocol); + return new TCLIService.Client(new TBinaryProtocol(transport)); } /** diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java index c09bc43f81..768b3416cf 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftAccessorTest.java @@ -9,20 +9,27 @@ import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.common.DatabricksClientConfiguratorManager; import com.databricks.jdbc.common.StatementType; +import com.databricks.jdbc.dbclient.impl.common.ClientConfigurator; import com.databricks.jdbc.dbclient.impl.common.StatementId; import com.databricks.jdbc.exception.DatabricksHttpException; +import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.exception.DatabricksTimeoutException; import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.service.sql.StatementState; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import org.apache.thrift.TException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) @@ -68,10 +75,43 @@ public class DatabricksThriftAccessorTest { .setStatus(new TStatus().setStatusCode(TStatusCode.SUCCESS_STATUS)) .setOperationState(TOperationState.RUNNING_STATE); - void setup(Boolean directResultsEnabled) { - when(connectionContext.getDirectResultMode()).thenReturn(directResultsEnabled); - when(connectionContext.getRowsFetchedPerBlock()).thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + private MockedStatic configuratorManagerStatic; + private DatabricksClientConfiguratorManager configuratorManager; + + @BeforeEach + void initConfiguratorManager() throws DatabricksParsingException { + configuratorManagerStatic = mockStatic(DatabricksClientConfiguratorManager.class); + configuratorManager = mock(DatabricksClientConfiguratorManager.class); + configuratorManagerStatic + .when(DatabricksClientConfiguratorManager::getInstance) + .thenReturn(configuratorManager); + ClientConfigurator mockConfigurator = mock(ClientConfigurator.class); + lenient().when(mockConfigurator.getDatabricksConfig()).thenReturn(new DatabricksConfig()); + lenient() + .when(configuratorManager.getConfigurator(any(IDatabricksConnectionContext.class))) + .thenReturn(mockConfigurator); + // Provide common defaults used in constructor and various tests + lenient() + .when(connectionContext.getRowsFetchedPerBlock()) + .thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); + lenient().when(connectionContext.getAsyncExecPollInterval()).thenReturn(1000); + lenient().when(connectionContext.getEndpointURL()).thenReturn("http://localhost"); + } + + @AfterEach + void cleanupConfiguratorManager() { + if (configuratorManagerStatic != null) { + configuratorManagerStatic.close(); + } + } + + void setup(Boolean directResultsEnabled) throws DatabricksParsingException { + lenient().when(connectionContext.getDirectResultMode()).thenReturn(directResultsEnabled); + lenient() + .when(connectionContext.getRowsFetchedPerBlock()) + .thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); } @Test @@ -131,7 +171,7 @@ void testExecuteAsync() throws TException, SQLException { } @Test - void testExecuteAsync_error() throws TException { + void testExecuteAsync_error() throws TException, DatabricksParsingException { setup(true); TExecuteStatementReq request = new TExecuteStatementReq(); @@ -142,7 +182,7 @@ void testExecuteAsync_error() throws TException { } @Test - void testExecuteAsync_SQLState() throws TException { + void testExecuteAsync_SQLState() throws TException, DatabricksParsingException { setup(true); TExecuteStatementReq request = new TExecuteStatementReq(); @@ -159,9 +199,8 @@ void testExecuteAsync_SQLState() throws TException { } @Test - void testExecuteThrowsThriftError() throws TException { + void testExecuteThrowsThriftError() throws TException, DatabricksParsingException { setup(true); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); TExecuteStatementReq request = new TExecuteStatementReq(); when(thriftClient.ExecuteStatement(request)).thenThrow(TException.class); assertThrows( @@ -172,7 +211,6 @@ void testExecuteThrowsThriftError() throws TException { @Test void testExecuteWithParentStatement() throws TException, SQLException { setup(true); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = new TExecuteStatementResp() @@ -193,7 +231,6 @@ void testExecuteWithParentStatement() throws TException, SQLException { @Test void testExecuteWithDirectResults() throws TException, SQLException { setup(true); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = new TExecuteStatementResp() @@ -209,8 +246,12 @@ void testExecuteWithDirectResults() throws TException, SQLException { @Test void testExecuteWithoutDirectResults() throws TException, SQLException { - setup(false); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + lenient().when(connectionContext.getDirectResultMode()).thenReturn(false); + lenient() + .when(connectionContext.getRowsFetchedPerBlock()) + .thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = new TExecuteStatementResp() @@ -225,10 +266,11 @@ void testExecuteWithoutDirectResults() throws TException, SQLException { } @Test - void testExecute_throwsException() throws TException { - setup(true); - - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + void testExecute_throwsException() throws TException, DatabricksParsingException { + when(connectionContext.getDirectResultMode()).thenReturn(false); + when(connectionContext.getRowsFetchedPerBlock()).thenReturn(DEFAULT_ROW_LIMIT_PER_BLOCK); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = new TExecuteStatementResp() @@ -246,7 +288,7 @@ void testExecute_throwsException() throws TException { } @Test - void testExecuteThrowsSQLExceptionWithSqlState() throws TException { + void testExecuteThrowsSQLExceptionWithSqlState() throws TException, DatabricksParsingException { setup(true); TExecuteStatementReq request = new TExecuteStatementReq(); TExecuteStatementResp tExecuteStatementResp = @@ -302,7 +344,7 @@ void testCloseOperation() throws TException, DatabricksSQLException { } @Test - void testCancelOperation_error() throws TException { + void testCancelOperation_error() throws TException, DatabricksParsingException { setup(true); TCancelOperationReq request = @@ -316,7 +358,7 @@ void testCancelOperation_error() throws TException { } @Test - void testCloseOperation_error() throws TException { + void testCloseOperation_error() throws TException, DatabricksParsingException { setup(true); TCloseOperationReq request = @@ -331,9 +373,9 @@ void testCloseOperation_error() throws TException { @Test void testIncludeResultSetMetadataNotSetForOldProtocol() - throws TException, DatabricksHttpException { - DatabricksThriftAccessor accessor = - new DatabricksThriftAccessor(thriftClient, connectionContext); + throws TException, DatabricksHttpException, DatabricksParsingException { + DatabricksThriftAccessor accessor = spy(new DatabricksThriftAccessor(connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); accessor.setServerProtocolVersion(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4); TFetchResultsReq expectedReq = getFetchResultsRequest(false); when(thriftClient.FetchResults(expectedReq)) @@ -360,7 +402,8 @@ void testIncludeResultSetMetadataNotSetForOldProtocol() @Test void testGetStatementResult_success() throws Exception { when(connectionContext.getDirectResultMode()).thenReturn(false); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); when(thriftClient.GetOperationStatus(operationStatusReq)) .thenReturn(operationStatusFinishedResp); TFetchResultsReq fetchReq = @@ -381,7 +424,8 @@ void testGetStatementResult_success() throws Exception { @Test void testGetStatementResult_pending() throws Exception { when(connectionContext.getDirectResultMode()).thenReturn(false); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); TGetOperationStatusResp resp = new TGetOperationStatusResp() .setStatus(new TStatus().setStatusCode(TStatusCode.STILL_EXECUTING_STATUS)) @@ -635,7 +679,7 @@ void testTypeInfoWithDirectResults() throws TException, DatabricksSQLException { } @Test - void testAccessorWhenFetchResultsThrowsError() throws TException { + void testAccessorWhenFetchResultsThrowsError() throws TException, DatabricksParsingException { setup(false); TGetTablesReq request = new TGetTablesReq(); @@ -651,7 +695,7 @@ void testAccessorWhenFetchResultsThrowsError() throws TException { } @Test - void testAccessorDuringThriftError() throws TException { + void testAccessorDuringThriftError() throws TException, DatabricksParsingException { setup(true); TGetTablesReq request = new TGetTablesReq(); @@ -660,7 +704,7 @@ void testAccessorDuringThriftError() throws TException { } @Test - void testAccessorDuringHTTPError() throws TException { + void testAccessorDuringHTTPError() throws TException, DatabricksParsingException { setup(true); TGetTablesReq request = new TGetTablesReq(); @@ -713,7 +757,8 @@ void testExecuteWithTimeout() throws TException, SQLException { // Set the async poll interval to 200 ms when(connectionContext.getAsyncExecPollInterval()).thenReturn(200); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); // Create statement execution mocks TExecuteStatementReq request = new TExecuteStatementReq(); @@ -751,7 +796,8 @@ void testExecuteWithTimeoutExpired() throws TException, SQLException { // Set the async poll interval to 1 second to facilitate testing when(connectionContext.getAsyncExecPollInterval()).thenReturn(1000); - accessor = new DatabricksThriftAccessor(thriftClient, connectionContext); + accessor = spy(new DatabricksThriftAccessor(connectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); // Create statement execution mocks TExecuteStatementReq request = new TExecuteStatementReq(); @@ -796,7 +842,15 @@ void testFetchResultsWithCustomMaxRowsPerBlock() throws TException, SQLException IDatabricksConnectionContext mockConnectionContext = mock(IDatabricksConnectionContext.class); when(mockConnectionContext.getDirectResultMode()).thenReturn(true); when(mockConnectionContext.getRowsFetchedPerBlock()).thenReturn(customMaxRows); - accessor = new DatabricksThriftAccessor(thriftClient, mockConnectionContext); + // Ensure configurator manager returns a configurator for this separate mock context + ClientConfigurator customMockConfigurator = mock(ClientConfigurator.class); + when(customMockConfigurator.getDatabricksConfig()).thenReturn(new DatabricksConfig()); + when(configuratorManager.getConfigurator(mockConnectionContext)) + .thenReturn(customMockConfigurator); + lenient().when(mockConnectionContext.getAsyncExecPollInterval()).thenReturn(1000); + lenient().when(mockConnectionContext.getEndpointURL()).thenReturn("http://localhost"); + accessor = spy(new DatabricksThriftAccessor(mockConnectionContext)); + doReturn(thriftClient).when(accessor).getThriftClient(); TExecuteStatementReq executeRequest = new TExecuteStatementReq(); TExecuteStatementResp executeResponse = diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java index dfa5f545a7..c9a480735c 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/thrift/DatabricksThriftServiceClientTest.java @@ -20,6 +20,7 @@ import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.StatementType; import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.model.client.thrift.generated.*; import com.databricks.jdbc.model.core.ExternalLink; @@ -894,7 +895,7 @@ void testGetDatabricksConfig() { } @Test - void testResetAccessToken() { + void testResetAccessToken() throws DatabricksParsingException { DatabricksThriftServiceClient client = new DatabricksThriftServiceClient(thriftAccessor, connectionContext); DatabricksHttpTTransport mockDatabricksHttpTTransport =