From 821fe48376e433824315106ee0ad500f324d4131 Mon Sep 17 00:00:00 2001 From: Jayant Singh Date: Tue, 30 Sep 2025 13:42:04 +0530 Subject: [PATCH 1/5] Implement lazy loading for inline Arrow results This PR introduces lazy loading support for inline Arrow results to improve memory efficiency when handling large result sets. Previously, InlineChunkProvider would eagerly fetch all arrow batches upfront when results had hasMoreRows = true, which could lead to memory issues with large datasets. This change splits the handling into two separate paths: 1. Lazy path (new): For Thrift-based inline Arrow results (when ARROW_BASED_SET is returned), we now use LazyThriftInlineArrowResult which fetches arrow batches on-demand as the client iterates through rows. This is similar to how LazyThriftResult works for columnar data. 2. Remote path (existing): For URL-based Arrow results (URL_BASED_SET), we continue using ArrowStreamResult with RemoteChunkProvider which downloads chunks from cloud storage. The InlineChunkProvider is now only used for SEA results with JSON_ARRAY format and INLINE disposition (contain all data inline {no hasMoreRows flag set}). This should reduce memory consumption and improve performance when dealing with large inline Arrow result sets. --- .../jdbc/api/impl/ExecutionResultFactory.java | 5 +- .../api/impl/arrow/ArrowStreamResult.java | 92 ++-- .../api/impl/arrow/InlineChunkProvider.java | 122 ----- .../arrow/LazyThriftInlineArrowResult.java | 425 ++++++++++++++++++ .../api/impl/ExecutionResultFactoryTest.java | 7 +- .../api/impl/arrow/ArrowStreamResultTest.java | 21 +- .../impl/arrow/InlineChunkProviderTest.java | 42 -- .../LazyThriftInlineArrowResultTest.java | 285 ++++++++++++ 8 files changed, 771 insertions(+), 228 deletions(-) create mode 100644 src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java create mode 100644 src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java diff --git a/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java b/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java index 4c719731de..ba6d7acb7b 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java +++ b/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java @@ -1,6 +1,7 @@ package com.databricks.jdbc.api.impl; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; @@ -96,9 +97,9 @@ private static IExecutionResult getResultHandler( case COLUMN_BASED_SET: return new LazyThriftResult(resultsResp, parentStatement, session); case ARROW_BASED_SET: - return new ArrowStreamResult(resultsResp, true, parentStatement, session); + return new LazyThriftInlineArrowResult(resultsResp, parentStatement, session); case URL_BASED_SET: - return new ArrowStreamResult(resultsResp, false, parentStatement, session); + return new ArrowStreamResult(resultsResp, parentStatement, session); case ROW_BASED_SET: throw new DatabricksSQLFeatureNotSupportedException( "Invalid state - row based set cannot be received"); diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java index 29a88fd6ba..4e011301ec 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java @@ -85,13 +85,11 @@ public ArrowStreamResult( public ArrowStreamResult( TFetchResultsResp resultsResp, - boolean isInlineArrow, IDatabricksStatementInternal parentStatementId, IDatabricksSession session) throws DatabricksSQLException { this( resultsResp, - isInlineArrow, parentStatementId, session, DatabricksHttpClientFactory.getInstance().getClient(session.getConnectionContext())); @@ -100,27 +98,22 @@ public ArrowStreamResult( @VisibleForTesting ArrowStreamResult( TFetchResultsResp resultsResp, - boolean isInlineArrow, IDatabricksStatementInternal parentStatement, IDatabricksSession session, IDatabricksHttpClient httpClient) throws DatabricksSQLException { this.session = session; setColumnInfo(resultsResp.getResultSetMetadata()); - if (isInlineArrow) { - this.chunkProvider = new InlineChunkProvider(resultsResp, parentStatement, session); - } else { - CompressionCodec compressionCodec = - CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); - this.chunkProvider = - new RemoteChunkProvider( - parentStatement, - resultsResp, - session, - httpClient, - session.getConnectionContext().getCloudFetchThreadPoolSize(), - compressionCodec); - } + CompressionCodec compressionCodec = + CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); + this.chunkProvider = + new RemoteChunkProvider( + parentStatement, + resultsResp, + session, + httpClient, + session.getConnectionContext().getCloudFetchThreadPoolSize(), + compressionCodec); } public List getArrowMetadata() throws DatabricksSQLException { @@ -133,30 +126,15 @@ public List getArrowMetadata() throws DatabricksSQLException { /** {@inheritDoc} */ @Override public Object getObject(int columnIndex) throws DatabricksSQLException { - ColumnInfoTypeName requiredType = columnInfos.get(columnIndex).getTypeName(); + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); String arrowMetadata = chunkIterator.getType(columnIndex); if (arrowMetadata == null) { - arrowMetadata = columnInfos.get(columnIndex).getTypeText(); - } - - // Handle complex type conversion when complex datatype support is disabled - boolean isComplexDatatypeSupportEnabled = - this.session.getConnectionContext().isComplexDatatypeSupportEnabled(); - if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) { - LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING"); - - Object result = - chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfos.get(columnIndex)); - if (result == null) { - return null; - } - ComplexDataTypeParser parser = new ComplexDataTypeParser(); - return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata); + arrowMetadata = columnInfo.getTypeText(); } - return chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, requiredType, arrowMetadata, columnInfos.get(columnIndex)); + return getObjectWithComplexTypeHandling( + session, chunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); } /** @@ -237,4 +215,44 @@ private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { columnInfos.add(getColumnInfoFromTColumnDesc(tColumnDesc)); } } + + /** + * Helper method to handle complex type conversion when complex datatype support is disabled. + * + * @param session The databricks session + * @param chunkIterator The chunk iterator + * @param columnIndex The column index + * @param requiredType The required column type + * @param arrowMetadata The arrow metadata + * @param columnInfo The column info + * @return The object value (converted if complex type and support disabled) + * @throws DatabricksSQLException if an error occurs + */ + protected static Object getObjectWithComplexTypeHandling( + IDatabricksSession session, + ArrowResultChunkIterator chunkIterator, + int columnIndex, + ColumnInfoTypeName requiredType, + String arrowMetadata, + ColumnInfo columnInfo) + throws DatabricksSQLException { + boolean isComplexDatatypeSupportEnabled = + session.getConnectionContext().isComplexDatatypeSupportEnabled(); + + if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) { + LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING"); + Object result = + chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfo); + if (result == null) { + return null; + } + ComplexDataTypeParser parser = new ComplexDataTypeParser(); + + return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata); + } + + return chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, requiredType, arrowMetadata, columnInfo); + } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java index e22d974a40..32f5e1b803 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java @@ -1,31 +1,17 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*; import static com.databricks.jdbc.common.util.DecompressionUtil.decompress; -import com.databricks.jdbc.api.internal.IDatabricksSession; -import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; -import com.databricks.jdbc.model.client.thrift.generated.*; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; import com.google.common.annotations.VisibleForTesting; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.arrow.vector.util.SchemaUtility; /** Class to manage inline Arrow chunks */ public class InlineChunkProvider implements ChunkProvider { @@ -37,23 +23,6 @@ public class InlineChunkProvider implements ChunkProvider { private final ArrowResultChunk arrowResultChunk; // There is only one packet of data in case of inline arrow - InlineChunkProvider( - TFetchResultsResp resultsResp, - IDatabricksStatementInternal parentStatement, - IDatabricksSession session) - throws DatabricksParsingException { - this.currentChunkIndex = -1; - this.totalRows = 0; - ByteArrayInputStream byteStream = initializeByteStream(resultsResp, session, parentStatement); - ArrowResultChunk.Builder builder = - ArrowResultChunk.builder().withInputStream(byteStream, totalRows); - - if (parentStatement != null) { - builder.withStatementId(parentStatement.getStatementId()); - } - arrowResultChunk = builder.build(); - } - /** * Constructor for inline arrow chunk provider from {@link ResultData} and {@link ResultManifest}. * @@ -123,97 +92,6 @@ public boolean isClosed() { return isClosed; } - private ByteArrayInputStream initializeByteStream( - TFetchResultsResp resultsResp, - IDatabricksSession session, - IDatabricksStatementInternal parentStatement) - throws DatabricksParsingException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - CompressionCodec compressionType = - CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); - try { - byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata()); - if (serializedSchema != null) { - baos.write(serializedSchema); - } - writeToByteOutputStream( - compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos); - while (resultsResp.hasMoreRows) { - resultsResp = session.getDatabricksClient().getMoreResults(parentStatement); - writeToByteOutputStream( - compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos); - } - return new ByteArrayInputStream(baos.toByteArray()); - } catch (DatabricksSQLException | IOException e) { - handleError(e); - } - return null; - } - - private void writeToByteOutputStream( - CompressionCodec compressionCodec, - IDatabricksStatementInternal parentStatement, - List arrowBatchList, - ByteArrayOutputStream baos) - throws DatabricksSQLException, IOException { - for (TSparkArrowBatch arrowBatch : arrowBatchList) { - byte[] decompressedBytes = - decompress( - arrowBatch.getBatch(), - compressionCodec, - String.format( - "Data fetch for inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]", - arrowBatch.getRowCount(), parentStatement, compressionCodec)); - totalRows += arrowBatch.getRowCount(); - baos.write(decompressedBytes); - } - } - - private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata) - throws DatabricksSQLException { - if (metadata.getArrowSchema() != null) { - return metadata.getArrowSchema(); - } - Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema()); - try { - return SchemaUtility.serialize(arrowSchema); - } catch (IOException e) { - handleError(e); - } - // should never reach here; - return null; - } - - private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema) - throws DatabricksParsingException { - List fields = new ArrayList<>(); - if (hiveSchema == null) { - return new Schema(fields); - } - try { - hiveSchema - .getColumns() - .forEach( - columnDesc -> { - try { - fields.add(getArrowField(columnDesc)); - } catch (SQLException e) { - throw new RuntimeException(e); - } - }); - } catch (RuntimeException e) { - handleError(e); - } - return new Schema(fields); - } - - private Field getArrowField(TColumnDesc columnDesc) throws SQLException { - TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc()); - ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType()); - FieldType fieldType = new FieldType(true, arrowType, null); - return new Field(columnDesc.getColumnName(), fieldType, null); - } - @VisibleForTesting void handleError(Exception e) throws DatabricksParsingException { String errorMessage = diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java new file mode 100644 index 0000000000..08950339cf --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java @@ -0,0 +1,425 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.common.EnvironmentVariables.DEFAULT_RESULT_ROW_LIMIT; +import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*; +import static com.databricks.jdbc.common.util.DecompressionUtil.decompress; + +import com.databricks.jdbc.api.impl.IExecutionResult; +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.common.CompressionCodec; +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.core.ColumnInfo; +import com.databricks.jdbc.model.core.ColumnInfoTypeName; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import com.google.common.annotations.VisibleForTesting; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.SchemaUtility; + +/** + * Lazy implementation for thrift-based inline Arrow results that fetches arrow batches on demand. + * Similar to LazyThriftResult but processes Arrow data instead of columnar thrift data. + */ +public class LazyThriftInlineArrowResult implements IExecutionResult { + private static final JdbcLogger LOGGER = + JdbcLoggerFactory.getLogger(LazyThriftInlineArrowResult.class); + + private TFetchResultsResp currentResponse; + private ArrowResultChunk currentChunk; + private ArrowResultChunkIterator currentChunkIterator; + private long globalRowIndex; + private final IDatabricksSession session; + private final IDatabricksStatementInternal statement; + private final int maxRows; + private boolean hasReachedEnd; + private boolean isClosed; + private long totalRowsFetched; + private List columnInfos; + + /** + * Creates a new LazyThriftInlineArrowResult that lazily fetches arrow data on demand. + * + * @param initialResponse the initial response from the server + * @param statement the statement that generated this result + * @param session the session to use for fetching additional data + * @throws DatabricksSQLException if the initial response cannot be processed + */ + public LazyThriftInlineArrowResult( + TFetchResultsResp initialResponse, + IDatabricksStatementInternal statement, + IDatabricksSession session) + throws DatabricksSQLException { + this.currentResponse = initialResponse; + this.statement = statement; + this.session = session; + this.maxRows = statement != null ? statement.getMaxRows() : DEFAULT_RESULT_ROW_LIMIT; + this.globalRowIndex = -1; + this.hasReachedEnd = false; + this.isClosed = false; + this.totalRowsFetched = 0; + + // Initialize column info from metadata + setColumnInfo(initialResponse.getResultSetMetadata()); + + // Load initial chunk + loadCurrentChunk(); + LOGGER.debug( + "LazyThriftInlineArrowResult initialized with {} rows in first chunk, hasMoreRows: {}", + currentChunk.numRows, + currentResponse.hasMoreRows); + } + + /** + * Gets the value at the specified column index for the current row. + * + * @param columnIndex the zero-based column index + * @return the value at the specified column + * @throws DatabricksSQLException if the result is closed, cursor is invalid, or column index is + * out of bounds + */ + @Override + public Object getObject(int columnIndex) throws DatabricksSQLException { + if (isClosed) { + throw new DatabricksSQLException( + "Result is already closed", DatabricksDriverErrorCode.STATEMENT_CLOSED); + } + if (globalRowIndex == -1) { + throw new DatabricksSQLException( + "Cursor is before first row", DatabricksDriverErrorCode.INVALID_STATE); + } + if (currentChunkIterator == null) { + throw new DatabricksSQLException( + "No current chunk available", DatabricksDriverErrorCode.INVALID_STATE); + } + if (columnIndex < 0 || columnIndex >= columnInfos.size()) { + throw new DatabricksSQLException( + "Column index out of bounds " + columnIndex, DatabricksDriverErrorCode.INVALID_STATE); + } + + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); + String arrowMetadata = currentChunkIterator.getType(columnIndex); + if (arrowMetadata == null) { + arrowMetadata = columnInfo.getTypeText(); + } + + return ArrowStreamResult.getObjectWithComplexTypeHandling( + session, currentChunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); + } + + /** + * Gets the current row index (0-based). Returns -1 if before the first row. + * + * @return the current row index + */ + @Override + public long getCurrentRow() { + return globalRowIndex; + } + + /** + * Moves the cursor to the next row. Fetches additional data from server if needed. + * + * @return true if there is a next row, false if at the end + * @throws DatabricksSQLException if an error occurs while fetching data + */ + @Override + public boolean next() throws DatabricksSQLException { + if (isClosed || hasReachedEnd) { + return false; + } + + if (!hasNext()) { + return false; + } + + // Check if we've reached the maxRows limit + boolean hasRowLimit = maxRows > 0; + if (hasRowLimit && globalRowIndex + 1 >= maxRows) { + hasReachedEnd = true; + return false; + } + + // Try to advance in current chunk + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + boolean advanced = currentChunkIterator.nextRow(); + if (advanced) { + globalRowIndex++; + return true; + } + } + + // Need to fetch next chunk + while (currentResponse.hasMoreRows) { + fetchNextChunk(); + + // If we got a chunk with data, advance to first row + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + boolean advanced = currentChunkIterator.nextRow(); + if (advanced) { + globalRowIndex++; + return true; + } + } + } + + // No more data available + hasReachedEnd = true; + return false; + } + + /** + * Checks if there are more rows available without advancing the cursor. + * + * @return true if there are more rows, false otherwise + */ + @Override + public boolean hasNext() { + if (isClosed || hasReachedEnd) { + return false; + } + + // Check maxRows limit + boolean hasRowLimit = maxRows > 0; + if (hasRowLimit && globalRowIndex + 1 >= maxRows) { + return false; + } + + // Check if there are more rows in current chunk + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + return true; + } + + // Check if there are more chunks to fetch + return currentResponse.hasMoreRows; + } + + /** Closes this result and releases associated resources. */ + @Override + public void close() { + this.isClosed = true; + if (currentChunk != null) { + currentChunk.releaseChunk(); + } + this.currentChunk = null; + this.currentChunkIterator = null; + this.currentResponse = null; + LOGGER.debug( + "LazyThriftInlineArrowResult closed after fetching {} total rows", totalRowsFetched); + } + + /** + * Gets the number of rows in the current chunk. + * + * @return the number of rows in the current chunk + */ + @Override + public long getRowCount() { + return currentChunk != null ? currentChunk.numRows : 0; + } + + /** + * Gets the chunk count. Always returns 0 for lazy thrift inline arrow results. + * + * @return 0 (lazy results don't use chunks in the same sense as buffered results) + */ + @Override + public long getChunkCount() { + return 0; + } + + private void loadCurrentChunk() throws DatabricksSQLException { + try { + ByteArrayInputStream byteStream = createArrowByteStream(currentResponse); + long rowCount = getTotalRowsInResponse(currentResponse); + + ArrowResultChunk.Builder builder = + ArrowResultChunk.builder().withInputStream(byteStream, rowCount); + + if (statement != null) { + builder.withStatementId(statement.getStatementId()); + } + + currentChunk = builder.build(); + currentChunkIterator = currentChunk.getChunkIterator(); + totalRowsFetched += rowCount; + + LOGGER.debug( + "Loaded arrow chunk with {} rows, total fetched: {}", rowCount, totalRowsFetched); + } catch (DatabricksParsingException e) { + LOGGER.error("Failed to load current chunk: {}", e.getMessage()); + hasReachedEnd = true; + throw new DatabricksSQLException( + "Failed to process arrow data", DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); + } + } + + /** + * Fetches the next chunk of data from the server and creates arrow chunks. + * + * @throws DatabricksSQLException if the fetch operation fails + */ + private void fetchNextChunk() throws DatabricksSQLException { + try { + LOGGER.debug("Fetching next arrow chunk, current total rows fetched: {}", totalRowsFetched); + currentResponse = session.getDatabricksClient().getMoreResults(statement); + + // Release previous chunk to free memory + if (currentChunk != null) { + currentChunk.releaseChunk(); + } + + loadCurrentChunk(); + + LOGGER.debug( + "Fetched arrow chunk with {} rows, hasMoreRows: {}", + currentChunk.numRows, + currentResponse.hasMoreRows); + } catch (DatabricksSQLException e) { + LOGGER.error("Failed to fetch next arrow chunk: {}", e.getMessage()); + hasReachedEnd = true; + throw e; + } + } + + private ByteArrayInputStream createArrowByteStream(TFetchResultsResp resultsResp) + throws DatabricksParsingException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + CompressionCodec compressionType = + CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); + try { + byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata()); + if (serializedSchema != null) { + baos.write(serializedSchema); + } + writeArrowBatchesToStream(compressionType, resultsResp.getResults().getArrowBatches(), baos); + return new ByteArrayInputStream(baos.toByteArray()); + } catch (DatabricksSQLException | IOException e) { + handleError(e); + } + return null; + } + + private void writeArrowBatchesToStream( + CompressionCodec compressionCodec, + List arrowBatchList, + ByteArrayOutputStream baos) + throws DatabricksSQLException, IOException { + for (TSparkArrowBatch arrowBatch : arrowBatchList) { + byte[] decompressedBytes = + decompress( + arrowBatch.getBatch(), + compressionCodec, + String.format( + "Data fetch for lazy inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]", + arrowBatch.getRowCount(), statement, compressionCodec)); + baos.write(decompressedBytes); + } + } + + private long getTotalRowsInResponse(TFetchResultsResp resultsResp) { + long totalRows = 0; + if (resultsResp.getResults() != null && resultsResp.getResults().getArrowBatches() != null) { + for (TSparkArrowBatch arrowBatch : resultsResp.getResults().getArrowBatches()) { + totalRows += arrowBatch.getRowCount(); + } + } + return totalRows; + } + + private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata) + throws DatabricksSQLException { + if (metadata.getArrowSchema() != null) { + return metadata.getArrowSchema(); + } + Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema()); + try { + return SchemaUtility.serialize(arrowSchema); + } catch (IOException e) { + handleError(e); + } + return null; + } + + private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema) + throws DatabricksParsingException { + List fields = new ArrayList<>(); + if (hiveSchema == null) { + return new Schema(fields); + } + try { + hiveSchema + .getColumns() + .forEach( + columnDesc -> { + try { + fields.add(getArrowField(columnDesc)); + } catch (SQLException e) { + throw new RuntimeException(e); + } + }); + } catch (RuntimeException e) { + handleError(e); + } + return new Schema(fields); + } + + private Field getArrowField(TColumnDesc columnDesc) throws SQLException { + TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc()); + ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType()); + FieldType fieldType = new FieldType(true, arrowType, null); + return new Field(columnDesc.getColumnName(), fieldType, null); + } + + private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { + columnInfos = new ArrayList<>(); + if (resultManifest.getSchema() == null) { + return; + } + for (TColumnDesc tColumnDesc : resultManifest.getSchema().getColumns()) { + columnInfos.add( + com.databricks.jdbc.common.util.DatabricksThriftUtil.getColumnInfoFromTColumnDesc( + tColumnDesc)); + } + } + + @VisibleForTesting + void handleError(Exception e) throws DatabricksParsingException { + String errorMessage = + String.format("Cannot process lazy thrift inline arrow format. Error: %s", e.getMessage()); + LOGGER.error(errorMessage); + throw new DatabricksParsingException( + errorMessage, e, DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); + } + + /** + * Gets the total number of rows fetched from the server so far. + * + * @return the total number of rows fetched from the server + */ + public long getTotalRowsFetched() { + return totalRowsFetched; + } + + /** + * Checks if all data has been fetched from the server. + * + * @return true if all data has been fetched (either reached end or maxRows limit) + */ + public boolean isCompletelyFetched() { + return hasReachedEnd || !currentResponse.hasMoreRows; + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java b/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java index 2efb1e33a6..1e14615925 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java @@ -1,10 +1,10 @@ package com.databricks.jdbc.api.impl; -import static com.databricks.jdbc.TestConstants.ARROW_BATCH_LIST; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.when; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; @@ -128,14 +128,11 @@ public void testGetResultSet_thriftURL() throws SQLException { @Test public void testGetResultSet_thriftInlineArrow() throws SQLException { - when(connectionContext.getConnectionUuid()).thenReturn("sample-uuid"); when(resultSetMetadataResp.getResultFormat()).thenReturn(TSparkRowSetType.ARROW_BASED_SET); when(fetchResultsResp.getResultSetMetadata()).thenReturn(resultSetMetadataResp); when(fetchResultsResp.getResults()).thenReturn(tRowSet); - when(session.getConnectionContext()).thenReturn(connectionContext); - when(tRowSet.getArrowBatches()).thenReturn(ARROW_BATCH_LIST); IExecutionResult result = ExecutionResultFactory.getResultSet(fetchResultsResp, session, parentStatement); - assertInstanceOf(ArrowStreamResult.class, result); + assertInstanceOf(LazyThriftInlineArrowResult.class, result); } } diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java index 5f42fbdf13..9f2eb213a3 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java @@ -133,25 +133,6 @@ public void testIteration() throws Exception { assertFalse(result.next()); } - @Test - public void testInlineArrow() throws DatabricksSQLException { - IDatabricksConnectionContext connectionContext = - DatabricksConnectionContextFactory.create(JDBC_URL, new Properties()); - when(session.getConnectionContext()).thenReturn(connectionContext); - when(metadataResp.getSchema()).thenReturn(TEST_TABLE_SCHEMA); - when(fetchResultsResp.getResults()).thenReturn(resultData); - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp); - ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, true, parentStatement, session); - assertEquals(-1, result.getCurrentRow()); - assertTrue(result.hasNext()); - assertFalse(result.next()); - assertEquals(0, result.getCurrentRow()); - assertFalse(result.hasNext()); - assertDoesNotThrow(result::close); - assertFalse(result.hasNext()); - } - @Test public void testCloudFetchArrow() throws Exception { IDatabricksConnectionContext connectionContext = @@ -164,7 +145,7 @@ public void testCloudFetchArrow() throws Exception { when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp); when(parentStatement.getStatementId()).thenReturn(STATEMENT_ID); ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, false, parentStatement, session, mockHttpClient); + new ArrowStreamResult(fetchResultsResp, parentStatement, session, mockHttpClient); assertEquals(-1, result.getCurrentRow()); assertTrue(result.hasNext()); assertDoesNotThrow(result::close); diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java index 86be512d4d..8392daf683 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java @@ -1,27 +1,17 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.TestConstants.ARROW_BATCH_LIST; -import static com.databricks.jdbc.TestConstants.TEST_TABLE_SCHEMA; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.databricks.jdbc.api.internal.IDatabricksSession; -import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; -import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; -import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; -import com.databricks.jdbc.model.client.thrift.generated.TGetResultSetMetadataResp; -import com.databricks.jdbc.model.client.thrift.generated.TRowSet; -import com.databricks.jdbc.model.client.thrift.generated.TSparkArrowBatch; import com.databricks.jdbc.model.core.ColumnInfo; import com.databricks.jdbc.model.core.ColumnInfoTypeName; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.Collections; import net.jpountz.lz4.LZ4FrameOutputStream; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -37,41 +27,9 @@ public class InlineChunkProviderTest { private static final long TOTAL_ROWS = 2L; - @Mock TGetResultSetMetadataResp metadata; - @Mock TFetchResultsResp fetchResultsResp; - @Mock IDatabricksStatementInternal parentStatement; - @Mock IDatabricksSession session; @Mock private ResultData mockResultData; @Mock private ResultManifest mockResultManifest; - @Test - void testInitialisation() throws DatabricksParsingException { - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadata); - when(metadata.getArrowSchema()).thenReturn(null); - when(metadata.getSchema()).thenReturn(TEST_TABLE_SCHEMA); - when(fetchResultsResp.getResults()).thenReturn(new TRowSet().setArrowBatches(ARROW_BATCH_LIST)); - when(metadata.isSetLz4Compressed()).thenReturn(false); - InlineChunkProvider inlineChunkProvider = - new InlineChunkProvider(fetchResultsResp, parentStatement, session); - assertTrue(inlineChunkProvider.hasNextChunk()); - assertTrue(inlineChunkProvider.next()); - assertFalse(inlineChunkProvider.next()); - } - - @Test - void handleErrorTest() throws DatabricksParsingException { - TSparkArrowBatch arrowBatch = - new TSparkArrowBatch().setRowCount(0).setBatch(new byte[] {65, 66, 67}); - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadata); - when(fetchResultsResp.getResults()) - .thenReturn(new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch))); - InlineChunkProvider inlineChunkProvider = - new InlineChunkProvider(fetchResultsResp, parentStatement, session); - assertThrows( - DatabricksParsingException.class, - () -> inlineChunkProvider.handleError(new RuntimeException())); - } - @Test void testConstructorSuccessfulCreation() throws DatabricksSQLException, IOException { // Create valid Arrow data with two rows and one column: [1, 2] diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java new file mode 100644 index 0000000000..9c43d3e785 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java @@ -0,0 +1,285 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.TestConstants.TEST_TABLE_SCHEMA; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import java.io.IOException; +import java.util.Collections; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class LazyThriftInlineArrowResultTest { + + @Mock private IDatabricksSession session; + @Mock private IDatabricksStatementInternal statement; + private static final StatementId STATEMENT_ID = new StatementId("test_statement_id"); + private static final byte[] DUMMY_ARROW_BYTES = new byte[] {65, 66, 67}; + + private TFetchResultsResp createFetchResultsResp(int rowCount, boolean hasMoreRows) { + TSparkArrowBatch arrowBatch = + new TSparkArrowBatch().setRowCount(rowCount).setBatch(DUMMY_ARROW_BYTES); + TRowSet rowSet = new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch)); + + TGetResultSetMetadataResp metadata = + new TGetResultSetMetadataResp().setSchema(TEST_TABLE_SCHEMA); + + TFetchResultsResp response = + new TFetchResultsResp().setResultSetMetadata(metadata).setResults(rowSet); + response.hasMoreRows = hasMoreRows; + + return response; + } + + @Test + void testConstructorInitializesCorrectly() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + assertEquals(0, result.getRowCount()); + assertEquals(0, result.getTotalRowsFetched()); + assertFalse(result.hasNext()); + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testGetObjectThrowsWhenClosed() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + DatabricksSQLException exception = + assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); + assertEquals("Result is already closed", exception.getMessage()); + assertEquals(DatabricksDriverErrorCode.STATEMENT_CLOSED.name(), exception.getSQLState()); + } + + @Test + void testGetObjectThrowsWhenBeforeFirstRow() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + DatabricksSQLException exception = + assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); + assertEquals("Cursor is before first row", exception.getMessage()); + assertEquals(DatabricksDriverErrorCode.INVALID_STATE.name(), exception.getSQLState()); + } + + @Test + void testCloseReleasesResources() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + result.close(); + + assertFalse(result.hasNext()); + assertFalse(result.next()); + } + + @Test + void testIsCompletelyFetchedWhenNoMoreRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testIsCompletelyFetchedWithMoreRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, true); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.isCompletelyFetched()); + } + + @Test + void testGetChunkCount() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(0, result.getChunkCount()); + } + + @Test + void testHandleErrorThrowsParsingException() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + Exception testException = new IOException("Test error"); + DatabricksParsingException exception = + assertThrows(DatabricksParsingException.class, () -> result.handleError(testException)); + assertTrue(exception.getMessage().contains("Cannot process lazy thrift inline arrow format")); + assertEquals( + DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR.name(), exception.getSQLState()); + } + + @Test + void testEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + assertFalse(result.hasNext()); + assertFalse(result.next()); + assertEquals(0, result.getRowCount()); + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testNullStatement() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, null, session); + + assertEquals(-1, result.getCurrentRow()); + assertEquals(0, result.getRowCount()); + } + + @Test + void testGetCurrentRowBeforeNext() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(-1, result.getCurrentRow()); + } + + @Test + void testGetTotalRowsFetched() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(0, result.getTotalRowsFetched()); + } + + @Test + void testNextReturnsFalseOnEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.next()); + } + + @Test + void testHasNextReturnsFalseOnEmptyResultSet() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.hasNext()); + } + + @Test + void testNextReturnsFalseAfterClose() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + assertFalse(result.next()); + } + + @Test + void testHasNextReturnsFalseAfterClose() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + assertFalse(result.hasNext()); + } + + @Test + void testConstructorWithNullStatementUsesDefaultMaxRows() throws DatabricksSQLException { + TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, null, session); + + assertNotNull(result); + assertEquals(-1, result.getCurrentRow()); + } +} From 155cf8c65ba64a4895c0f6165b1398f585feea07 Mon Sep 17 00:00:00 2001 From: Jayant Singh Date: Wed, 7 Jan 2026 22:28:30 +0000 Subject: [PATCH 2/5] Address review comments --- .../impl/arrow/LazyThriftInlineArrowResult.java | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java index 08950339cf..4c01751e71 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java @@ -93,18 +93,22 @@ public LazyThriftInlineArrowResult( @Override public Object getObject(int columnIndex) throws DatabricksSQLException { if (isClosed) { + LOGGER.warn("Attempted to get object from closed result"); throw new DatabricksSQLException( "Result is already closed", DatabricksDriverErrorCode.STATEMENT_CLOSED); } if (globalRowIndex == -1) { + LOGGER.warn("Attempted to get object before calling next()"); throw new DatabricksSQLException( "Cursor is before first row", DatabricksDriverErrorCode.INVALID_STATE); } if (currentChunkIterator == null) { + LOGGER.warn("No current chunk available when getting object"); throw new DatabricksSQLException( "No current chunk available", DatabricksDriverErrorCode.INVALID_STATE); } if (columnIndex < 0 || columnIndex >= columnInfos.size()) { + LOGGER.warn("Column index {} out of bounds (size: {})", columnIndex, columnInfos.size()); throw new DatabricksSQLException( "Column index out of bounds " + columnIndex, DatabricksDriverErrorCode.INVALID_STATE); } @@ -261,6 +265,12 @@ private void loadCurrentChunk() throws DatabricksSQLException { "Loaded arrow chunk with {} rows, total fetched: {}", rowCount, totalRowsFetched); } catch (DatabricksParsingException e) { LOGGER.error("Failed to load current chunk: {}", e.getMessage()); + // Clean up any partially loaded chunk to prevent memory leaks + if (currentChunk != null) { + currentChunk.releaseChunk(); + currentChunk = null; + } + currentChunkIterator = null; hasReachedEnd = true; throw new DatabricksSQLException( "Failed to process arrow data", DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); @@ -410,7 +420,7 @@ void handleError(Exception e) throws DatabricksParsingException { * * @return the total number of rows fetched from the server */ - public long getTotalRowsFetched() { + long getTotalRowsFetched() { return totalRowsFetched; } @@ -419,7 +429,7 @@ public long getTotalRowsFetched() { * * @return true if all data has been fetched (either reached end or maxRows limit) */ - public boolean isCompletelyFetched() { + boolean isCompletelyFetched() { return hasReachedEnd || !currentResponse.hasMoreRows; } } From c2230c554bc39f77d0a7eb4a6f39743a26927890 Mon Sep 17 00:00:00 2001 From: Jayant Singh Date: Thu, 8 Jan 2026 07:34:31 +0000 Subject: [PATCH 3/5] Fix arrow metadata --- .../jdbc/api/impl/DatabricksResultSet.java | 3 +++ .../api/impl/arrow/LazyThriftInlineArrowResult.java | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java index 8b41cb5df2..3df89f8729 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java @@ -7,6 +7,7 @@ import com.databricks.jdbc.api.IExecutionStatus; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; import com.databricks.jdbc.api.impl.arrow.ChunkProvider; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.converters.ConverterHelper; import com.databricks.jdbc.api.impl.converters.ObjectConverter; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; @@ -155,6 +156,8 @@ public DatabricksResultSet( List arrowMetadata = null; if (executionResult instanceof ArrowStreamResult) { arrowMetadata = ((ArrowStreamResult) executionResult).getArrowMetadata(); + } else if (executionResult instanceof LazyThriftInlineArrowResult) { + arrowMetadata = ((LazyThriftInlineArrowResult) executionResult).getArrowMetadata(); } this.resultSetMetaData = new DatabricksResultSetMetaData( diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java index 4c01751e71..cba0b2ba9a 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java @@ -245,6 +245,19 @@ public long getChunkCount() { return 0; } + /** + * Gets the Arrow metadata for the current chunk. + * + * @return list of arrow metadata strings, or null if no chunk is loaded + * @throws DatabricksSQLException if an error occurs + */ + public List getArrowMetadata() throws DatabricksSQLException { + if (currentChunk == null) { + return null; + } + return currentChunk.getArrowMetadata(); + } + private void loadCurrentChunk() throws DatabricksSQLException { try { ByteArrayInputStream byteStream = createArrowByteStream(currentResponse); From 1df5d7adcc70e4a6f36ea3b71bf0e0e7a5e1b8c0 Mon Sep 17 00:00:00 2001 From: Jayant Singh Date: Wed, 14 Jan 2026 13:53:38 +0530 Subject: [PATCH 4/5] Address review comments and add tests --- .../arrow/LazyThriftInlineArrowResult.java | 31 +- .../LazyThriftInlineArrowResultTest.java | 355 ++++++++++++------ .../DatabricksEmptyMetadataClientTest.java | 35 ++ 3 files changed, 305 insertions(+), 116 deletions(-) diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java index cba0b2ba9a..12498fc3cb 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java @@ -92,6 +92,27 @@ public LazyThriftInlineArrowResult( */ @Override public Object getObject(int columnIndex) throws DatabricksSQLException { + validateGetObjectState(columnIndex); + + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); + String arrowMetadata = currentChunkIterator.getType(columnIndex); + if (arrowMetadata == null) { + arrowMetadata = columnInfo.getTypeText(); + } + + return ArrowStreamResult.getObjectWithComplexTypeHandling( + session, currentChunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); + } + + /** + * Validates the state before getting an object at the specified column index. + * + * @param columnIndex the zero-based column index to validate + * @throws DatabricksSQLException if the result is closed, cursor is invalid, or column index is + * out of bounds + */ + private void validateGetObjectState(int columnIndex) throws DatabricksSQLException { if (isClosed) { LOGGER.warn("Attempted to get object from closed result"); throw new DatabricksSQLException( @@ -112,16 +133,6 @@ public Object getObject(int columnIndex) throws DatabricksSQLException { throw new DatabricksSQLException( "Column index out of bounds " + columnIndex, DatabricksDriverErrorCode.INVALID_STATE); } - - ColumnInfo columnInfo = columnInfos.get(columnIndex); - ColumnInfoTypeName requiredType = columnInfo.getTypeName(); - String arrowMetadata = currentChunkIterator.getType(columnIndex); - if (arrowMetadata == null) { - arrowMetadata = columnInfo.getTypeText(); - } - - return ArrowStreamResult.getObjectWithComplexTypeHandling( - session, currentChunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); } /** diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java index 9c43d3e785..ddf4466644 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java @@ -1,18 +1,29 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.TestConstants.TEST_TABLE_SCHEMA; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; +import com.databricks.jdbc.api.impl.DatabricksConnectionContextFactory; +import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.dbclient.IDatabricksClient; import com.databricks.jdbc.dbclient.impl.common.StatementId; -import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.model.client.thrift.generated.*; import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; -import java.io.IOException; +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; +import java.util.Properties; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -21,18 +32,105 @@ @ExtendWith(MockitoExtension.class) public class LazyThriftInlineArrowResultTest { + private static final String JDBC_URL = + "jdbc:databricks://sample-host.18.azuredatabricks.net:9999/default;transportMode=http;ssl=1;" + + "AuthMech=3;httpPath=/sql/1.0/warehouses/99999999;"; + @Mock private IDatabricksSession session; @Mock private IDatabricksStatementInternal statement; + @Mock private IDatabricksClient databricksClient; + private IDatabricksConnectionContext connectionContext; private static final StatementId STATEMENT_ID = new StatementId("test_statement_id"); - private static final byte[] DUMMY_ARROW_BYTES = new byte[] {65, 66, 67}; + private static final TTableSchema TWO_COLUMN_SCHEMA = + createTableSchema(TTypeId.INT_TYPE, TTypeId.STRING_TYPE); + + @BeforeEach + void setUp() throws Exception { + connectionContext = DatabricksConnectionContextFactory.create(JDBC_URL, new Properties()); + } - private TFetchResultsResp createFetchResultsResp(int rowCount, boolean hasMoreRows) { - TSparkArrowBatch arrowBatch = - new TSparkArrowBatch().setRowCount(rowCount).setBatch(DUMMY_ARROW_BYTES); + /** + * Creates valid Arrow IPC format bytes with two columns (int and string). This creates a complete + * Arrow IPC stream (with embedded schema) that can be parsed by ArrowStreamReader. + * + *

Int column values: batch_index * 100 + row_index (e.g., 0, 1, 2... for batch 0) String + * column values: "row_{batch_index}_{row_index}" (e.g., "row_0_0", "row_0_1"...) + * + * @param batchCount Number of record batches to create + * @param rowsPerBatch Number of rows in each batch + * @return Valid Arrow IPC bytes + */ + private static byte[] createValidArrowData(int batchCount, int rowsPerBatch) { + try (BufferAllocator allocator = new RootAllocator(); + ByteArrayOutputStream out = new ByteArrayOutputStream()) { + + try (IntVector intVector = new IntVector("int_column", allocator); + VarCharVector stringVector = new VarCharVector("string_column", allocator)) { + + intVector.allocateNew(rowsPerBatch); + stringVector.allocateNew(rowsPerBatch); + + try (VectorSchemaRoot root = VectorSchemaRoot.of(intVector, stringVector); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + + for (int batch = 0; batch < batchCount; batch++) { + for (int i = 0; i < rowsPerBatch; i++) { + intVector.set(i, batch * 100 + i); + stringVector.setSafe(i, ("row_" + batch + "_" + i).getBytes()); + } + intVector.setValueCount(rowsPerBatch); + stringVector.setValueCount(rowsPerBatch); + root.setRowCount(rowsPerBatch); + writer.writeBatch(); + } + + writer.end(); + } + } + + return out.toByteArray(); + } catch (Exception e) { + throw new RuntimeException("Failed to create test Arrow data", e); + } + } + + /** + * Creates a TTableSchema with the specified column types. The schema structure matches what + * hiveSchemaToArrowSchema expects. + */ + private static TTableSchema createTableSchema(TTypeId... types) { + List columns = new ArrayList<>(); + for (int i = 0; i < types.length; i++) { + TPrimitiveTypeEntry primitiveType = new TPrimitiveTypeEntry().setType(types[i]); + TTypeEntry typeEntry = new TTypeEntry(); + typeEntry.setPrimitiveEntry(primitiveType); + TTypeDesc typeDesc = new TTypeDesc().setTypes(Collections.singletonList(typeEntry)); + TColumnDesc columnDesc = + new TColumnDesc().setColumnName("col_" + i).setTypeDesc(typeDesc).setPosition(i); + columns.add(columnDesc); + } + return new TTableSchema().setColumns(columns); + } + + /** + * Creates a TFetchResultsResp with valid Arrow data. Uses empty arrowSchema bytes so that the + * complete Arrow IPC stream (with embedded schema) is used as-is. + * + * @param arrowData Valid Arrow IPC bytes (should include schema) + * @param rowCount Row count for the batch + * @param hasMoreRows Whether there are more rows to fetch + * @return Configured TFetchResultsResp + */ + private TFetchResultsResp createFetchResultsResp( + byte[] arrowData, int rowCount, boolean hasMoreRows) { + TSparkArrowBatch arrowBatch = new TSparkArrowBatch().setRowCount(rowCount).setBatch(arrowData); TRowSet rowSet = new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch)); + // Use empty arrowSchema - this causes getSerializedSchema to return empty bytes, + // which effectively means nothing is prepended to our complete Arrow IPC stream TGetResultSetMetadataResp metadata = - new TGetResultSetMetadataResp().setSchema(TEST_TABLE_SCHEMA); + new TGetResultSetMetadataResp().setSchema(TWO_COLUMN_SCHEMA).setArrowSchema(new byte[0]); TFetchResultsResp response = new TFetchResultsResp().setResultSetMetadata(metadata).setResults(rowSet); @@ -42,8 +140,9 @@ private TFetchResultsResp createFetchResultsResp(int rowCount, boolean hasMoreRo } @Test - void testConstructorInitializesCorrectly() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testEmptyResultSet() throws DatabricksSQLException { + byte[] arrowData = createValidArrowData(1, 0); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, 0, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); @@ -51,16 +150,20 @@ void testConstructorInitializesCorrectly() throws DatabricksSQLException { LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); + // Verify all initial state for empty result assertEquals(-1, result.getCurrentRow()); assertEquals(0, result.getRowCount()); assertEquals(0, result.getTotalRowsFetched()); + assertEquals(0, result.getChunkCount()); assertFalse(result.hasNext()); + assertFalse(result.next()); assertTrue(result.isCompletelyFetched()); } @Test void testGetObjectThrowsWhenClosed() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + byte[] arrowData = createValidArrowData(1, 1); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, 1, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); @@ -69,6 +172,11 @@ void testGetObjectThrowsWhenClosed() throws DatabricksSQLException { new LazyThriftInlineArrowResult(initialResponse, statement, session); result.close(); + // Verify close behavior + assertFalse(result.hasNext()); + assertFalse(result.next()); + assertEquals(0, result.getRowCount()); + DatabricksSQLException exception = assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); assertEquals("Result is already closed", exception.getMessage()); @@ -77,7 +185,8 @@ void testGetObjectThrowsWhenClosed() throws DatabricksSQLException { @Test void testGetObjectThrowsWhenBeforeFirstRow() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + byte[] arrowData = createValidArrowData(1, 1); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, 1, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); @@ -92,8 +201,9 @@ void testGetObjectThrowsWhenBeforeFirstRow() throws DatabricksSQLException { } @Test - void testCloseReleasesResources() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testIsCompletelyFetchedWithMoreRows() throws DatabricksSQLException { + byte[] arrowData = createValidArrowData(1, 0); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, 0, true); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); @@ -101,15 +211,15 @@ void testCloseReleasesResources() throws DatabricksSQLException { LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - result.close(); - - assertFalse(result.hasNext()); - assertFalse(result.next()); + assertFalse(result.isCompletelyFetched()); + assertTrue(result.hasNext()); // hasNext is true because hasMoreRows is true } @Test - void testIsCompletelyFetchedWhenNoMoreRows() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testIterateThroughRowsWithValidArrowData() throws DatabricksSQLException { + int rowCount = 5; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); @@ -117,97 +227,139 @@ void testIsCompletelyFetchedWhenNoMoreRows() throws DatabricksSQLException { LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - assertTrue(result.isCompletelyFetched()); - } - - @Test - void testIsCompletelyFetchedWithMoreRows() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, true); - - when(statement.getMaxRows()).thenReturn(0); - when(statement.getStatementId()).thenReturn(STATEMENT_ID); + // Verify initial state + assertEquals(-1, result.getCurrentRow()); + assertTrue(result.hasNext()); - LazyThriftInlineArrowResult result = - new LazyThriftInlineArrowResult(initialResponse, statement, session); + // Iterate through all rows + for (int i = 0; i < rowCount; i++) { + assertTrue(result.hasNext(), "Should have next at row " + i); + assertTrue(result.next(), "next() should return true at row " + i); + assertEquals(i, result.getCurrentRow()); + } - assertFalse(result.isCompletelyFetched()); + // After all rows + assertFalse(result.hasNext()); + assertFalse(result.next()); + assertEquals(rowCount, result.getTotalRowsFetched()); } @Test - void testGetChunkCount() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testGetObjectReturnsCorrectIntegerValue() throws DatabricksSQLException { + int rowCount = 3; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); + when(session.getConnectionContext()).thenReturn(connectionContext); LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - assertEquals(0, result.getChunkCount()); + // Move to first row and get value + assertTrue(result.next()); + Object value = result.getObject(0); + assertNotNull(value); + assertInstanceOf(Integer.class, value); + assertEquals(0, value); // First row: batch_0 * 100 + row_0 = 0 + + // Move to second row + assertTrue(result.next()); + value = result.getObject(0); + assertEquals(1, value); // Second row: batch_0 * 100 + row_1 = 1 + + // Move to third row + assertTrue(result.next()); + value = result.getObject(0); + assertEquals(2, value); // Third row: batch_0 * 100 + row_2 = 2 } @Test - void testHandleErrorThrowsParsingException() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testGetObjectWithTwoColumns() throws DatabricksSQLException { + int rowCount = 2; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); + when(session.getConnectionContext()).thenReturn(connectionContext); LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - Exception testException = new IOException("Test error"); - DatabricksParsingException exception = - assertThrows(DatabricksParsingException.class, () -> result.handleError(testException)); - assertTrue(exception.getMessage().contains("Cannot process lazy thrift inline arrow format")); - assertEquals( - DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR.name(), exception.getSQLState()); + assertTrue(result.next()); + + // Get integer column + Object intValue = result.getObject(0); + assertNotNull(intValue); + assertInstanceOf(Integer.class, intValue); + assertEquals(0, intValue); + + // Get string column + Object stringValue = result.getObject(1); + assertNotNull(stringValue); + assertEquals("row_0_0", stringValue.toString()); } @Test - void testEmptyResultSet() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testGetObjectThrowsForColumnIndexOutOfBounds() throws DatabricksSQLException { + int rowCount = 1; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); + // Note: session.getConnectionContext() is not stubbed here because the column index + // validation happens before the connection context is accessed LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - assertEquals(-1, result.getCurrentRow()); - assertFalse(result.hasNext()); - assertFalse(result.next()); - assertEquals(0, result.getRowCount()); - assertTrue(result.isCompletelyFetched()); - } - - @Test - void testNullStatement() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + assertTrue(result.next()); - LazyThriftInlineArrowResult result = - new LazyThriftInlineArrowResult(initialResponse, null, session); + // Test negative index + DatabricksSQLException negativeException = + assertThrows(DatabricksSQLException.class, () -> result.getObject(-1)); + assertTrue(negativeException.getMessage().contains("Column index out of bounds")); + assertEquals(DatabricksDriverErrorCode.INVALID_STATE.name(), negativeException.getSQLState()); - assertEquals(-1, result.getCurrentRow()); - assertEquals(0, result.getRowCount()); + // Test index beyond column count (we have 2 columns: 0 and 1) + DatabricksSQLException beyondException = + assertThrows(DatabricksSQLException.class, () -> result.getObject(2)); + assertTrue(beyondException.getMessage().contains("Column index out of bounds")); + assertEquals(DatabricksDriverErrorCode.INVALID_STATE.name(), beyondException.getSQLState()); } @Test - void testGetCurrentRowBeforeNext() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testMaxRowsLimitEnforced() throws DatabricksSQLException { + int totalRows = 10; + int maxRows = 3; + byte[] arrowData = createValidArrowData(1, totalRows); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, totalRows, false); - when(statement.getMaxRows()).thenReturn(0); + when(statement.getMaxRows()).thenReturn(maxRows); when(statement.getStatementId()).thenReturn(STATEMENT_ID); LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - assertEquals(-1, result.getCurrentRow()); + // Should only be able to iterate up to maxRows + int rowsRetrieved = 0; + while (result.next()) { + rowsRetrieved++; + } + + assertEquals(maxRows, rowsRetrieved); + assertFalse(result.hasNext()); + assertEquals(maxRows - 1, result.getCurrentRow()); // 0-based index } @Test - void testGetTotalRowsFetched() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testGetArrowMetadataReturnsMetadata() throws DatabricksSQLException { + int rowCount = 1; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); @@ -215,71 +367,62 @@ void testGetTotalRowsFetched() throws DatabricksSQLException { LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - assertEquals(0, result.getTotalRowsFetched()); + List metadata = result.getArrowMetadata(); + assertNotNull(metadata); + // The metadata list should have one entry per column (2 columns: int and string) + assertEquals(2, metadata.size()); } @Test - void testNextReturnsFalseOnEmptyResultSet() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); - - when(statement.getMaxRows()).thenReturn(0); - when(statement.getStatementId()).thenReturn(STATEMENT_ID); - - LazyThriftInlineArrowResult result = - new LazyThriftInlineArrowResult(initialResponse, statement, session); + void testFetchNextChunkFromServer() throws DatabricksSQLException { + int rowsPerChunk = 2; + byte[] arrowData1 = createValidArrowData(1, rowsPerChunk); + byte[] arrowData2 = createValidArrowData(1, rowsPerChunk); - assertFalse(result.next()); - } + // First chunk with hasMoreRows = true + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData1, rowsPerChunk, true); - @Test - void testHasNextReturnsFalseOnEmptyResultSet() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + // Second chunk with hasMoreRows = false + TFetchResultsResp secondResponse = createFetchResultsResp(arrowData2, rowsPerChunk, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); + when(session.getDatabricksClient()).thenReturn(databricksClient); + when(databricksClient.getMoreResults(statement)).thenReturn(secondResponse); LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - assertFalse(result.hasNext()); - } - - @Test - void testNextReturnsFalseAfterClose() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + // Iterate through first chunk + assertTrue(result.next()); + assertTrue(result.next()); + assertFalse(result.isCompletelyFetched()); // Still has more rows - when(statement.getMaxRows()).thenReturn(0); - when(statement.getStatementId()).thenReturn(STATEMENT_ID); - - LazyThriftInlineArrowResult result = - new LazyThriftInlineArrowResult(initialResponse, statement, session); - result.close(); + // This should trigger fetch of next chunk + assertTrue(result.next()); + assertTrue(result.next()); + // After all rows assertFalse(result.next()); + assertTrue(result.isCompletelyFetched()); + assertEquals(rowsPerChunk * 2, result.getTotalRowsFetched()); + + // Verify that getMoreResults was called + verify(databricksClient).getMoreResults(statement); } @Test - void testHasNextReturnsFalseAfterClose() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); + void testGetRowCountReturnsCurrentChunkRowCount() throws DatabricksSQLException { + int rowCount = 5; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); when(statement.getMaxRows()).thenReturn(0); when(statement.getStatementId()).thenReturn(STATEMENT_ID); LazyThriftInlineArrowResult result = new LazyThriftInlineArrowResult(initialResponse, statement, session); - result.close(); - - assertFalse(result.hasNext()); - } - @Test - void testConstructorWithNullStatementUsesDefaultMaxRows() throws DatabricksSQLException { - TFetchResultsResp initialResponse = createFetchResultsResp(0, false); - - LazyThriftInlineArrowResult result = - new LazyThriftInlineArrowResult(initialResponse, null, session); - - assertNotNull(result); - assertEquals(-1, result.getCurrentRow()); + assertEquals(rowCount, result.getRowCount()); } } diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksEmptyMetadataClientTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksEmptyMetadataClientTest.java index dd2aca1ad6..1152900d27 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksEmptyMetadataClientTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksEmptyMetadataClientTest.java @@ -115,4 +115,39 @@ void testListPrimaryKeys() throws SQLException { assertEquals(resultSet.getMetaData().getColumnName(2), "TABLE_SCHEM"); assertEquals(resultSet.getMetaData().getColumnName(3), "TABLE_NAME"); } + + @Test + void testListImportedKeys() throws SQLException { + ResultSet resultSet = mockClient.listImportedKeys(session, "catalog", "schema", "table"); + assertNotNull(resultSet); + assertFalse(resultSet.next()); // empty result set + assertEquals(14, resultSet.getMetaData().getColumnCount()); + assertEquals("PKTABLE_CAT", resultSet.getMetaData().getColumnName(1)); + } + + @Test + void testListExportedKeys() throws SQLException { + ResultSet resultSet = mockClient.listExportedKeys(session, "catalog", "schema", "table"); + assertNotNull(resultSet); + assertFalse(resultSet.next()); // empty result set + assertEquals(14, resultSet.getMetaData().getColumnCount()); + assertEquals("PKTABLE_CAT", resultSet.getMetaData().getColumnName(1)); + } + + @Test + void testListCrossReferences() throws SQLException { + ResultSet resultSet = + mockClient.listCrossReferences( + session, + "parentCatalog", + "parentSchema", + "parentTable", + "foreignCatalog", + "foreignSchema", + "foreignTable"); + assertNotNull(resultSet); + assertFalse(resultSet.next()); // empty result set + assertEquals(14, resultSet.getMetaData().getColumnCount()); + assertEquals("PKTABLE_CAT", resultSet.getMetaData().getColumnName(1)); + } } From e0613785873ab450b18d6ff02295cce214003e36 Mon Sep 17 00:00:00 2001 From: Jayant Singh Date: Mon, 19 Jan 2026 11:01:39 +0530 Subject: [PATCH 5/5] next_changelog --- NEXT_CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index d27a25e263..b95b94ee70 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -5,6 +5,7 @@ ### Added ### Updated +- Implemented lazy loading for inline Arrow results, fetching arrow batches on demand instead of all at once. This improves memory usage and initial response time for large result sets when using the Thrift protocol with Arrow format. ### Fixed