diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 2c911672da..d59484cabd 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -5,6 +5,7 @@ ### Added ### Updated +- Implemented lazy loading for inline Arrow results, fetching arrow batches on demand instead of all at once. This improves memory usage and initial response time for large result sets when using the Thrift protocol with Arrow format. ### Fixed - Fixed complex data type metadata support when retrieving 0 rows in Arrow format diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java index 8b41cb5df2..3df89f8729 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksResultSet.java @@ -7,6 +7,7 @@ import com.databricks.jdbc.api.IExecutionStatus; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; import com.databricks.jdbc.api.impl.arrow.ChunkProvider; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.converters.ConverterHelper; import com.databricks.jdbc.api.impl.converters.ObjectConverter; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; @@ -155,6 +156,8 @@ public DatabricksResultSet( List arrowMetadata = null; if (executionResult instanceof ArrowStreamResult) { arrowMetadata = ((ArrowStreamResult) executionResult).getArrowMetadata(); + } else if (executionResult instanceof LazyThriftInlineArrowResult) { + arrowMetadata = ((LazyThriftInlineArrowResult) executionResult).getArrowMetadata(); } this.resultSetMetaData = new DatabricksResultSetMetaData( diff --git a/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java b/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java index c93a3f2879..6caa7135d2 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java +++ b/src/main/java/com/databricks/jdbc/api/impl/ExecutionResultFactory.java @@ -1,6 +1,7 @@ package com.databricks.jdbc.api.impl; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; import com.databricks.jdbc.api.internal.IDatabricksSession; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; @@ -96,9 +97,9 @@ private static IExecutionResult getResultHandler( case COLUMN_BASED_SET: return new LazyThriftResult(resultsResp, parentStatement, session); case ARROW_BASED_SET: - return new ArrowStreamResult(resultsResp, true, parentStatement, session); + return new LazyThriftInlineArrowResult(resultsResp, parentStatement, session); case URL_BASED_SET: - return new ArrowStreamResult(resultsResp, false, parentStatement, session); + return new ArrowStreamResult(resultsResp, parentStatement, session); case ROW_BASED_SET: throw new DatabricksSQLFeatureNotSupportedException( "Invalid state - row based set cannot be received"); diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java index 1d7067a3e8..61e7ba83fb 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResult.java @@ -143,13 +143,11 @@ private static ChunkProvider createRemoteChunkProvider( public ArrowStreamResult( TFetchResultsResp resultsResp, - boolean isInlineArrow, IDatabricksStatementInternal parentStatementId, IDatabricksSession session) throws DatabricksSQLException { this( resultsResp, - isInlineArrow, parentStatementId, session, DatabricksHttpClientFactory.getInstance().getClient(session.getConnectionContext())); @@ -158,19 +156,14 @@ public ArrowStreamResult( @VisibleForTesting ArrowStreamResult( TFetchResultsResp resultsResp, - boolean isInlineArrow, IDatabricksStatementInternal parentStatement, IDatabricksSession session, IDatabricksHttpClient httpClient) throws DatabricksSQLException { this.session = session; setColumnInfo(resultsResp.getResultSetMetadata()); - if (isInlineArrow) { - this.chunkProvider = new InlineChunkProvider(resultsResp, parentStatement, session); - } else { - this.chunkProvider = - createThriftRemoteChunkProvider(resultsResp, parentStatement, session, httpClient); - } + this.chunkProvider = + createThriftRemoteChunkProvider(resultsResp, parentStatement, session, httpClient); } /** @@ -238,48 +231,15 @@ public List getArrowMetadata() throws DatabricksSQLException { /** {@inheritDoc} */ @Override public Object getObject(int columnIndex) throws DatabricksSQLException { - ColumnInfoTypeName requiredType = columnInfos.get(columnIndex).getTypeName(); + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); String arrowMetadata = chunkIterator.getType(columnIndex); if (arrowMetadata == null) { - arrowMetadata = columnInfos.get(columnIndex).getTypeText(); + arrowMetadata = columnInfo.getTypeText(); } - // Handle complex type conversion when complex datatype support is disabled - boolean isComplexDatatypeSupportEnabled = - this.session.getConnectionContext().isComplexDatatypeSupportEnabled(); - boolean isGeoSpatialSupportEnabled = - this.session.getConnectionContext().isGeoSpatialSupportEnabled(); - - // Check if we need to convert geospatial types to string when geospatial support is disabled - // This check must come before the general complex type check - if (!isGeoSpatialSupportEnabled && isGeospatialType(requiredType)) { - LOGGER.debug("Geospatial support is disabled, converting {} to STRING", requiredType); - - Object result = - chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfos.get(columnIndex)); - if (result == null) { - return null; - } - // Return raw string for geospatial types when support is disabled - return result; - } - - if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) { - LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING"); - - Object result = - chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfos.get(columnIndex)); - if (result == null) { - return null; - } - ComplexDataTypeParser parser = new ComplexDataTypeParser(); - return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata); - } - - return chunkIterator.getColumnObjectAtCurrentRow( - columnIndex, requiredType, arrowMetadata, columnInfos.get(columnIndex)); + return getObjectWithComplexTypeHandling( + session, chunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); } /** @@ -384,6 +344,66 @@ private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { } } + /** + * Helper method to handle complex type and geospatial type conversion when support is disabled. + * + *

