Skip to content
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

### Updated
- Implemented lazy loading for inline Arrow results, fetching arrow batches on demand instead of all at once. This improves memory usage and initial response time for large result sets when using the Thrift protocol with Arrow format.

### Fixed
- Fixed complex data type metadata support when retrieving 0 rows in Arrow format
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.databricks.jdbc.api.IExecutionStatus;
import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult;
import com.databricks.jdbc.api.impl.arrow.ChunkProvider;
import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult;
import com.databricks.jdbc.api.impl.converters.ConverterHelper;
import com.databricks.jdbc.api.impl.converters.ObjectConverter;
import com.databricks.jdbc.api.impl.volume.VolumeOperationResult;
Expand Down Expand Up @@ -155,6 +156,8 @@ public DatabricksResultSet(
List<String> arrowMetadata = null;
if (executionResult instanceof ArrowStreamResult) {
arrowMetadata = ((ArrowStreamResult) executionResult).getArrowMetadata();
} else if (executionResult instanceof LazyThriftInlineArrowResult) {
arrowMetadata = ((LazyThriftInlineArrowResult) executionResult).getArrowMetadata();
}
this.resultSetMetaData =
new DatabricksResultSetMetaData(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.jdbc.api.impl;

import com.databricks.jdbc.api.impl.arrow.ArrowStreamResult;
import com.databricks.jdbc.api.impl.arrow.LazyThriftInlineArrowResult;
import com.databricks.jdbc.api.impl.volume.VolumeOperationResult;
import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.api.internal.IDatabricksStatementInternal;
Expand Down Expand Up @@ -96,9 +97,9 @@ private static IExecutionResult getResultHandler(
case COLUMN_BASED_SET:
return new LazyThriftResult(resultsResp, parentStatement, session);
case ARROW_BASED_SET:
return new ArrowStreamResult(resultsResp, true, parentStatement, session);
return new LazyThriftInlineArrowResult(resultsResp, parentStatement, session);
case URL_BASED_SET:
return new ArrowStreamResult(resultsResp, false, parentStatement, session);
return new ArrowStreamResult(resultsResp, parentStatement, session);
case ROW_BASED_SET:
throw new DatabricksSQLFeatureNotSupportedException(
"Invalid state - row based set cannot be received");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,11 @@ private static ChunkProvider createRemoteChunkProvider(

public ArrowStreamResult(
TFetchResultsResp resultsResp,
boolean isInlineArrow,
IDatabricksStatementInternal parentStatementId,
IDatabricksSession session)
throws DatabricksSQLException {
this(
resultsResp,
isInlineArrow,
parentStatementId,
session,
DatabricksHttpClientFactory.getInstance().getClient(session.getConnectionContext()));
Expand All @@ -158,19 +156,14 @@ public ArrowStreamResult(
@VisibleForTesting
ArrowStreamResult(
TFetchResultsResp resultsResp,
boolean isInlineArrow,
IDatabricksStatementInternal parentStatement,
IDatabricksSession session,
IDatabricksHttpClient httpClient)
throws DatabricksSQLException {
this.session = session;
setColumnInfo(resultsResp.getResultSetMetadata());
if (isInlineArrow) {
this.chunkProvider = new InlineChunkProvider(resultsResp, parentStatement, session);
} else {
this.chunkProvider =
createThriftRemoteChunkProvider(resultsResp, parentStatement, session, httpClient);
}
this.chunkProvider =
createThriftRemoteChunkProvider(resultsResp, parentStatement, session, httpClient);
}

/**
Expand Down Expand Up @@ -238,48 +231,15 @@ public List<String> getArrowMetadata() throws DatabricksSQLException {
/** {@inheritDoc} */
@Override
public Object getObject(int columnIndex) throws DatabricksSQLException {
ColumnInfoTypeName requiredType = columnInfos.get(columnIndex).getTypeName();
ColumnInfo columnInfo = columnInfos.get(columnIndex);
ColumnInfoTypeName requiredType = columnInfo.getTypeName();
String arrowMetadata = chunkIterator.getType(columnIndex);
if (arrowMetadata == null) {
arrowMetadata = columnInfos.get(columnIndex).getTypeText();
arrowMetadata = columnInfo.getTypeText();
}

// Handle complex type conversion when complex datatype support is disabled
boolean isComplexDatatypeSupportEnabled =
this.session.getConnectionContext().isComplexDatatypeSupportEnabled();
boolean isGeoSpatialSupportEnabled =
this.session.getConnectionContext().isGeoSpatialSupportEnabled();

// Check if we need to convert geospatial types to string when geospatial support is disabled
// This check must come before the general complex type check
if (!isGeoSpatialSupportEnabled && isGeospatialType(requiredType)) {
LOGGER.debug("Geospatial support is disabled, converting {} to STRING", requiredType);

Object result =
chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfos.get(columnIndex));
if (result == null) {
return null;
}
// Return raw string for geospatial types when support is disabled
return result;
}

if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) {
LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING");

Object result =
chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfos.get(columnIndex));
if (result == null) {
return null;
}
ComplexDataTypeParser parser = new ComplexDataTypeParser();
return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata);
}

return chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, requiredType, arrowMetadata, columnInfos.get(columnIndex));
return getObjectWithComplexTypeHandling(
session, chunkIterator, columnIndex, requiredType, arrowMetadata, columnInfo);
}

/**
Expand Down Expand Up @@ -384,6 +344,66 @@ private void setColumnInfo(TGetResultSetMetadataResp resultManifest) {
}
}

/**
* Helper method to handle complex type and geospatial type conversion when support is disabled.
*
* <p>This method is also used by LazyThriftInlineArrowResult for consistent type handling.
*
* @param session The databricks session
* @param chunkIterator The chunk iterator
* @param columnIndex The column index
* @param requiredType The required column type
* @param arrowMetadata The arrow metadata
* @param columnInfo The column info
* @return The object value (converted if complex/geospatial type and support disabled)
* @throws DatabricksSQLException if an error occurs
*/
protected static Object getObjectWithComplexTypeHandling(
IDatabricksSession session,
ArrowResultChunkIterator chunkIterator,
int columnIndex,
ColumnInfoTypeName requiredType,
String arrowMetadata,
ColumnInfo columnInfo)
throws DatabricksSQLException {
boolean isComplexDatatypeSupportEnabled =
session.getConnectionContext().isComplexDatatypeSupportEnabled();
boolean isGeoSpatialSupportEnabled =
session.getConnectionContext().isGeoSpatialSupportEnabled();

// Check if we need to convert geospatial types to string when geospatial support is disabled
// This check must come before the general complex type check
if (!isGeoSpatialSupportEnabled && isGeospatialType(requiredType)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sreekanth is making changes in similar code, make sure that you don't override his changes for geospatial

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack.

LOGGER.debug("Geospatial support is disabled, converting {} to STRING", requiredType);

Object result =
chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfo);
if (result == null) {
return null;
}
// Return raw string for geospatial types when support is disabled
return result;
}

if (!isComplexDatatypeSupportEnabled && isComplexType(requiredType)) {
LOGGER.debug("Complex datatype support is disabled, converting complex type to STRING");

Object result =
chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, ColumnInfoTypeName.STRING, "STRING", columnInfo);
if (result == null) {
return null;
}
ComplexDataTypeParser parser = new ComplexDataTypeParser();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to create this for every getObject call?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is just existing code from method getObject(int columnIndex) in the same class. I moved it to separate method because getObject(int columnIndex) was unreadable.

To answer the question: No we shouldn't create such objects in hot paths like getObjects. But I don't want to change the scope of this PR. Will create a separate change/


return parser.formatComplexTypeString(result.toString(), requiredType.name(), arrowMetadata);
}

