diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java index 9519c70b..cf330c42 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java @@ -49,6 +49,9 @@ public class DataCloudPreparedStatement extends DataCloudStatement implements Pr private String sql; private final ParameterManager parameterManager; private final Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + // True if we are currently fetching metadata from the server, this influences the query param generation + // to not return any data. + private boolean fetchingMetadata = false; DataCloudPreparedStatement(DataCloudConnection connection, ParameterManager parameterManager) { super(connection); @@ -78,7 +81,10 @@ public boolean execute(String sql) throws SQLException { } @Override - protected ExecuteQueryParamBuilder getQueryParamBuilder(QueryTimeout queryTimeout) throws SQLException { + protected QueryParam.Builder getQueryParamBuilder( + String sql, QueryTimeout queryTimeout, QueryParam.TransferMode transferMode) throws SQLException { + val builder = super.getQueryParamBuilder(sql, queryTimeout, transferMode); + final byte[] encodedRow; try { encodedRow = toArrowByteArray(parameterManager.getParameters(), calendar); @@ -86,14 +92,15 @@ protected ExecuteQueryParamBuilder getQueryParamBuilder(QueryTimeout queryTimeou throw new SQLException("Failed to encode parameters on prepared statement", e); } - val preparedQueryParams = QueryParam.newBuilder() - .setParamStyle(QueryParam.ParameterStyle.QUESTION_MARK) + if (fetchingMetadata) { + // Submit the query as metadata only query, with limit 0 Hyper will skip execution. + builder.setQueryRowLimit(0); + } + + return builder.setParamStyle(QueryParam.ParameterStyle.QUESTION_MARK) .setArrowParameters(QueryParameterArrow.newBuilder() .setData(ByteString.copyFrom(encodedRow)) - .build()) - .build(); - - return super.getQueryParamBuilder(queryTimeout).withQueryParams(preparedQueryParams); + .build()); } public boolean executeAsyncQuery() throws SQLException { @@ -265,7 +272,18 @@ public void setArray(int parameterIndex, Array x) throws SQLException { @Override public ResultSetMetaData getMetaData() throws SQLException { - throw new SQLException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + if ((resultSet != null) && !resultSet.isClosed()) { + return resultSet.getMetaData(); + } + try { + fetchingMetadata = true; + val result = super.executeQuery(sql); + val metadata = result.getMetaData(); + result.close(); + return metadata; + } finally { + fetchingMetadata = false; + } } @Override diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java index 0860bde1..76b6ec4e 100644 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java +++ b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java @@ -25,6 +25,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import salesforce.cdp.hyperdb.v1.QueryParam; +import salesforce.cdp.hyperdb.v1.ResultRange; @Slf4j public class DataCloudStatement implements Statement, AutoCloseable { @@ -58,13 +59,23 @@ public DataCloudStatement(@NonNull DataCloudConnection connection) { this.statementProperties = connection.getConnectionProperties().getStatementProperties(); } - protected ExecuteQueryParamBuilder getQueryParamBuilder(QueryTimeout queryTimeout) throws SQLException { + protected QueryParam.Builder getQueryParamBuilder( + String sql, QueryTimeout queryTimeout, QueryParam.TransferMode transferMode) throws SQLException { + val builder = QueryParam.newBuilder() + .setQuery(sql) + .setOutputFormat(QueryResultArrowStream.OUTPUT_FORMAT) + .setTransferMode(transferMode); + val querySettings = new HashMap<>(statementProperties.getQuerySettings()); if (!queryTimeout.getServerQueryTimeout().isZero()) { querySettings.put( "query_timeout", queryTimeout.getServerQueryTimeout().toMillis() + "ms"); } - return ExecuteQueryParamBuilder.of(querySettings); + if (!querySettings.isEmpty()) { + builder.putAllSettings(querySettings); + } + + return builder; } @Getter @@ -127,10 +138,14 @@ public ResultSet executeQuery(String sql) throws SQLException { private QueryResultIterator executeAdaptiveQuery(String sql) throws SQLException { val queryTimeout = QueryTimeout.of( statementProperties.getQueryTimeout(), statementProperties.getQueryTimeoutLocalEnforcementDelay()); - val paramBuilder = getQueryParamBuilder(queryTimeout); - val queryParam = targetMaxRows > 0 - ? paramBuilder.getAdaptiveRowLimitQueryParams(sql, targetMaxRows, targetMaxBytes) - : paramBuilder.getAdaptiveQueryParams(sql); + val paramBuilder = getQueryParamBuilder(sql, queryTimeout, QueryParam.TransferMode.ADAPTIVE); + if (targetMaxRows > 0) { + val range = ResultRange.newBuilder().setRowLimit(targetMaxRows).setByteLimit(targetMaxBytes); + paramBuilder.setResultRange(range); + log.info("setting row limit query. maxRows={}, maxBytes={}", (long) targetMaxRows, (long) targetMaxBytes); + } + QueryParam queryParam = paramBuilder.build(); + val stub = connection .getStub() .withDeadlineAfter( @@ -147,8 +162,8 @@ protected void executeAsyncQueryInternal(String sql) throws SQLException { try { val queryTimeout = QueryTimeout.of( statementProperties.getQueryTimeout(), statementProperties.getQueryTimeoutLocalEnforcementDelay()); - val paramBuilder = getQueryParamBuilder(queryTimeout); - val request = paramBuilder.getQueryParams(sql, QueryParam.TransferMode.ASYNC); + val paramBuilder = getQueryParamBuilder(sql, queryTimeout, QueryParam.TransferMode.ASYNC); + QueryParam queryParam = paramBuilder.build(); val stub = connection .getStub() .withDeadlineAfter( @@ -157,7 +172,7 @@ protected void executeAsyncQueryInternal(String sql) throws SQLException { // We set the deadline based off the query timeout here as the server-side doesn't properly enforce // the query timeout during the initial compilation phase. By setting the deadline, we can ensure // that the query timeout is enforced also when the server hangs during compilation. - queryHandle = AsyncQueryAccessHandle.of(stub, request); + queryHandle = AsyncQueryAccessHandle.of(stub, queryParam); log.info( "executeAsyncQuery completed. queryId={}", queryHandle.getQueryStatus().getQueryId()); @@ -185,6 +200,7 @@ public void close() throws SQLException { log.debug("Entering close"); if (resultSet != null) { resultSet.close(); + resultSet = null; } log.debug("Exiting close"); } diff --git a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryParamBuilder.java b/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryParamBuilder.java deleted file mode 100644 index 01222f29..00000000 --- a/jdbc-core/src/main/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryParamBuilder.java +++ /dev/null @@ -1,72 +0,0 @@ -/** - * This file is part of https://github.com/forcedotcom/datacloud-jdbc which is released under the - * Apache 2.0 license. See https://github.com/forcedotcom/datacloud-jdbc/blob/main/LICENSE.txt - */ -package com.salesforce.datacloud.jdbc.core; - -import com.salesforce.datacloud.jdbc.protocol.QueryResultArrowStream; -import com.salesforce.datacloud.jdbc.util.Unstable; -import java.util.Map; -import lombok.*; -import lombok.extern.slf4j.Slf4j; -import salesforce.cdp.hyperdb.v1.QueryParam; -import salesforce.cdp.hyperdb.v1.ResultRange; - -/** - * Although this class is public, we do not consider it to be part of our API. - * It is for internal use only until it stabilizes. - */ -@Builder(access = AccessLevel.PRIVATE) -@Slf4j -@Unstable -public class ExecuteQueryParamBuilder { - private final QueryParam settingsQueryParams; - - private QueryParam additionalQueryParams; - - public static ExecuteQueryParamBuilder of(Map querySettings) { - val builder = ExecuteQueryParamBuilder.builder(); - if (!querySettings.isEmpty()) { - builder.settingsQueryParams( - QueryParam.newBuilder().putAllSettings(querySettings).build()); - } - return builder.build(); - } - - public ExecuteQueryParamBuilder withQueryParams(QueryParam additionalQueryParams) { - this.additionalQueryParams = additionalQueryParams; - return this; - } - - private QueryParam completeBuilder(QueryParam.Builder builder) { - if (additionalQueryParams != null) { - builder.mergeFrom(additionalQueryParams); - } - if (settingsQueryParams != null) { - builder.mergeFrom(settingsQueryParams); - } - return builder.build(); - } - - public QueryParam getQueryParams(String sql, QueryParam.TransferMode transferMode) { - return completeBuilder(QueryParam.newBuilder() - .setQuery(sql) - .setOutputFormat(QueryResultArrowStream.OUTPUT_FORMAT) - .setTransferMode(transferMode)); - } - - public QueryParam getAdaptiveQueryParams(String sql) { - return getQueryParams(sql, QueryParam.TransferMode.ADAPTIVE); - } - - public QueryParam getAdaptiveRowLimitQueryParams(String sql, long maxRows, long maxBytes) { - val builder = QueryParam.newBuilder() - .setQuery(sql) - .setOutputFormat(QueryResultArrowStream.OUTPUT_FORMAT) - .setTransferMode(QueryParam.TransferMode.ADAPTIVE); - val range = ResultRange.newBuilder().setRowLimit(maxRows).setByteLimit(maxBytes); - builder.setResultRange(range); - log.info("setting row limit query. maxRows={}, maxBytes={}", maxRows, maxBytes); - return completeBuilder(builder); - } -} diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementHyperTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementHyperTest.java index e9de0e45..44527677 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementHyperTest.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementHyperTest.java @@ -8,10 +8,14 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import com.salesforce.datacloud.jdbc.hyper.LocalHyperTestBase; +import com.salesforce.datacloud.jdbc.util.HyperLogScope; +import java.math.BigDecimal; import java.sql.Connection; import java.sql.Date; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; import java.sql.Time; import java.sql.Timestamp; import java.time.LocalDate; @@ -22,6 +26,7 @@ import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -222,4 +227,103 @@ public void testPreparedStatementTimestampWithCalendarRange() { } } } + + @Test + @SneakyThrows + public void testGetMetaDataReturnsResultSetMetaData() { + try (HyperLogScope logScope = new HyperLogScope()) { + try (Connection connection = getHyperQueryConnection(logScope.getProperties())) { + try (PreparedStatement preparedStatement = + connection.prepareStatement("select 1 as id, 'test' as name, 3.14 as value, pg_sleep(100000) as" + + " would_timeout_in_execution")) { + ResultSetMetaData metadata = preparedStatement.getMetaData(); + + assertThat(metadata).isNotNull(); + assertThat(metadata.getColumnCount()).isEqualTo(4); + + assertThat(metadata.getColumnName(1)).isEqualTo("id"); + assertThat(metadata.getColumnTypeName(1)).isEqualTo("INTEGER"); + + assertThat(metadata.getColumnName(2)).isEqualTo("name"); + assertThat(metadata.getColumnTypeName(2)).isEqualTo("VARCHAR"); + + assertThat(metadata.getColumnName(3)).isEqualTo("value"); + assertThat(metadata.getColumnTypeName(3)).isEqualTo("DECIMAL"); + + // Verify that the query actually finished + ResultSet resultSet = logScope.executeQuery("SELECT COUNT(*) FROM hyper_log WHERE k='query-end'"); + resultSet.next(); + Assertions.assertThat(resultSet.getDouble(1)).isEqualTo(1); + } + } + } + } + + @Test + @SneakyThrows + public void testGetMetaDataFollowedByExecuteReturnsData() { + try (Connection connection = getHyperQueryConnection()) { + try (PreparedStatement preparedStatement = + connection.prepareStatement("select 1 as id, 'test' as name, 3.14 as value")) { + ResultSetMetaData metadata = preparedStatement.getMetaData(); + assertThat(metadata).isNotNull(); + assertThat(metadata.getColumnCount()).isEqualTo(3); + + try (ResultSet resultSet = preparedStatement.executeQuery()) { + assertThat(resultSet.next()).isTrue(); + assertThat(resultSet.getInt("id")).isEqualTo(1); + assertThat(resultSet.getString("name")).isEqualTo("test"); + assertThat(resultSet.getBigDecimal("value")).isEqualTo(BigDecimal.valueOf(3.14)); + assertThat(resultSet.next()).isFalse(); + } + } + } + } + + @Test + @SneakyThrows + public void testGetMetaDataWithInvalidQueryThrowsSQLException() { + try (Connection connection = getHyperQueryConnection()) { + try (PreparedStatement preparedStatement = + connection.prepareStatement("select * from non_existent_table")) { + Assertions.assertThatThrownBy(preparedStatement::getMetaData) + .isInstanceOf(SQLException.class) + .hasMessageContaining("table \"non_existent_table\" does not exist"); + } + } + } + + @Test + @SneakyThrows + public void testGetMetaDataAfterExecuteDoesNotQueryAgain() { + try (HyperLogScope logScope = new HyperLogScope()) { + try (Connection connection = getHyperQueryConnection(logScope.getProperties())) { + try (PreparedStatement preparedStatement = + connection.prepareStatement("select 1 as id, 'test' as name")) { + try (ResultSet resultSet = preparedStatement.executeQuery()) { + assertThat(resultSet.next()).isTrue(); + + ResultSetMetaData metadata = preparedStatement.getMetaData(); + assertThat(metadata).isNotNull(); + assertThat(metadata.getColumnCount()).isEqualTo(2); + } + + ResultSet logResult = logScope.executeQuery("SELECT COUNT(*) FROM hyper_log WHERE k='query-end'"); + logResult.next(); + Assertions.assertThat(logResult.getDouble(1)) + .as("Should only have one query execution, not two") + .isEqualTo(1); + + // Test that after closing the resultset it would query again + ResultSetMetaData metadata = preparedStatement.getMetaData(); + assertThat(metadata).isNotNull(); + assertThat(metadata.getColumnCount()).isEqualTo(2); + + ResultSet logResult2 = logScope.executeQuery("SELECT COUNT(*) FROM hyper_log WHERE k='query-end'"); + logResult2.next(); + Assertions.assertThat(logResult2.getDouble(1)).isEqualTo(2); + } + } + } + } } diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java index 271e918f..829d2c5d 100644 --- a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java @@ -261,7 +261,6 @@ private static Stream unsupported() { impl("executeUpdate", s -> s.executeUpdate("", Statement.RETURN_GENERATED_KEYS)), impl("executeUpdate", s -> s.executeUpdate("", new int[] {})), impl("executeUpdate", s -> s.executeUpdate("", new String[] {})), - impl("getMetaData", DataCloudPreparedStatement::getMetaData), impl("getParameterMetaData", DataCloudPreparedStatement::getParameterMetaData)); } diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/protocol/async/AsyncQueryResultIteratorTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/protocol/async/AsyncQueryResultIteratorTest.java new file mode 100644 index 00000000..8c63e5f3 --- /dev/null +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/protocol/async/AsyncQueryResultIteratorTest.java @@ -0,0 +1,188 @@ +/** + * This file is part of https://github.com/forcedotcom/datacloud-jdbc which is released under the + * Apache 2.0 license. See https://github.com/forcedotcom/datacloud-jdbc/blob/main/LICENSE.txt + */ +package com.salesforce.datacloud.jdbc.protocol.async; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.core.InterceptedHyperTestBase; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.val; +import org.grpcmock.GrpcMock; +import org.junit.jupiter.api.Test; +import salesforce.cdp.hyperdb.v1.ExecuteQueryResponse; +import salesforce.cdp.hyperdb.v1.HyperServiceGrpc; +import salesforce.cdp.hyperdb.v1.OutputFormat; +import salesforce.cdp.hyperdb.v1.QueryInfo; +import salesforce.cdp.hyperdb.v1.QueryParam; +import salesforce.cdp.hyperdb.v1.QueryResult; +import salesforce.cdp.hyperdb.v1.QueryStatus; + +class AsyncQueryResultIteratorTest extends InterceptedHyperTestBase { + + private static final String TEST_QUERY = "SELECT * FROM test_table asyncTest"; + private static final String TEST_QUERY_ID = "async-test-query-123"; + + private HyperServiceGrpc.HyperServiceStub setupStub() { + return getInterceptedStub().withDeadlineAfter(30000, TimeUnit.MILLISECONDS); + } + + /** + * Tests that the async iterator correctly handles a scenario where: + * 1. Initial executeQuery stream returns data quickly + * 2. Query info polling hangs (simulating slow server response) + * 3. Eventually completes when the polling returns + * + * This ensures the async nature allows other work to proceed while waiting. + */ + @Test + void whenQueryProducesDataThenHangsAtEnd_shouldCompleteSuccessfully() throws Exception { + val stub = setupStub(); + + val delayedInfoLatch = new CountDownLatch(1); + val iteratorReceivedFirstResult = new CountDownLatch(1); + + // Setup executeQuery to return initial data with RUNNING status + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .withRequest(req -> req.getQuery().equals(TEST_QUERY)) + .willProxyTo((request, observer) -> { + // First response: query info with running status + observer.onNext(ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder() + .setQueryStatus(QueryStatus.newBuilder() + .setQueryId(TEST_QUERY_ID) + .setCompletionStatus(QueryStatus.CompletionStatus.RUNNING_OR_UNSPECIFIED) + .setChunkCount(1) + .build()) + .build()) + .build()); + + // Second response: inline result data (fast) + observer.onNext(ExecuteQueryResponse.newBuilder() + .setQueryResult(QueryResult.newBuilder().build()) + .build()); + + observer.onCompleted(); + })); + + // Setup getQueryInfo to delay before returning finished status + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getGetQueryInfoMethod()) + .withRequest(req -> req.getQueryId().equals(TEST_QUERY_ID)) + .willProxyTo((request, observer) -> { + // Wait for the iterator to receive the first result before continuing + try { + iteratorReceivedFirstResult.await(5, TimeUnit.SECONDS); + // Simulate a delay in the query info response + delayedInfoLatch.await(2, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // Return finished status + observer.onNext(QueryInfo.newBuilder() + .setQueryStatus(QueryStatus.newBuilder() + .setQueryId(TEST_QUERY_ID) + .setCompletionStatus(QueryStatus.CompletionStatus.FINISHED) + .setChunkCount(1) + .build()) + .build()); + observer.onCompleted(); + })); + + val queryParam = QueryParam.newBuilder() + .setQuery(TEST_QUERY) + .setOutputFormat(OutputFormat.ARROW_IPC) + .setTransferMode(QueryParam.TransferMode.ADAPTIVE) + .build(); + + try (val iterator = AsyncQueryResultIterator.of(stub, queryParam)) { + val resultCount = new AtomicInteger(0); + + // First call should return quickly with the inline result + CompletableFuture> firstFuture = + iterator.next().toCompletableFuture(); + + // Should complete quickly since data is available + Optional firstResult = firstFuture.get(5, TimeUnit.SECONDS); + assertThat(firstResult).isPresent(); + resultCount.incrementAndGet(); + + // Signal that we received the first result + iteratorReceivedFirstResult.countDown(); + + // Second call will need to poll for query info - this will hang initially + CompletableFuture> secondFuture = + iterator.next().toCompletableFuture(); + + // Verify the future is not completed yet (query info is delayed) + assertThat(secondFuture.isDone()).isFalse(); + + Thread.sleep(1000); + + // Release the delayed query info + delayedInfoLatch.countDown(); + + // Now it should complete with empty (no more results) + Optional secondResult = secondFuture.get(5, TimeUnit.SECONDS); + assertThat(secondResult).isEmpty(); + + // Verify final status + assertThat(iterator.getQueryStatus().getCompletionStatus()) + .isEqualTo(QueryStatus.CompletionStatus.FINISHED); + assertThat(resultCount.get()).isEqualTo(1); + } + } + + @Test + void whenExecuteQueryReturnsFinishedImmediately_shouldCompleteWithoutPolling() throws Exception { + val stub = setupStub(); + + // Setup executeQuery to return finished status immediately with inline data + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .withRequest(req -> req.getQuery().equals(TEST_QUERY)) + .willProxyTo((request, observer) -> { + observer.onNext(ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder() + .setQueryStatus(QueryStatus.newBuilder() + .setQueryId(TEST_QUERY_ID) + .setCompletionStatus(QueryStatus.CompletionStatus.FINISHED) + .setChunkCount(1) + .build()) + .build()) + .build()); + + observer.onNext(ExecuteQueryResponse.newBuilder() + .setQueryResult(QueryResult.newBuilder().build()) + .build()); + + observer.onCompleted(); + })); + + val queryParam = QueryParam.newBuilder() + .setQuery(TEST_QUERY) + .setOutputFormat(OutputFormat.ARROW_IPC) + .setTransferMode(QueryParam.TransferMode.ADAPTIVE) + .build(); + + try (val iterator = AsyncQueryResultIterator.of(stub, queryParam)) { + // First result should be available + Optional first = iterator.next().toCompletableFuture().get(5, TimeUnit.SECONDS); + assertThat(first).isPresent(); + + // Second call should return empty (finished) + Optional second = iterator.next().toCompletableFuture().get(5, TimeUnit.SECONDS); + assertThat(second).isEmpty(); + + assertThat(iterator.getQueryStatus().getCompletionStatus()) + .isEqualTo(QueryStatus.CompletionStatus.FINISHED); + } + + // Verify no query info polling was needed + verifyGetQueryInfo(0); + } +} diff --git a/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/protocol/async/core/SyncIteratorAdapterTest.java b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/protocol/async/core/SyncIteratorAdapterTest.java new file mode 100644 index 00000000..ed19f47a --- /dev/null +++ b/jdbc-core/src/test/java/com/salesforce/datacloud/jdbc/protocol/async/core/SyncIteratorAdapterTest.java @@ -0,0 +1,127 @@ +/** + * This file is part of https://github.com/forcedotcom/datacloud-jdbc which is released under the + * Apache 2.0 license. See https://github.com/forcedotcom/datacloud-jdbc/blob/main/LICENSE.txt + */ +package com.salesforce.datacloud.jdbc.protocol.async.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.val; +import org.junit.jupiter.api.Test; + +class SyncIteratorAdapterTest { + + @Test + void testInterruptHandlingRestoresInterruptFlag() throws Exception { + val blockingFuture = new CompletableFuture>(); + val closeCalled = new AtomicBoolean(false); + val iteratorStartedBlocking = new CountDownLatch(1); + + // Create an async iterator that blocks indefinitely until closed + AsyncIterator asyncIterator = new AsyncIterator() { + @Override + public CompletionStage> next() { + iteratorStartedBlocking.countDown(); + return blockingFuture; + } + + @Override + public void close() { + closeCalled.set(true); + // Simulate gRPC cancellation completing the future with error + blockingFuture.completeExceptionally(new RuntimeException("Stream cancelled")); + } + }; + + val adapter = new SyncIteratorAdapter<>(asyncIterator); + val threadInterrupted = new AtomicBoolean(false); + val hasNextResult = new AtomicBoolean(true); + + // Run hasNext() in a separate thread and interrupt it + Thread thread = new Thread(() -> { + try { + hasNextResult.set(adapter.hasNext()); + } catch (RuntimeException e) { + // Expected - stream was cancelled + } + threadInterrupted.set(Thread.currentThread().isInterrupted()); + }); + + thread.start(); + + // Wait for the thread to start blocking on the future + assertThat(iteratorStartedBlocking.await(5, TimeUnit.SECONDS)).isTrue(); + + // Interrupt the thread + thread.interrupt(); + + // Wait for thread to finish + thread.join(5000); + assertThat(thread.isAlive()).isFalse(); + + // Verify close was called due to interrupt + assertThat(closeCalled.get()).isTrue(); + + // Verify interrupt flag was restored + assertThat(threadInterrupted.get()).isTrue(); + } + + @Test + void testNormalIteration() { + val values = new String[] {"a", "b", "c"}; + val index = new AtomicInteger(0); + + AsyncIterator asyncIterator = new AsyncIterator() { + @Override + public CompletionStage> next() { + int i = index.getAndIncrement(); + if (i < values.length) { + return CompletableFuture.completedFuture(Optional.of(values[i])); + } + return CompletableFuture.completedFuture(Optional.empty()); + } + + @Override + public void close() {} + }; + + val adapter = new SyncIteratorAdapter<>(asyncIterator); + + assertThat(adapter.hasNext()).isTrue(); + assertThat(adapter.next()).isEqualTo("a"); + assertThat(adapter.hasNext()).isTrue(); + assertThat(adapter.next()).isEqualTo("b"); + assertThat(adapter.hasNext()).isTrue(); + assertThat(adapter.next()).isEqualTo("c"); + assertThat(adapter.hasNext()).isFalse(); + // Check that repeated calls stay false + assertThat(adapter.hasNext()).isFalse(); + // Check that next() throws an exception + assertThatThrownBy(adapter::next).isInstanceOf(NoSuchElementException.class); + } + + @Test + void testEmptyIterator() { + AsyncIterator asyncIterator = new AsyncIterator() { + @Override + public CompletionStage> next() { + return CompletableFuture.completedFuture(Optional.empty()); + } + + @Override + public void close() {} + }; + + val adapter = new SyncIteratorAdapter<>(asyncIterator); + assertThat(adapter.hasNext()).isFalse(); + } +}