This method is also used by LazyThriftInlineArrowResult for consistent type handling. + * + * @param session The databricks session + * @param chunkIterator The chunk iterator + * @param columnIndex The column index + * @param requiredType The required column type + * @param arrowMetadata The arrow metadata + * @param columnInfo The column info + * @return The object value (converted if complex/geospatial type and support disabled) + * @throws DatabricksSQLException if an error occurs + */ + protected static Object getObjectWithComplexTypeHandling( + IDatabricksSession session, + ArrowResultChunkIterator chunkIterator, + int columnIndex, + ColumnInfoTypeName requiredType, + String arrowMetadata, + ColumnInfo columnInfo) + throws DatabricksSQLException { + boolean isComplexDatatypeSupportEnabled = + session.getConnectionContext().isComplexDatatypeSupportEnabled(); + boolean isGeoSpatialSupportEnabled = + session.getConnectionContext().isGeoSpatialSupportEnabled(); + + // Check if we need to convert geospatial types to string when geospatial support is disabled + // This check must come before the general complex type check + if (!isGeoSpatialSupportEnabled && isGeospatialType(requiredType)) { + LOGGER.debug("Geospatial support is disabled, converting {} to STRING", requiredType); + + Object result = + chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfo); + if (result == null) { + return null; + } + // Return raw string for geospatial types when support is disabled + return result; + } + + if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) { + LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING"); + + Object result = + chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfo); + if (result == null) { + return null; + } + ComplexDataTypeParser parser = new ComplexDataTypeParser(); + + return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata); + } + + return chunkIterator.getColumnObjectAtCurrentRow( + columnIndex, requiredType, arrowMetadata, columnInfo); + } + /** * Converts a collection of ExternalLinks to a ChunkLinkFetchResult. * diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java index e22d974a40..32f5e1b803 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProvider.java @@ -1,31 +1,17 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*; import static com.databricks.jdbc.common.util.DecompressionUtil.decompress; -import com.databricks.jdbc.api.internal.IDatabricksSession; -import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; import com.databricks.jdbc.log.JdbcLogger; import com.databricks.jdbc.log.JdbcLoggerFactory; -import com.databricks.jdbc.model.client.thrift.generated.*; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; import com.google.common.annotations.VisibleForTesting; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.arrow.vector.util.SchemaUtility; /** Class to manage inline Arrow chunks */ public class InlineChunkProvider implements ChunkProvider { @@ -37,23 +23,6 @@ public class InlineChunkProvider implements ChunkProvider { private final ArrowResultChunk arrowResultChunk; // There is only one packet of data in case of inline arrow - InlineChunkProvider( - TFetchResultsResp resultsResp, - IDatabricksStatementInternal parentStatement, - IDatabricksSession session) - throws DatabricksParsingException { - this.currentChunkIndex = -1; - this.totalRows = 0; - ByteArrayInputStream byteStream = initializeByteStream(resultsResp, session, parentStatement); - ArrowResultChunk.Builder builder = - ArrowResultChunk.builder().withInputStream(byteStream, totalRows); - - if (parentStatement != null) { - builder.withStatementId(parentStatement.getStatementId()); - } - arrowResultChunk = builder.build(); - } - /** * Constructor for inline arrow chunk provider from {@link ResultData} and {@link ResultManifest}. * @@ -123,97 +92,6 @@ public boolean isClosed() { return isClosed; } - private ByteArrayInputStream initializeByteStream( - TFetchResultsResp resultsResp, - IDatabricksSession session, - IDatabricksStatementInternal parentStatement) - throws DatabricksParsingException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - CompressionCodec compressionType = - CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); - try { - byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata()); - if (serializedSchema != null) { - baos.write(serializedSchema); - } - writeToByteOutputStream( - compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos); - while (resultsResp.hasMoreRows) { - resultsResp = session.getDatabricksClient().getMoreResults(parentStatement); - writeToByteOutputStream( - compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos); - } - return new ByteArrayInputStream(baos.toByteArray()); - } catch (DatabricksSQLException | IOException e) { - handleError(e); - } - return null; - } - - private void writeToByteOutputStream( - CompressionCodec compressionCodec, - IDatabricksStatementInternal parentStatement, - List arrowBatchList, - ByteArrayOutputStream baos) - throws DatabricksSQLException, IOException { - for (TSparkArrowBatch arrowBatch : arrowBatchList) { - byte[] decompressedBytes = - decompress( - arrowBatch.getBatch(), - compressionCodec, - String.format( - "Data fetch for inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]", - arrowBatch.getRowCount(), parentStatement, compressionCodec)); - totalRows += arrowBatch.getRowCount(); - baos.write(decompressedBytes); - } - } - - private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata) - throws DatabricksSQLException { - if (metadata.getArrowSchema() != null) { - return metadata.getArrowSchema(); - } - Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema()); - try { - return SchemaUtility.serialize(arrowSchema); - } catch (IOException e) { - handleError(e); - } - // should never reach here; - return null; - } - - private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema) - throws DatabricksParsingException { - List fields = new ArrayList<>(); - if (hiveSchema == null) { - return new Schema(fields); - } - try { - hiveSchema - .getColumns() - .forEach( - columnDesc -> { - try { - fields.add(getArrowField(columnDesc)); - } catch (SQLException e) { - throw new RuntimeException(e); - } - }); - } catch (RuntimeException e) { - handleError(e); - } - return new Schema(fields); - } - - private Field getArrowField(TColumnDesc columnDesc) throws SQLException { - TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc()); - ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType()); - FieldType fieldType = new FieldType(true, arrowType, null); - return new Field(columnDesc.getColumnName(), fieldType, null); - } - @VisibleForTesting void handleError(Exception e) throws DatabricksParsingException { String errorMessage = diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java new file mode 100644 index 0000000000..12498fc3cb --- /dev/null +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResult.java @@ -0,0 +1,459 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static com.databricks.jdbc.common.EnvironmentVariables.DEFAULT_RESULT_ROW_LIMIT; +import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*; +import static com.databricks.jdbc.common.util.DecompressionUtil.decompress; + +import com.databricks.jdbc.api.impl.IExecutionResult; +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.common.CompressionCodec; +import com.databricks.jdbc.exception.DatabricksParsingException; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.log.JdbcLogger; +import com.databricks.jdbc.log.JdbcLoggerFactory; +import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.core.ColumnInfo; +import com.databricks.jdbc.model.core.ColumnInfoTypeName; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import com.google.common.annotations.VisibleForTesting; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.SchemaUtility; + +/** + * Lazy implementation for thrift-based inline Arrow results that fetches arrow batches on demand. + * Similar to LazyThriftResult but processes Arrow data instead of columnar thrift data. + */ +public class LazyThriftInlineArrowResult implements IExecutionResult { + private static final JdbcLogger LOGGER = + JdbcLoggerFactory.getLogger(LazyThriftInlineArrowResult.class); + + private TFetchResultsResp currentResponse; + private ArrowResultChunk currentChunk; + private ArrowResultChunkIterator currentChunkIterator; + private long globalRowIndex; + private final IDatabricksSession session; + private final IDatabricksStatementInternal statement; + private final int maxRows; + private boolean hasReachedEnd; + private boolean isClosed; + private long totalRowsFetched; + private List columnInfos; + + /** + * Creates a new LazyThriftInlineArrowResult that lazily fetches arrow data on demand. + * + * @param initialResponse the initial response from the server + * @param statement the statement that generated this result + * @param session the session to use for fetching additional data + * @throws DatabricksSQLException if the initial response cannot be processed + */ + public LazyThriftInlineArrowResult( + TFetchResultsResp initialResponse, + IDatabricksStatementInternal statement, + IDatabricksSession session) + throws DatabricksSQLException { + this.currentResponse = initialResponse; + this.statement = statement; + this.session = session; + this.maxRows = statement != null ? statement.getMaxRows() : DEFAULT_RESULT_ROW_LIMIT; + this.globalRowIndex = -1; + this.hasReachedEnd = false; + this.isClosed = false; + this.totalRowsFetched = 0; + + // Initialize column info from metadata + setColumnInfo(initialResponse.getResultSetMetadata()); + + // Load initial chunk + loadCurrentChunk(); + LOGGER.debug( + "LazyThriftInlineArrowResult initialized with {} rows in first chunk, hasMoreRows: {}", + currentChunk.numRows, + currentResponse.hasMoreRows); + } + + /** + * Gets the value at the specified column index for the current row. + * + * @param columnIndex the zero-based column index + * @return the value at the specified column + * @throws DatabricksSQLException if the result is closed, cursor is invalid, or column index is + * out of bounds + */ + @Override + public Object getObject(int columnIndex) throws DatabricksSQLException { + validateGetObjectState(columnIndex); + + ColumnInfo columnInfo = columnInfos.get(columnIndex); + ColumnInfoTypeName requiredType = columnInfo.getTypeName(); + String arrowMetadata = currentChunkIterator.getType(columnIndex); + if (arrowMetadata == null) { + arrowMetadata = columnInfo.getTypeText(); + } + + return ArrowStreamResult.getObjectWithComplexTypeHandling( + session, currentChunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo); + } + + /** + * Validates the state before getting an object at the specified column index. + * + * @param columnIndex the zero-based column index to validate + * @throws DatabricksSQLException if the result is closed, cursor is invalid, or column index is + * out of bounds + */ + private void validateGetObjectState(int columnIndex) throws DatabricksSQLException { + if (isClosed) { + LOGGER.warn("Attempted to get object from closed result"); + throw new DatabricksSQLException( + "Result is already closed", DatabricksDriverErrorCode.STATEMENT_CLOSED); + } + if (globalRowIndex == -1) { + LOGGER.warn("Attempted to get object before calling next()"); + throw new DatabricksSQLException( + "Cursor is before first row", DatabricksDriverErrorCode.INVALID_STATE); + } + if (currentChunkIterator == null) { + LOGGER.warn("No current chunk available when getting object"); + throw new DatabricksSQLException( + "No current chunk available", DatabricksDriverErrorCode.INVALID_STATE); + } + if (columnIndex < 0 || columnIndex >= columnInfos.size()) { + LOGGER.warn("Column index {} out of bounds (size: {})", columnIndex, columnInfos.size()); + throw new DatabricksSQLException( + "Column index out of bounds " + columnIndex, DatabricksDriverErrorCode.INVALID_STATE); + } + } + + /** + * Gets the current row index (0-based). Returns -1 if before the first row. + * + * @return the current row index + */ + @Override + public long getCurrentRow() { + return globalRowIndex; + } + + /** + * Moves the cursor to the next row. Fetches additional data from server if needed. + * + * @return true if there is a next row, false if at the end + * @throws DatabricksSQLException if an error occurs while fetching data + */ + @Override + public boolean next() throws DatabricksSQLException { + if (isClosed || hasReachedEnd) { + return false; + } + + if (!hasNext()) { + return false; + } + + // Check if we've reached the maxRows limit + boolean hasRowLimit = maxRows > 0; + if (hasRowLimit && globalRowIndex + 1 >= maxRows) { + hasReachedEnd = true; + return false; + } + + // Try to advance in current chunk + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + boolean advanced = currentChunkIterator.nextRow(); + if (advanced) { + globalRowIndex++; + return true; + } + } + + // Need to fetch next chunk + while (currentResponse.hasMoreRows) { + fetchNextChunk(); + + // If we got a chunk with data, advance to first row + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + boolean advanced = currentChunkIterator.nextRow(); + if (advanced) { + globalRowIndex++; + return true; + } + } + } + + // No more data available + hasReachedEnd = true; + return false; + } + + /** + * Checks if there are more rows available without advancing the cursor. + * + * @return true if there are more rows, false otherwise + */ + @Override + public boolean hasNext() { + if (isClosed || hasReachedEnd) { + return false; + } + + // Check maxRows limit + boolean hasRowLimit = maxRows > 0; + if (hasRowLimit && globalRowIndex + 1 >= maxRows) { + return false; + } + + // Check if there are more rows in current chunk + if (currentChunkIterator != null && currentChunkIterator.hasNextRow()) { + return true; + } + + // Check if there are more chunks to fetch + return currentResponse.hasMoreRows; + } + + /** Closes this result and releases associated resources. */ + @Override + public void close() { + this.isClosed = true; + if (currentChunk != null) { + currentChunk.releaseChunk(); + } + this.currentChunk = null; + this.currentChunkIterator = null; + this.currentResponse = null; + LOGGER.debug( + "LazyThriftInlineArrowResult closed after fetching {} total rows", totalRowsFetched); + } + + /** + * Gets the number of rows in the current chunk. + * + * @return the number of rows in the current chunk + */ + @Override + public long getRowCount() { + return currentChunk != null ? currentChunk.numRows : 0; + } + + /** + * Gets the chunk count. Always returns 0 for lazy thrift inline arrow results. + * + * @return 0 (lazy results don't use chunks in the same sense as buffered results) + */ + @Override + public long getChunkCount() { + return 0; + } + + /** + * Gets the Arrow metadata for the current chunk. + * + * @return list of arrow metadata strings, or null if no chunk is loaded + * @throws DatabricksSQLException if an error occurs + */ + public List getArrowMetadata() throws DatabricksSQLException { + if (currentChunk == null) { + return null; + } + return currentChunk.getArrowMetadata(); + } + + private void loadCurrentChunk() throws DatabricksSQLException { + try { + ByteArrayInputStream byteStream = createArrowByteStream(currentResponse); + long rowCount = getTotalRowsInResponse(currentResponse); + + ArrowResultChunk.Builder builder = + ArrowResultChunk.builder().withInputStream(byteStream, rowCount); + + if (statement != null) { + builder.withStatementId(statement.getStatementId()); + } + + currentChunk = builder.build(); + currentChunkIterator = currentChunk.getChunkIterator(); + totalRowsFetched += rowCount; + + LOGGER.debug( + "Loaded arrow chunk with {} rows, total fetched: {}", rowCount, totalRowsFetched); + } catch (DatabricksParsingException e) { + LOGGER.error("Failed to load current chunk: {}", e.getMessage()); + // Clean up any partially loaded chunk to prevent memory leaks + if (currentChunk != null) { + currentChunk.releaseChunk(); + currentChunk = null; + } + currentChunkIterator = null; + hasReachedEnd = true; + throw new DatabricksSQLException( + "Failed to process arrow data", DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); + } + } + + /** + * Fetches the next chunk of data from the server and creates arrow chunks. + * + * @throws DatabricksSQLException if the fetch operation fails + */ + private void fetchNextChunk() throws DatabricksSQLException { + try { + LOGGER.debug("Fetching next arrow chunk, current total rows fetched: {}", totalRowsFetched); + currentResponse = session.getDatabricksClient().getMoreResults(statement); + + // Release previous chunk to free memory + if (currentChunk != null) { + currentChunk.releaseChunk(); + } + + loadCurrentChunk(); + + LOGGER.debug( + "Fetched arrow chunk with {} rows, hasMoreRows: {}", + currentChunk.numRows, + currentResponse.hasMoreRows); + } catch (DatabricksSQLException e) { + LOGGER.error("Failed to fetch next arrow chunk: {}", e.getMessage()); + hasReachedEnd = true; + throw e; + } + } + + private ByteArrayInputStream createArrowByteStream(TFetchResultsResp resultsResp) + throws DatabricksParsingException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + CompressionCodec compressionType = + CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata()); + try { + byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata()); + if (serializedSchema != null) { + baos.write(serializedSchema); + } + writeArrowBatchesToStream(compressionType, resultsResp.getResults().getArrowBatches(), baos); + return new ByteArrayInputStream(baos.toByteArray()); + } catch (DatabricksSQLException | IOException e) { + handleError(e); + } + return null; + } + + private void writeArrowBatchesToStream( + CompressionCodec compressionCodec, + List arrowBatchList, + ByteArrayOutputStream baos) + throws DatabricksSQLException, IOException { + for (TSparkArrowBatch arrowBatch : arrowBatchList) { + byte[] decompressedBytes = + decompress( + arrowBatch.getBatch(), + compressionCodec, + String.format( + "Data fetch for lazy inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]", + arrowBatch.getRowCount(), statement, compressionCodec)); + baos.write(decompressedBytes); + } + } + + private long getTotalRowsInResponse(TFetchResultsResp resultsResp) { + long totalRows = 0; + if (resultsResp.getResults() != null && resultsResp.getResults().getArrowBatches() != null) { + for (TSparkArrowBatch arrowBatch : resultsResp.getResults().getArrowBatches()) { + totalRows += arrowBatch.getRowCount(); + } + } + return totalRows; + } + + private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata) + throws DatabricksSQLException { + if (metadata.getArrowSchema() != null) { + return metadata.getArrowSchema(); + } + Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema()); + try { + return SchemaUtility.serialize(arrowSchema); + } catch (IOException e) { + handleError(e); + } + return null; + } + + private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema) + throws DatabricksParsingException { + List fields = new ArrayList<>(); + if (hiveSchema == null) { + return new Schema(fields); + } + try { + hiveSchema + .getColumns() + .forEach( + columnDesc -> { + try { + fields.add(getArrowField(columnDesc)); + } catch (SQLException e) { + throw new RuntimeException(e); + } + }); + } catch (RuntimeException e) { + handleError(e); + } + return new Schema(fields); + } + + private Field getArrowField(TColumnDesc columnDesc) throws SQLException { + TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc()); + ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType()); + FieldType fieldType = new FieldType(true, arrowType, null); + return new Field(columnDesc.getColumnName(), fieldType, null); + } + + private void setColumnInfo(TGetResultSetMetadataResp resultManifest) { + columnInfos = new ArrayList<>(); + if (resultManifest.getSchema() == null) { + return; + } + for (TColumnDesc tColumnDesc : resultManifest.getSchema().getColumns()) { + columnInfos.add( + com.databricks.jdbc.common.util.DatabricksThriftUtil.getColumnInfoFromTColumnDesc( + tColumnDesc)); + } + } + + @VisibleForTesting + void handleError(Exception e) throws DatabricksParsingException { + String errorMessage = + String.format("Cannot process lazy thrift inline arrow format. Error: %s", e.getMessage()); + LOGGER.error(errorMessage); + throw new DatabricksParsingException( + errorMessage, e, DatabricksDriverErrorCode.INLINE_CHUNK_PARSING_ERROR); + } + + /** + * Gets the total number of rows fetched from the server so far. + * + * @return the total number of rows fetched from the server + */ + long getTotalRowsFetched() { + return totalRowsFetched; + } + + /** + * Checks if all data has been fetched from the server. + * + * @return true if all data has been fetched (either reached end or maxRows limit) + */ + boolean isCompletelyFetched() { + return hasReachedEnd || !currentResponse.hasMoreRows; + } +} diff --git a/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java b/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java index 2efb1e33a6..1e14615925 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/ExecutionResultFactoryTest.java @@ -1,10 +1,10 @@ package com.databricks.jdbc.api.impl; -import static com.databricks.jdbc.TestConstants.ARROW_BATCH_LIST; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.when; import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult; +import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult; import com.databricks.jdbc.api.impl.volume.VolumeOperationResult; import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; @@ -128,14 +128,11 @@ public void testGetResultSet_thriftURL() throws SQLException { @Test public void testGetResultSet_thriftInlineArrow() throws SQLException { - when(connectionContext.getConnectionUuid()).thenReturn("sample-uuid"); when(resultSetMetadataResp.getResultFormat()).thenReturn(TSparkRowSetType.ARROW_BASED_SET); when(fetchResultsResp.getResultSetMetadata()).thenReturn(resultSetMetadataResp); when(fetchResultsResp.getResults()).thenReturn(tRowSet); - when(session.getConnectionContext()).thenReturn(connectionContext); - when(tRowSet.getArrowBatches()).thenReturn(ARROW_BATCH_LIST); IExecutionResult result = ExecutionResultFactory.getResultSet(fetchResultsResp, session, parentStatement); - assertInstanceOf(ArrowStreamResult.class, result); + assertInstanceOf(LazyThriftInlineArrowResult.class, result); } } diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java index 6e32861fec..3ede51f430 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowStreamResultTest.java @@ -133,25 +133,6 @@ public void testIteration() throws Exception { assertFalse(result.next()); } - @Test - public void testInlineArrow() throws DatabricksSQLException { - IDatabricksConnectionContext connectionContext = - DatabricksConnectionContextFactory.create(JDBC_URL, new Properties()); - when(session.getConnectionContext()).thenReturn(connectionContext); - when(metadataResp.getSchema()).thenReturn(TEST_TABLE_SCHEMA); - when(fetchResultsResp.getResults()).thenReturn(resultData); - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp); - ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, true, parentStatement, session); - assertEquals(-1, result.getCurrentRow()); - assertTrue(result.hasNext()); - assertFalse(result.next()); - assertEquals(0, result.getCurrentRow()); - assertFalse(result.hasNext()); - assertDoesNotThrow(result::close); - assertFalse(result.hasNext()); - } - @Test public void testCloudFetchArrow() throws Exception { IDatabricksConnectionContext connectionContext = @@ -164,7 +145,7 @@ public void testCloudFetchArrow() throws Exception { when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadataResp); when(parentStatement.getStatementId()).thenReturn(STATEMENT_ID); ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, false, parentStatement, session, mockHttpClient); + new ArrowStreamResult(fetchResultsResp, parentStatement, session, mockHttpClient); assertEquals(-1, result.getCurrentRow()); assertTrue(result.hasNext()); assertDoesNotThrow(result::close); @@ -586,7 +567,7 @@ public void testStreamingChunkProviderEnabledForThriftResult() throws Exception when(mockHttpClient.execute(isA(HttpUriRequest.class), eq(true))).thenReturn(httpResponse); ArrowStreamResult result = - new ArrowStreamResult(fetchResultsResp, false, parentStatement, session, mockHttpClient); + new ArrowStreamResult(fetchResultsResp, parentStatement, session, mockHttpClient); // Verify result was created successfully with StreamingChunkProvider for Thrift assertNotNull(result); diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java index 86be512d4d..8392daf683 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/InlineChunkProviderTest.java @@ -1,27 +1,17 @@ package com.databricks.jdbc.api.impl.arrow; -import static com.databricks.jdbc.TestConstants.ARROW_BATCH_LIST; -import static com.databricks.jdbc.TestConstants.TEST_TABLE_SCHEMA; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.databricks.jdbc.api.internal.IDatabricksSession; -import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; import com.databricks.jdbc.common.CompressionCodec; -import com.databricks.jdbc.exception.DatabricksParsingException; import com.databricks.jdbc.exception.DatabricksSQLException; -import com.databricks.jdbc.model.client.thrift.generated.TFetchResultsResp; -import com.databricks.jdbc.model.client.thrift.generated.TGetResultSetMetadataResp; -import com.databricks.jdbc.model.client.thrift.generated.TRowSet; -import com.databricks.jdbc.model.client.thrift.generated.TSparkArrowBatch; import com.databricks.jdbc.model.core.ColumnInfo; import com.databricks.jdbc.model.core.ColumnInfoTypeName; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.Collections; import net.jpountz.lz4.LZ4FrameOutputStream; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -37,41 +27,9 @@ public class InlineChunkProviderTest { private static final long TOTAL_ROWS = 2L; - @Mock TGetResultSetMetadataResp metadata; - @Mock TFetchResultsResp fetchResultsResp; - @Mock IDatabricksStatementInternal parentStatement; - @Mock IDatabricksSession session; @Mock private ResultData mockResultData; @Mock private ResultManifest mockResultManifest; - @Test - void testInitialisation() throws DatabricksParsingException { - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadata); - when(metadata.getArrowSchema()).thenReturn(null); - when(metadata.getSchema()).thenReturn(TEST_TABLE_SCHEMA); - when(fetchResultsResp.getResults()).thenReturn(new TRowSet().setArrowBatches(ARROW_BATCH_LIST)); - when(metadata.isSetLz4Compressed()).thenReturn(false); - InlineChunkProvider inlineChunkProvider = - new InlineChunkProvider(fetchResultsResp, parentStatement, session); - assertTrue(inlineChunkProvider.hasNextChunk()); - assertTrue(inlineChunkProvider.next()); - assertFalse(inlineChunkProvider.next()); - } - - @Test - void handleErrorTest() throws DatabricksParsingException { - TSparkArrowBatch arrowBatch = - new TSparkArrowBatch().setRowCount(0).setBatch(new byte[] {65, 66, 67}); - when(fetchResultsResp.getResultSetMetadata()).thenReturn(metadata); - when(fetchResultsResp.getResults()) - .thenReturn(new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch))); - InlineChunkProvider inlineChunkProvider = - new InlineChunkProvider(fetchResultsResp, parentStatement, session); - assertThrows( - DatabricksParsingException.class, - () -> inlineChunkProvider.handleError(new RuntimeException())); - } - @Test void testConstructorSuccessfulCreation() throws DatabricksSQLException, IOException { // Create valid Arrow data with two rows and one column: [1, 2] diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java new file mode 100644 index 0000000000..ddf4466644 --- /dev/null +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/LazyThriftInlineArrowResultTest.java @@ -0,0 +1,428 @@ +package com.databricks.jdbc.api.impl.arrow; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import com.databricks.jdbc.api.impl.DatabricksConnectionContextFactory; +import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; +import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.api.internal.IDatabricksStatementInternal; +import com.databricks.jdbc.dbclient.IDatabricksClient; +import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.exception.DatabricksSQLException; +import com.databricks.jdbc.model.client.thrift.generated.*; +import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode; +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class LazyThriftInlineArrowResultTest { + + private static final String JDBC_URL = + "jdbc:databricks://sample-host.18.azuredatabricks.net:9999/default;transportMode=http;ssl=1;" + + "AuthMech=3;httpPath=/sql/1.0/warehouses/99999999;"; + + @Mock private IDatabricksSession session; + @Mock private IDatabricksStatementInternal statement; + @Mock private IDatabricksClient databricksClient; + private IDatabricksConnectionContext connectionContext; + private static final StatementId STATEMENT_ID = new StatementId("test_statement_id"); + private static final TTableSchema TWO_COLUMN_SCHEMA = + createTableSchema(TTypeId.INT_TYPE, TTypeId.STRING_TYPE); + + @BeforeEach + void setUp() throws Exception { + connectionContext = DatabricksConnectionContextFactory.create(JDBC_URL, new Properties()); + } + + /** + * Creates valid Arrow IPC format bytes with two columns (int and string). This creates a complete + * Arrow IPC stream (with embedded schema) that can be parsed by ArrowStreamReader. + * + *

Int column values: batch_index * 100 + row_index (e.g., 0, 1, 2... for batch 0) String + * column values: "row_{batch_index}_{row_index}" (e.g., "row_0_0", "row_0_1"...) + * + * @param batchCount Number of record batches to create + * @param rowsPerBatch Number of rows in each batch + * @return Valid Arrow IPC bytes + */ + private static byte[] createValidArrowData(int batchCount, int rowsPerBatch) { + try (BufferAllocator allocator = new RootAllocator(); + ByteArrayOutputStream out = new ByteArrayOutputStream()) { + + try (IntVector intVector = new IntVector("int_column", allocator); + VarCharVector stringVector = new VarCharVector("string_column", allocator)) { + + intVector.allocateNew(rowsPerBatch); + stringVector.allocateNew(rowsPerBatch); + + try (VectorSchemaRoot root = VectorSchemaRoot.of(intVector, stringVector); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + + for (int batch = 0; batch < batchCount; batch++) { + for (int i = 0; i < rowsPerBatch; i++) { + intVector.set(i, batch * 100 + i); + stringVector.setSafe(i, ("row_" + batch + "_" + i).getBytes()); + } + intVector.setValueCount(rowsPerBatch); + stringVector.setValueCount(rowsPerBatch); + root.setRowCount(rowsPerBatch); + writer.writeBatch(); + } + + writer.end(); + } + } + + return out.toByteArray(); + } catch (Exception e) { + throw new RuntimeException("Failed to create test Arrow data", e); + } + } + + /** + * Creates a TTableSchema with the specified column types. The schema structure matches what + * hiveSchemaToArrowSchema expects. + */ + private static TTableSchema createTableSchema(TTypeId... types) { + List columns = new ArrayList<>(); + for (int i = 0; i < types.length; i++) { + TPrimitiveTypeEntry primitiveType = new TPrimitiveTypeEntry().setType(types[i]); + TTypeEntry typeEntry = new TTypeEntry(); + typeEntry.setPrimitiveEntry(primitiveType); + TTypeDesc typeDesc = new TTypeDesc().setTypes(Collections.singletonList(typeEntry)); + TColumnDesc columnDesc = + new TColumnDesc().setColumnName("col_" + i).setTypeDesc(typeDesc).setPosition(i); + columns.add(columnDesc); + } + return new TTableSchema().setColumns(columns); + } + + /** + * Creates a TFetchResultsResp with valid Arrow data. Uses empty arrowSchema bytes so that the + * complete Arrow IPC stream (with embedded schema) is used as-is. + * + * @param arrowData Valid Arrow IPC bytes (should include schema) + * @param rowCount Row count for the batch + * @param hasMoreRows Whether there are more rows to fetch + * @return Configured TFetchResultsResp + */ + private TFetchResultsResp createFetchResultsResp( + byte[] arrowData, int rowCount, boolean hasMoreRows) { + TSparkArrowBatch arrowBatch = new TSparkArrowBatch().setRowCount(rowCount).setBatch(arrowData); + TRowSet rowSet = new TRowSet().setArrowBatches(Collections.singletonList(arrowBatch)); + + // Use empty arrowSchema - this causes getSerializedSchema to return empty bytes, + // which effectively means nothing is prepended to our complete Arrow IPC stream + TGetResultSetMetadataResp metadata = + new TGetResultSetMetadataResp().setSchema(TWO_COLUMN_SCHEMA).setArrowSchema(new byte[0]); + + TFetchResultsResp response = + new TFetchResultsResp().setResultSetMetadata(metadata).setResults(rowSet); + response.hasMoreRows = hasMoreRows; + + return response; + } + + @Test + void testEmptyResultSet() throws DatabricksSQLException { + byte[] arrowData = createValidArrowData(1, 0); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, 0, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + // Verify all initial state for empty result + assertEquals(-1, result.getCurrentRow()); + assertEquals(0, result.getRowCount()); + assertEquals(0, result.getTotalRowsFetched()); + assertEquals(0, result.getChunkCount()); + assertFalse(result.hasNext()); + assertFalse(result.next()); + assertTrue(result.isCompletelyFetched()); + } + + @Test + void testGetObjectThrowsWhenClosed() throws DatabricksSQLException { + byte[] arrowData = createValidArrowData(1, 1); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, 1, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + result.close(); + + // Verify close behavior + assertFalse(result.hasNext()); + assertFalse(result.next()); + assertEquals(0, result.getRowCount()); + + DatabricksSQLException exception = + assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); + assertEquals("Result is already closed", exception.getMessage()); + assertEquals(DatabricksDriverErrorCode.STATEMENT_CLOSED.name(), exception.getSQLState()); + } + + @Test + void testGetObjectThrowsWhenBeforeFirstRow() throws DatabricksSQLException { + byte[] arrowData = createValidArrowData(1, 1); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, 1, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + DatabricksSQLException exception = + assertThrows(DatabricksSQLException.class, () -> result.getObject(0)); + assertEquals("Cursor is before first row", exception.getMessage()); + assertEquals(DatabricksDriverErrorCode.INVALID_STATE.name(), exception.getSQLState()); + } + + @Test + void testIsCompletelyFetchedWithMoreRows() throws DatabricksSQLException { + byte[] arrowData = createValidArrowData(1, 0); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, 0, true); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertFalse(result.isCompletelyFetched()); + assertTrue(result.hasNext()); // hasNext is true because hasMoreRows is true + } + + @Test + void testIterateThroughRowsWithValidArrowData() throws DatabricksSQLException { + int rowCount = 5; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + // Verify initial state + assertEquals(-1, result.getCurrentRow()); + assertTrue(result.hasNext()); + + // Iterate through all rows + for (int i = 0; i < rowCount; i++) { + assertTrue(result.hasNext(), "Should have next at row " + i); + assertTrue(result.next(), "next() should return true at row " + i); + assertEquals(i, result.getCurrentRow()); + } + + // After all rows + assertFalse(result.hasNext()); + assertFalse(result.next()); + assertEquals(rowCount, result.getTotalRowsFetched()); + } + + @Test + void testGetObjectReturnsCorrectIntegerValue() throws DatabricksSQLException { + int rowCount = 3; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + when(session.getConnectionContext()).thenReturn(connectionContext); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + // Move to first row and get value + assertTrue(result.next()); + Object value = result.getObject(0); + assertNotNull(value); + assertInstanceOf(Integer.class, value); + assertEquals(0, value); // First row: batch_0 * 100 + row_0 = 0 + + // Move to second row + assertTrue(result.next()); + value = result.getObject(0); + assertEquals(1, value); // Second row: batch_0 * 100 + row_1 = 1 + + // Move to third row + assertTrue(result.next()); + value = result.getObject(0); + assertEquals(2, value); // Third row: batch_0 * 100 + row_2 = 2 + } + + @Test + void testGetObjectWithTwoColumns() throws DatabricksSQLException { + int rowCount = 2; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + when(session.getConnectionContext()).thenReturn(connectionContext); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertTrue(result.next()); + + // Get integer column + Object intValue = result.getObject(0); + assertNotNull(intValue); + assertInstanceOf(Integer.class, intValue); + assertEquals(0, intValue); + + // Get string column + Object stringValue = result.getObject(1); + assertNotNull(stringValue); + assertEquals("row_0_0", stringValue.toString()); + } + + @Test + void testGetObjectThrowsForColumnIndexOutOfBounds() throws DatabricksSQLException { + int rowCount = 1; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + // Note: session.getConnectionContext() is not stubbed here because the column index + // validation happens before the connection context is accessed + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertTrue(result.next()); + + // Test negative index + DatabricksSQLException negativeException = + assertThrows(DatabricksSQLException.class, () -> result.getObject(-1)); + assertTrue(negativeException.getMessage().contains("Column index out of bounds")); + assertEquals(DatabricksDriverErrorCode.INVALID_STATE.name(), negativeException.getSQLState()); + + // Test index beyond column count (we have 2 columns: 0 and 1) + DatabricksSQLException beyondException = + assertThrows(DatabricksSQLException.class, () -> result.getObject(2)); + assertTrue(beyondException.getMessage().contains("Column index out of bounds")); + assertEquals(DatabricksDriverErrorCode.INVALID_STATE.name(), beyondException.getSQLState()); + } + + @Test + void testMaxRowsLimitEnforced() throws DatabricksSQLException { + int totalRows = 10; + int maxRows = 3; + byte[] arrowData = createValidArrowData(1, totalRows); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, totalRows, false); + + when(statement.getMaxRows()).thenReturn(maxRows); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + // Should only be able to iterate up to maxRows + int rowsRetrieved = 0; + while (result.next()) { + rowsRetrieved++; + } + + assertEquals(maxRows, rowsRetrieved); + assertFalse(result.hasNext()); + assertEquals(maxRows - 1, result.getCurrentRow()); // 0-based index + } + + @Test + void testGetArrowMetadataReturnsMetadata() throws DatabricksSQLException { + int rowCount = 1; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + List metadata = result.getArrowMetadata(); + assertNotNull(metadata); + // The metadata list should have one entry per column (2 columns: int and string) + assertEquals(2, metadata.size()); + } + + @Test + void testFetchNextChunkFromServer() throws DatabricksSQLException { + int rowsPerChunk = 2; + byte[] arrowData1 = createValidArrowData(1, rowsPerChunk); + byte[] arrowData2 = createValidArrowData(1, rowsPerChunk); + + // First chunk with hasMoreRows = true + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData1, rowsPerChunk, true); + + // Second chunk with hasMoreRows = false + TFetchResultsResp secondResponse = createFetchResultsResp(arrowData2, rowsPerChunk, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + when(session.getDatabricksClient()).thenReturn(databricksClient); + when(databricksClient.getMoreResults(statement)).thenReturn(secondResponse); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + // Iterate through first chunk + assertTrue(result.next()); + assertTrue(result.next()); + assertFalse(result.isCompletelyFetched()); // Still has more rows + + // This should trigger fetch of next chunk + assertTrue(result.next()); + assertTrue(result.next()); + + // After all rows + assertFalse(result.next()); + assertTrue(result.isCompletelyFetched()); + assertEquals(rowsPerChunk * 2, result.getTotalRowsFetched()); + + // Verify that getMoreResults was called + verify(databricksClient).getMoreResults(statement); + } + + @Test + void testGetRowCountReturnsCurrentChunkRowCount() throws DatabricksSQLException { + int rowCount = 5; + byte[] arrowData = createValidArrowData(1, rowCount); + TFetchResultsResp initialResponse = createFetchResultsResp(arrowData, rowCount, false); + + when(statement.getMaxRows()).thenReturn(0); + when(statement.getStatementId()).thenReturn(STATEMENT_ID); + + LazyThriftInlineArrowResult result = + new LazyThriftInlineArrowResult(initialResponse, statement, session); + + assertEquals(rowCount, result.getRowCount()); + } +} diff --git a/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksEmptyMetadataClientTest.java b/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksEmptyMetadataClientTest.java index dd2aca1ad6..1152900d27 100644 --- a/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksEmptyMetadataClientTest.java +++ b/src/test/java/com/databricks/jdbc/dbclient/impl/sqlexec/DatabricksEmptyMetadataClientTest.java @@ -115,4 +115,39 @@ void testListPrimaryKeys() throws SQLException { assertEquals(resultSet.getMetaData().getColumnName(2), "TABLE_SCHEM"); assertEquals(resultSet.getMetaData().getColumnName(3), "TABLE_NAME"); } + + @Test + void testListImportedKeys() throws SQLException { + ResultSet resultSet = mockClient.listImportedKeys(session, "catalog", "schema", "table"); + assertNotNull(resultSet); + assertFalse(resultSet.next()); // empty result set + assertEquals(14, resultSet.getMetaData().getColumnCount()); + assertEquals("PKTABLE_CAT", resultSet.getMetaData().getColumnName(1)); + } + + @Test + void testListExportedKeys() throws SQLException { + ResultSet resultSet = mockClient.listExportedKeys(session, "catalog", "schema", "table"); + assertNotNull(resultSet); + assertFalse(resultSet.next()); // empty result set + assertEquals(14, resultSet.getMetaData().getColumnCount()); + assertEquals("PKTABLE_CAT", resultSet.getMetaData().getColumnName(1)); + } + + @Test + void testListCrossReferences() throws SQLException { + ResultSet resultSet = + mockClient.listCrossReferences( + session, + "parentCatalog", + "parentSchema", + "parentTable", + "foreignCatalog", + "foreignSchema", + "foreignTable"); + assertNotNull(resultSet); + assertFalse(resultSet.next()); // empty result set + assertEquals(14, resultSet.getMetaData().getColumnCount()); + assertEquals("PKTABLE_CAT", resultSet.getMetaData().getColumnName(1)); + } }