return chunkIterator.getColumnObjectAtCurrentRow(
columnIndex, requiredType, arrowMetadata, columnInfo);
}

/**
* Converts a collection of ExternalLinks to a ChunkLinkFetchResult.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,17 @@
package com.databricks.jdbc.api.impl.arrow;

import static com.databricks.jdbc.common.util.DatabricksTypeUtil.*;
import static com.databricks.jdbc.common.util.DecompressionUtil.decompress;

import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.api.internal.IDatabricksStatementInternal;
import com.databricks.jdbc.common.CompressionCodec;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.client.thrift.generated.*;
import com.databricks.jdbc.model.core.ResultData;
import com.databricks.jdbc.model.core.ResultManifest;
import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode;
import com.google.common.annotations.VisibleForTesting;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.SchemaUtility;

/** Class to manage inline Arrow chunks */
public class InlineChunkProvider implements ChunkProvider {
Expand All @@ -37,23 +23,6 @@ public class InlineChunkProvider implements ChunkProvider {
private final ArrowResultChunk
arrowResultChunk; // There is only one packet of data in case of inline arrow

InlineChunkProvider(
TFetchResultsResp resultsResp,
IDatabricksStatementInternal parentStatement,
IDatabricksSession session)
throws DatabricksParsingException {
this.currentChunkIndex = -1;
this.totalRows = 0;
ByteArrayInputStream byteStream = initializeByteStream(resultsResp, session, parentStatement);
ArrowResultChunk.Builder builder =
ArrowResultChunk.builder().withInputStream(byteStream, totalRows);

if (parentStatement != null) {
builder.withStatementId(parentStatement.getStatementId());
}
arrowResultChunk = builder.build();
}

/**
* Constructor for inline arrow chunk provider from {@link ResultData} and {@link ResultManifest}.
*
Expand Down Expand Up @@ -123,97 +92,6 @@ public boolean isClosed() {
return isClosed;
}

private ByteArrayInputStream initializeByteStream(
TFetchResultsResp resultsResp,
IDatabricksSession session,
IDatabricksStatementInternal parentStatement)
throws DatabricksParsingException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
CompressionCodec compressionType =
CompressionCodec.getCompressionMapping(resultsResp.getResultSetMetadata());
try {
byte[] serializedSchema = getSerializedSchema(resultsResp.getResultSetMetadata());
if (serializedSchema != null) {
baos.write(serializedSchema);
}
writeToByteOutputStream(
compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos);
while (resultsResp.hasMoreRows) {
resultsResp = session.getDatabricksClient().getMoreResults(parentStatement);
writeToByteOutputStream(
compressionType, parentStatement, resultsResp.getResults().getArrowBatches(), baos);
}
return new ByteArrayInputStream(baos.toByteArray());
} catch (DatabricksSQLException | IOException e) {
handleError(e);
}
return null;
}

private void writeToByteOutputStream(
CompressionCodec compressionCodec,
IDatabricksStatementInternal parentStatement,
List<TSparkArrowBatch> arrowBatchList,
ByteArrayOutputStream baos)
throws DatabricksSQLException, IOException {
for (TSparkArrowBatch arrowBatch : arrowBatchList) {
byte[] decompressedBytes =
decompress(
arrowBatch.getBatch(),
compressionCodec,
String.format(
"Data fetch for inline arrow batch [%d] and statement [%s] with decompression algorithm : [%s]",
arrowBatch.getRowCount(), parentStatement, compressionCodec));
totalRows += arrowBatch.getRowCount();
baos.write(decompressedBytes);
}
}

private byte[] getSerializedSchema(TGetResultSetMetadataResp metadata)
throws DatabricksSQLException {
if (metadata.getArrowSchema() != null) {
return metadata.getArrowSchema();
}
Schema arrowSchema = hiveSchemaToArrowSchema(metadata.getSchema());
try {
return SchemaUtility.serialize(arrowSchema);
} catch (IOException e) {
handleError(e);
}
// should never reach here;
return null;
}

private Schema hiveSchemaToArrowSchema(TTableSchema hiveSchema)
throws DatabricksParsingException {
List<Field> fields = new ArrayList<>();
if (hiveSchema == null) {
return new Schema(fields);
}
try {
hiveSchema
.getColumns()
.forEach(
columnDesc -> {
try {
fields.add(getArrowField(columnDesc));
} catch (SQLException e) {
throw new RuntimeException(e);
}
});
} catch (RuntimeException e) {
handleError(e);
}
return new Schema(fields);
}

private Field getArrowField(TColumnDesc columnDesc) throws SQLException {
TPrimitiveTypeEntry primitiveTypeEntry = getTPrimitiveTypeOrDefault(columnDesc.getTypeDesc());
ArrowType arrowType = mapThriftToArrowType(primitiveTypeEntry.getType());
FieldType fieldType = new FieldType(true, arrowType, null);
return new Field(columnDesc.getColumnName(), fieldType, null);
}

@VisibleForTesting
void handleError(Exception e) throws DatabricksParsingException {
String errorMessage =
Expand Down
Loading
Loading