diff --git a/internal/cli_service/cli_service.go b/internal/cli_service/cli_service.go index 71952c69..8622b936 100644 --- a/internal/cli_service/cli_service.go +++ b/internal/cli_service/cli_service.go @@ -6926,7 +6926,7 @@ func (p *TDBSqlResultFormat) String() string { } func (p *TDBSqlResultFormat) Validate() error { - return nil + return nil } // Attributes: // - Batch diff --git a/internal/rows/arrowbased/arrowRecordIterator.go b/internal/rows/arrowbased/arrowRecordIterator.go index 787a0bab..416afd09 100644 --- a/internal/rows/arrowbased/arrowRecordIterator.go +++ b/internal/rows/arrowbased/arrowRecordIterator.go @@ -10,34 +10,59 @@ import ( "github.com/apache/arrow/go/v12/arrow/ipc" "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/config" - dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors" "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" "github.com/databricks/databricks-sql-go/rows" ) func NewArrowRecordIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, bi BatchIterator, arrowSchemaBytes []byte, cfg config.Config) rows.ArrowBatchIterator { ari := arrowRecordIterator{ - cfg: cfg, - batchIterator: bi, - resultPageIterator: rpi, - ctx: ctx, - arrowSchemaBytes: arrowSchemaBytes, + cfg: cfg, + ctx: ctx, + arrowSchemaBytes: arrowSchemaBytes, } - return &ari + if bi != nil && rpi != nil { + // Both initial batch iterator and result page iterator + // Extract the raw iterator from the initial batch iterator and create a composite + if batchIter, ok := bi.(*batchIterator); ok { + pagedRaw := &pagedRawBatchIterator{ + ctx: ctx, + resultPageIterator: rpi, + cfg: &cfg, + startRowOffset: 0, + } + compositeRaw := NewInitialThenPagedRawIterator(batchIter.rawIterator, pagedRaw) + ari.batchIterator = NewBatchIterator(compositeRaw, arrowSchemaBytes, &cfg) + } else { + // Fallback: use initial batch iterator, ignore pagination for now + ari.batchIterator = bi + } + } else if bi != nil { + // Only initial batch iterator + ari.batchIterator = bi + } else if rpi != nil { + // Only result page iterator + pagedRawIter := &pagedRawBatchIterator{ + ctx: ctx, + resultPageIterator: rpi, + cfg: &cfg, + startRowOffset: 0, + } + ari.batchIterator = NewBatchIterator(pagedRawIter, arrowSchemaBytes, &cfg) + } + return &ari } // A type implemented DBSQLArrowBatchIterator type arrowRecordIterator struct { - ctx context.Context - cfg config.Config - batchIterator BatchIterator - resultPageIterator rowscanner.ResultPageIterator - currentBatch SparkArrowBatch - isFinished bool - arrowSchemaBytes []byte - arrowSchema *arrow.Schema + ctx context.Context + cfg config.Config + batchIterator BatchIterator + currentBatch SparkArrowBatch + isFinished bool + arrowSchemaBytes []byte + arrowSchema *arrow.Schema } var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil) @@ -80,18 +105,13 @@ func (ri *arrowRecordIterator) Close() { if ri.batchIterator != nil { ri.batchIterator.Close() } - - if ri.resultPageIterator != nil { - ri.resultPageIterator.Close() - } } } func (ri *arrowRecordIterator) checkFinished() { finished := ri.isFinished || ((ri.currentBatch == nil || !ri.currentBatch.HasNext()) && - (ri.batchIterator == nil || !ri.batchIterator.HasNext()) && - (ri.resultPageIterator == nil || !ri.resultPageIterator.HasNext())) + (ri.batchIterator == nil || !ri.batchIterator.HasNext())) if finished { // Reached end of result set so Close @@ -101,80 +121,39 @@ func (ri *arrowRecordIterator) checkFinished() { // Update the current batch if necessary func (ri *arrowRecordIterator) getCurrentBatch() error { - // only need to update if no current batch or current batch has no more records if ri.currentBatch == nil || !ri.currentBatch.HasNext() { - - // ensure up to date batch iterator - err := ri.getBatchIterator() - if err != nil { - return err - } - // release current batch if ri.currentBatch != nil { ri.currentBatch.Close() } // Get next batch from batch iterator - ri.currentBatch, err = ri.batchIterator.Next() - if err != nil { - return err - } - } - - return nil -} - -// Update batch iterator if necessary -func (ri *arrowRecordIterator) getBatchIterator() error { - // only need to update if there is no batch iterator or the - // batch iterator has no more batches - if ri.batchIterator == nil || !ri.batchIterator.HasNext() { - if ri.batchIterator != nil { - // release any resources held by the current batch iterator - ri.batchIterator.Close() - ri.batchIterator = nil + if ri.batchIterator == nil { + return io.EOF } - // Get the next page of the result set - resp, err := ri.resultPageIterator.Next() + var err error + ri.currentBatch, err = ri.batchIterator.Next() if err != nil { return err } - // Check the result format - resultFormat := resp.ResultSetMetadata.GetResultFormat() - if resultFormat != cli_service.TSparkRowSetType_ARROW_BASED_SET && resultFormat != cli_service.TSparkRowSetType_URL_BASED_SET { - return dbsqlerr.NewDriverError(ri.ctx, errArrowRowsNotArrowFormat, nil) - } - + // Update schema bytes if we don't have them yet and the batch iterator got them if ri.arrowSchemaBytes == nil { - ri.arrowSchemaBytes = resp.ResultSetMetadata.ArrowSchema - } - - // Create a new batch iterator for the batches in the result page - bi, err := ri.newBatchIterator(resp) - if err != nil { - return err + if batchIter, ok := ri.batchIterator.(*batchIterator); ok { + if pagedIter, ok := batchIter.rawIterator.(*pagedRawBatchIterator); ok { + if schemaBytes := pagedIter.GetSchemaBytes(); schemaBytes != nil { + ri.arrowSchemaBytes = schemaBytes + } + } + } } - - ri.batchIterator = bi } return nil } -// Create a new batch iterator from a page of the result set -func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) { - rowSet := fr.Results - if len(rowSet.ResultLinks) > 0 { - return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) - } else { - return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) - } -} - // Return the schema of the records. func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) { // Return cached schema if available @@ -207,3 +186,51 @@ func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) { ri.arrowSchema = reader.Schema() return ri.arrowSchema, nil } + +// InitialThenPagedRawIterator handles initial raw iterator first, then paged raw iterator +type InitialThenPagedRawIterator struct { + InitialRaw RawBatchIterator + PagedRaw RawBatchIterator +} + +// NewInitialThenPagedRawIterator creates a composite iterator +func NewInitialThenPagedRawIterator(initial, paged RawBatchIterator) RawBatchIterator { + return &InitialThenPagedRawIterator{ + InitialRaw: initial, + PagedRaw: paged, + } +} + +func (i *InitialThenPagedRawIterator) Next() (*cli_service.TSparkArrowBatch, error) { + if i.InitialRaw != nil && i.InitialRaw.HasNext() { + return i.InitialRaw.Next() + } + if i.PagedRaw != nil { + return i.PagedRaw.Next() + } + return nil, io.EOF +} + +func (i *InitialThenPagedRawIterator) HasNext() bool { + return (i.InitialRaw != nil && i.InitialRaw.HasNext()) || + (i.PagedRaw != nil && i.PagedRaw.HasNext()) +} + +func (i *InitialThenPagedRawIterator) Close() { + if i.InitialRaw != nil { + i.InitialRaw.Close() + } + if i.PagedRaw != nil { + i.PagedRaw.Close() + } +} + +func (i *InitialThenPagedRawIterator) GetStartRowOffset() int64 { + if i.InitialRaw != nil && i.InitialRaw.HasNext() { + return i.InitialRaw.GetStartRowOffset() + } + if i.PagedRaw != nil { + return i.PagedRaw.GetStartRowOffset() + } + return 0 +} diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index f6a60c58..71ccc0db 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "database/sql/driver" + "fmt" "io" "time" @@ -42,8 +43,8 @@ type colInfo struct { dbType cli_service.TTypeId } -// arrowRowScanner handles extracting values from arrow records -type arrowRowScanner struct { +// ArrowRowScanner handles extracting values from arrow records +type ArrowRowScanner struct { rowscanner.Delimiter valueContainerMaker @@ -74,11 +75,12 @@ type arrowRowScanner struct { ctx context.Context - batchIterator BatchIterator + rawBatchIterator RawBatchIterator // Direct results raw batch iterator + batchIterator BatchIterator // Lazy-initialized for row scanning } -// Make sure arrowRowScanner fulfills the RowScanner interface -var _ rowscanner.RowScanner = (*arrowRowScanner)(nil) +// Make sure ArrowRowScanner fulfills the RowScanner interface +var _ rowscanner.RowScanner = (*ArrowRowScanner)(nil) // NewArrowRowScanner returns an instance of RowScanner which handles arrow format results func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp, rowSet *cli_service.TRowSet, cfg *config.Config, logger *dbsqllog.DBSQLLogger, ctx context.Context) (rowscanner.RowScanner, dbsqlerr.DBError) { @@ -112,19 +114,23 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp return nil, dbsqlerrint.NewDriverError(ctx, errArrowRowsToTimestampFn, err) } - var bi BatchIterator - var err2 dbsqlerr.DBError + var rawBi RawBatchIterator if len(rowSet.ResultLinks) > 0 { logger.Debug().Msgf("Initialize CloudFetch loader, row set start offset: %d, file list:", rowSet.StartRowOffset) for _, resultLink := range rowSet.ResultLinks { logger.Debug().Msgf("- start row offset: %d, row count: %d", resultLink.StartRowOffset, resultLink.RowCount) } - bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) - } else { - bi, err2 = NewLocalBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) - } - if err2 != nil { - return nil, err2 + var err2 dbsqlerr.DBError + rawBi, err2 = NewCloudRawBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) + if err2 != nil { + return nil, err2 + } + } else if len(rowSet.ArrowBatches) > 0 { + var err2 dbsqlerr.DBError + rawBi, err2 = NewLocalRawBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) + if err2 != nil { + return nil, err2 + } } var location *time.Location = time.UTC @@ -134,7 +140,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp } } - rs := &arrowRowScanner{ + rs := &ArrowRowScanner{ Delimiter: rowscanner.NewDelimiter(rowSet.StartRowOffset, rowscanner.CountRows(rowSet)), valueContainerMaker: &arrowValueContainerMaker{}, ArrowConfig: arrowConfig, @@ -144,20 +150,22 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp colInfo: colInfos, DBSQLLogger: logger, location: location, - batchIterator: bi, + rawBatchIterator: rawBi, } return rs, nil } // Close is called when the Rows instance is closed. -func (ars *arrowRowScanner) Close() { +func (ars *ArrowRowScanner) Close() { if ars.rowValues != nil { ars.rowValues.Close() } if ars.batchIterator != nil { ars.batchIterator.Close() + } else if ars.rawBatchIterator != nil { + ars.rawBatchIterator.Close() } if ars.currentBatch != nil { @@ -166,7 +174,7 @@ func (ars *arrowRowScanner) Close() { } // NRows returns the number of rows in the current set of batches -func (ars *arrowRowScanner) NRows() int64 { +func (ars *ArrowRowScanner) NRows() int64 { if ars != nil { return ars.Count() } @@ -189,7 +197,7 @@ var intervalTypes map[cli_service.TTypeId]struct{} = map[cli_service.TTypeId]str // The dest should not be written to outside of ScanRow. Care // should be taken when closing a RowScanner not to modify // a buffer held in dest. -func (ars *arrowRowScanner) ScanRow( +func (ars *ArrowRowScanner) ScanRow( destination []driver.Value, rowNumber int64) dbsqlerr.DBError { @@ -241,12 +249,19 @@ func isIntervalType(typeId cli_service.TTypeId) bool { } // loadBatchFor loads the batch containing the specified row if necessary -func (ars *arrowRowScanner) loadBatchFor(rowNumber int64) dbsqlerr.DBError { +func (ars *ArrowRowScanner) loadBatchFor(rowNumber int64) dbsqlerr.DBError { if ars == nil { return dbsqlerrint.NewDriverError(context.Background(), errArrowRowsNoArrowBatches, nil) } + // Create batch iterator on demand from raw iterator + if ars.batchIterator == nil && ars.rawBatchIterator != nil { + ars.batchIterator = NewBatchIterator(ars.rawBatchIterator, ars.arrowSchemaBytes, &config.Config{ + ArrowConfig: ars.ArrowConfig, + }) + } + if ars.batchIterator == nil { return dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsNoArrowBatches, nil) } @@ -317,15 +332,42 @@ func (ars *arrowRowScanner) loadBatchFor(rowNumber int64) dbsqlerr.DBError { // Check that the row number falls within the range of this row scanner and that // it is not moving backwards. -func (ars *arrowRowScanner) validateRowNumber(rowNumber int64) dbsqlerr.DBError { +func (ars *ArrowRowScanner) validateRowNumber(rowNumber int64) dbsqlerr.DBError { if rowNumber < 0 || rowNumber > ars.End() || (ars.currentBatch != nil && ars.currentBatch.Direction(rowNumber) == rowscanner.DirBack) { return dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsInvalidRowNumber(rowNumber), nil) } return nil } -func (ars *arrowRowScanner) GetArrowBatches(ctx context.Context, cfg config.Config, rpi rowscanner.ResultPageIterator) (dbsqlrows.ArrowBatchIterator, error) { - ri := NewArrowRecordIterator(ctx, rpi, ars.batchIterator, ars.arrowSchemaBytes, cfg) +// GetRawBatchIterator returns the raw batch iterator for direct results +func (ars *ArrowRowScanner) GetRawBatchIterator() RawBatchIterator { + return ars.rawBatchIterator +} + +func (ars *ArrowRowScanner) GetArrowBatches(ctx context.Context, cfg config.Config, rpi rowscanner.ResultPageIterator) (dbsqlrows.ArrowBatchIterator, error) { + // Create a unified raw batch iterator that combines direct results and pagination + var unifiedRawIterator RawBatchIterator + + if ars.rawBatchIterator != nil && rpi != nil { + // Both direct results and pagination - compose them + pagedRawIterator := NewPagedRawBatchIterator(ctx, rpi, &cfg) + unifiedRawIterator = NewInitialThenPagedRawIterator(ars.rawBatchIterator, pagedRawIterator) + } else if ars.rawBatchIterator != nil { + // Only direct results + unifiedRawIterator = ars.rawBatchIterator + } else if rpi != nil { + // Only pagination + unifiedRawIterator = NewPagedRawBatchIterator(ctx, rpi, &cfg) + } else { + // No data + return nil, fmt.Errorf("no data available") + } + + // Create batch iterator wrapper + batchIterator := NewBatchIterator(unifiedRawIterator, ars.arrowSchemaBytes, &cfg) + + // Create arrow record iterator with the unified batch iterator + ri := NewArrowRecordIterator(ctx, nil, batchIterator, ars.arrowSchemaBytes, cfg) return ri, nil } @@ -556,7 +598,7 @@ type arrowValueContainerMaker struct{} var _ valueContainerMaker = (*arrowValueContainerMaker)(nil) // makeColumnValuesContainers creates appropriately typed column values holders for each column -func (vcm *arrowValueContainerMaker) makeColumnValuesContainers(ars *arrowRowScanner, d rowscanner.Delimiter) error { +func (vcm *arrowValueContainerMaker) makeColumnValuesContainers(ars *ArrowRowScanner, d rowscanner.Delimiter) error { if ars.rowValues == nil { columnValueHolders := make([]columnValues, len(ars.colInfo)) for i, field := range ars.arrowSchema.Fields() { diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index c43674eb..7c5f1bda 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -205,7 +205,7 @@ func TestArrowRowScanner(t *testing.T) { t.Run("NRows", func(t *testing.T) { // test counting the number of rows by summing individual batches - var dummy *arrowRowScanner + var dummy *ArrowRowScanner assert.Equal(t, int64(0), dummy.NRows()) rowSet := &cli_service.TRowSet{} @@ -238,7 +238,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 1)) require.Nil(t, err) @@ -314,7 +314,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 1)) require.Nil(t, err) @@ -416,7 +416,7 @@ func TestArrowRowScanner(t *testing.T) { d, err1 := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) require.Nil(t, err1) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) err := ars.makeColumnValuesContainers(ars, rowscanner.NewDelimiter(0, 0)) require.Nil(t, err) @@ -424,7 +424,7 @@ func TestArrowRowScanner(t *testing.T) { dest := make([]driver.Value, 1) err = ars.ScanRow(dest, 0) require.NotNil(t, err) - assert.True(t, strings.Contains(err.Error(), "databricks: driver error: "+errArrowRowsInvalidRowNumber(0))) + assert.True(t, strings.Contains(err.Error(), errArrowRowsNoArrowBatches)) }) t.Run("Close releases column values", func(t *testing.T) { @@ -447,7 +447,7 @@ func TestArrowRowScanner(t *testing.T) { require.Nil(t, err) d.Close() - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) var releaseCount int fc := &fakeColumnValues{fnRelease: func() { releaseCount++ }} ars.rowValues = NewRowValues(rowscanner.NewDelimiter(0, 1), []columnValues{fc, fc, fc}) @@ -456,12 +456,12 @@ func TestArrowRowScanner(t *testing.T) { }) t.Run("loadBatch invalid row scanner", func(t *testing.T) { - var ars *arrowRowScanner + var ars *ArrowRowScanner err := ars.loadBatchFor(0) assert.NotNil(t, err) assert.ErrorContains(t, err, errArrowRowsNoArrowBatches) - ars = &arrowRowScanner{} + ars = &ArrowRowScanner{} ars.DBSQLLogger = dbsqllog.Logger err = ars.loadBatchFor(0) assert.NotNil(t, err) @@ -484,7 +484,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) assert.Nil(t, ars.rowValues) @@ -502,7 +502,7 @@ func TestArrowRowScanner(t *testing.T) { ars.batchIterator = fbi var callCount int - ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { + ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *ArrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { callCount += 1 columnValueHolders := make([]columnValues, len(ars.arrowSchema.Fields())) for i := range ars.arrowSchema.Fields() { @@ -554,7 +554,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -592,7 +592,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -631,7 +631,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -645,7 +645,7 @@ func TestArrowRowScanner(t *testing.T) { ars.batchIterator = fbi ars.valueContainerMaker = &fakeValueContainerMaker{ - fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { + fnMakeColumnValuesContainers: func(ars *ArrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { return dbsqlerrint.NewDriverError(context.TODO(), "error making containers", nil) }, } @@ -672,7 +672,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -708,7 +708,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) fbi := &fakeBatchIterator{ batches: []SparkArrowBatch{ @@ -857,7 +857,7 @@ func TestArrowRowScanner(t *testing.T) { d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) - var ars *arrowRowScanner = d.(*arrowRowScanner) + var ars *ArrowRowScanner = d.(*ArrowRowScanner) ars.UseArrowNativeComplexTypes = true ars.UseArrowNativeDecimal = true ars.UseArrowNativeIntervalTypes = true @@ -873,7 +873,7 @@ func TestArrowRowScanner(t *testing.T) { } ars.batchIterator = fbi - ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { + ars.valueContainerMaker = &fakeValueContainerMaker{fnMakeColumnValuesContainers: func(ars *ArrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError { columnValueHolders := make([]columnValues, len(ars.arrowSchema.Fields())) for i := range ars.arrowSchema.Fields() { columnValueHolders[i] = &fakeColumnValues{} @@ -935,7 +935,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 0) @@ -963,7 +963,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 1) @@ -986,7 +986,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := []driver.Value{ true, int8(4), int16(3), int32(2), int64(1), float32(3.3), float64(2.2), "stringval", @@ -1018,7 +1018,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 2) @@ -1038,19 +1038,24 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) assert.Equal(t, int64(53940), ars.NRows()) - bi, ok := ars.batchIterator.(*localBatchIterator) - assert.True(t, ok) + dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) + + // Trigger creation of batchIterator by attempting to scan the first row + err = ars.ScanRow(dest, 0) + assert.Nil(t, err) + + // Now wrap the initialized batchIterator fbi := &batchIteratorWrapper{ - bi: bi, + bi: ars.batchIterator, + callCount: 1, // Already loaded one batch for row 0 } - ars.batchIterator = fbi - dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) - for i := int64(0); i < ars.NRows(); i = i + 1 { + // Continue from row 1 since we already scanned row 0 + for i := int64(1); i < ars.NRows(); i = i + 1 { err := ars.ScanRow(dest, i) assert.Nil(t, err) assert.Equal(t, int32(i+1), dest[0]) @@ -1110,7 +1115,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) // err = ars.ScanRow(dest, 0) @@ -1138,7 +1143,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 0) @@ -1190,7 +1195,7 @@ func TestArrowRowScanner(t *testing.T) { d, err1 := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err1) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err1 = ars.ScanRow(dest, 1) @@ -1231,7 +1236,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 0) @@ -1290,7 +1295,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 0) @@ -1323,7 +1328,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 0) @@ -1362,7 +1367,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) @@ -1401,7 +1406,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 0) @@ -1437,7 +1442,7 @@ func TestArrowRowScanner(t *testing.T) { d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) assert.Nil(t, err) - ars := d.(*arrowRowScanner) + ars := d.(*ArrowRowScanner) dest := make([]driver.Value, len(executeStatementResp.DirectResults.ResultSetMetadata.Schema.Columns)) err = ars.ScanRow(dest, 0) @@ -1997,12 +2002,12 @@ func getAllTypesSchema() *cli_service.TTableSchema { } type fakeValueContainerMaker struct { - fnMakeColumnValuesContainers func(ars *arrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError + fnMakeColumnValuesContainers func(ars *ArrowRowScanner, d rowscanner.Delimiter) dbsqlerr.DBError } var _ valueContainerMaker = (*fakeValueContainerMaker)(nil) -func (vcm *fakeValueContainerMaker) makeColumnValuesContainers(ars *arrowRowScanner, d rowscanner.Delimiter) error { +func (vcm *fakeValueContainerMaker) makeColumnValuesContainers(ars *ArrowRowScanner, d rowscanner.Delimiter) error { if vcm.fnMakeColumnValuesContainers != nil { return vcm.fnMakeColumnValuesContainers(ars, d) } diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 4f7ef0be..d9c855c4 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -21,24 +21,97 @@ import ( "github.com/databricks/databricks-sql-go/logger" ) +type RawBatchIterator interface { + Next() (*cli_service.TSparkArrowBatch, error) + HasNext() bool + Close() + GetStartRowOffset() int64 +} + +// BatchIterator provides parsed Arrow batches type BatchIterator interface { Next() (SparkArrowBatch, error) HasNext() bool Close() } -func NewCloudBatchIterator( +// batchIterator wraps a RawBatchIterator and parses the raw batches +type batchIterator struct { + rawIterator RawBatchIterator + arrowSchemaBytes []byte + cfg *config.Config +} + +var _ BatchIterator = (*batchIterator)(nil) + +func (bi *batchIterator) Next() (SparkArrowBatch, error) { + rawBatch, err := bi.rawIterator.Next() + if err != nil { + return nil, err + } + + // GetStartRowOffset returns the start offset of the last returned batch + startOffset := bi.rawIterator.GetStartRowOffset() + + // Get schema bytes dynamically if not available + schemaBytes := bi.arrowSchemaBytes + if schemaBytes == nil { + if pagedIter, ok := bi.rawIterator.(*pagedRawBatchIterator); ok { + schemaBytes = pagedIter.GetSchemaBytes() + } + } + + reader := io.MultiReader( + bytes.NewReader(schemaBytes), + getReader(bytes.NewReader(rawBatch.Batch), bi.cfg.UseLz4Compression), + ) + + records, err := getArrowRecords(reader, startOffset) + if err != nil { + return nil, err + } + + batch := sparkArrowBatch{ + Delimiter: rowscanner.NewDelimiter(startOffset, rawBatch.RowCount), + arrowRecords: records, + } + + return &batch, nil +} + +func (bi *batchIterator) HasNext() bool { + return bi.rawIterator.HasNext() +} + +func (bi *batchIterator) Close() { + bi.rawIterator.Close() +} + +// NewBatchIterator creates a BatchIterator from a RawBatchIterator +func NewBatchIterator( + rawIterator RawBatchIterator, + arrowSchemaBytes []byte, + cfg *config.Config, +) BatchIterator { + return &batchIterator{ + rawIterator: rawIterator, + arrowSchemaBytes: arrowSchemaBytes, + cfg: cfg, + } +} + +func NewCloudRawBatchIterator( ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config, -) (BatchIterator, dbsqlerr.DBError) { +) (RawBatchIterator, dbsqlerr.DBError) { bi := &cloudBatchIterator{ - ctx: ctx, - cfg: cfg, - startRowOffset: startRowOffset, - pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), - downloadTasks: NewQueue[cloudFetchDownloadTask](), + ctx: ctx, + cfg: cfg, + currentRowOffset: startRowOffset, + pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), + downloadTasks: NewQueue[cloudFetchDownloadTask](), } for _, link := range files { @@ -48,16 +121,16 @@ func NewCloudBatchIterator( return bi, nil } -func NewLocalBatchIterator( +func NewLocalRawBatchIterator( ctx context.Context, batches []*cli_service.TSparkArrowBatch, startRowOffset int64, arrowSchemaBytes []byte, cfg *config.Config, -) (BatchIterator, dbsqlerr.DBError) { +) (RawBatchIterator, dbsqlerr.DBError) { bi := &localBatchIterator{ cfg: cfg, - startRowOffset: startRowOffset, + currentRowOffset: startRowOffset, arrowSchemaBytes: arrowSchemaBytes, batches: batches, index: -1, @@ -68,38 +141,22 @@ func NewLocalBatchIterator( type localBatchIterator struct { cfg *config.Config - startRowOffset int64 + currentRowOffset int64 // Tracks the start offset of the last returned batch arrowSchemaBytes []byte batches []*cli_service.TSparkArrowBatch index int } -var _ BatchIterator = (*localBatchIterator)(nil) +var _ RawBatchIterator = (*localBatchIterator)(nil) -func (bi *localBatchIterator) Next() (SparkArrowBatch, error) { +func (bi *localBatchIterator) Next() (*cli_service.TSparkArrowBatch, error) { cnt := len(bi.batches) bi.index++ if bi.index < cnt { ab := bi.batches[bi.index] - - reader := io.MultiReader( - bytes.NewReader(bi.arrowSchemaBytes), - getReader(bytes.NewReader(ab.Batch), bi.cfg.UseLz4Compression), - ) - - records, err := getArrowRecords(reader, bi.startRowOffset) - if err != nil { - return &sparkArrowBatch{}, err - } - - batch := sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, ab.RowCount), - arrowRecords: records, - } - - bi.startRowOffset += ab.RowCount // advance to beginning of the next batch - - return &batch, nil + // Update offset after returning the batch + bi.currentRowOffset += ab.RowCount + return ab, nil } bi.index = cnt @@ -116,17 +173,28 @@ func (bi *localBatchIterator) Close() { bi.index = len(bi.batches) } +func (bi *localBatchIterator) GetStartRowOffset() int64 { + // Return the offset of the last returned batch + if bi.index >= 0 && bi.index < len(bi.batches) { + // currentRowOffset points to after the last returned batch + // so subtract the current batch's row count to get its start + return bi.currentRowOffset - bi.batches[bi.index].RowCount + } + return bi.currentRowOffset +} + type cloudBatchIterator struct { - ctx context.Context - cfg *config.Config - startRowOffset int64 - pendingLinks Queue[cli_service.TSparkArrowResultLink] - downloadTasks Queue[cloudFetchDownloadTask] + ctx context.Context + cfg *config.Config + currentRowOffset int64 // Tracks the offset after the last returned batch + lastBatchRowCount int64 // Tracks the row count of the last returned batch + pendingLinks Queue[cli_service.TSparkArrowResultLink] + downloadTasks Queue[cloudFetchDownloadTask] } -var _ BatchIterator = (*cloudBatchIterator)(nil) +var _ RawBatchIterator = (*cloudBatchIterator)(nil) -func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { +func (bi *cloudBatchIterator) Next() (*cli_service.TSparkArrowBatch, error) { for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) { link := bi.pendingLinks.Dequeue() logger.Debug().Msgf( @@ -153,7 +221,7 @@ func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { return nil, io.EOF } - batch, err := task.GetResult() + batchData, rowCount, err := task.GetResult() // once we've got an errored out task - cancel the remaining ones if err != nil { @@ -163,6 +231,14 @@ func (bi *cloudBatchIterator) Next() (SparkArrowBatch, error) { // explicitly call cancel function on successfully completed task to avoid context leak task.cancel() + + // Create TSparkArrowBatch from the downloaded data + batch := &cli_service.TSparkArrowBatch{ + Batch: batchData, + RowCount: rowCount, + } + bi.lastBatchRowCount = rowCount + bi.currentRowOffset += rowCount return batch, nil } @@ -178,9 +254,15 @@ func (bi *cloudBatchIterator) Close() { } } +func (bi *cloudBatchIterator) GetStartRowOffset() int64 { + // Return the start offset of the last returned batch + return bi.currentRowOffset - bi.lastBatchRowCount +} + type cloudFetchDownloadTaskResult struct { - batch SparkArrowBatch - err error + batchData []byte + rowCount int64 + err error } type cloudFetchDownloadTask struct { @@ -192,7 +274,7 @@ type cloudFetchDownloadTask struct { resultChan chan cloudFetchDownloadTaskResult } -func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { +func (cft *cloudFetchDownloadTask) GetResult() ([]byte, int64, error) { link := cft.link result, ok := <-cft.resultChan @@ -204,14 +286,14 @@ func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { link.RowCount, result.err.Error(), ) - return nil, result.err + return nil, 0, result.err } logger.Debug().Msgf( "CloudFetch: received data for link at offset %d row count %d", link.StartRowOffset, link.RowCount, ) - return result.batch, nil + return result.batchData, result.rowCount, nil } // This branch should never be reached. If you see this message - something got really wrong @@ -220,7 +302,7 @@ func (cft *cloudFetchDownloadTask) GetResult() (SparkArrowBatch, error) { link.StartRowOffset, link.RowCount, ) - return nil, nil + return nil, 0, nil } func (cft *cloudFetchDownloadTask) Run() { @@ -234,31 +316,29 @@ func (cft *cloudFetchDownloadTask) Run() { ) data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry) if err != nil { - cft.resultChan <- cloudFetchDownloadTaskResult{batch: nil, err: err} + cft.resultChan <- cloudFetchDownloadTaskResult{batchData: nil, rowCount: 0, err: err} return } - // io.ReadCloser.Close() may return an error, but in this case it should be safe to ignore (I hope so) - defer data.Close() + // Read all data into memory + batchData, err := io.ReadAll(data) + data.Close() // Close after reading + if err != nil { + cft.resultChan <- cloudFetchDownloadTaskResult{batchData: nil, rowCount: 0, err: err} + return + } logger.Debug().Msgf( - "CloudFetch: reading records for link at offset %d row count %d", + "CloudFetch: received batch data for link at offset %d row count %d", cft.link.StartRowOffset, cft.link.RowCount, ) - reader := getReader(data, cft.useLz4Compression) - - records, err := getArrowRecords(reader, cft.link.StartRowOffset) - if err != nil { - cft.resultChan <- cloudFetchDownloadTaskResult{batch: nil, err: err} - return - } - batch := sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(cft.link.StartRowOffset, cft.link.RowCount), - arrowRecords: records, + cft.resultChan <- cloudFetchDownloadTaskResult{ + batchData: batchData, + rowCount: cft.link.RowCount, + err: nil, } - cft.resultChan <- cloudFetchDownloadTaskResult{batch: &batch, err: nil} }() } @@ -308,6 +388,125 @@ func isLinkExpired(expiryTime int64, linkExpiryBuffer time.Duration) bool { return expiryTime-bufferSecs < time.Now().Unix() } +// pagedRawBatchIterator wraps a result page iterator and provides raw batches +type pagedRawBatchIterator struct { + ctx context.Context + resultPageIterator rowscanner.ResultPageIterator + currentIterator RawBatchIterator + cfg *config.Config + startRowOffset int64 + schemaBytes []byte // Schema bytes extracted from first page +} + +// NewPagedRawBatchIterator creates a raw batch iterator from a result page iterator +func NewPagedRawBatchIterator( + ctx context.Context, + resultPageIterator rowscanner.ResultPageIterator, + cfg *config.Config, +) RawBatchIterator { + return &pagedRawBatchIterator{ + ctx: ctx, + resultPageIterator: resultPageIterator, + cfg: cfg, + startRowOffset: 0, + } +} + +func (pi *pagedRawBatchIterator) Next() (*cli_service.TSparkArrowBatch, error) { + // If we have a current iterator and it has more batches, use it + if pi.currentIterator != nil && pi.currentIterator.HasNext() { + return pi.currentIterator.Next() + } + + // Need to fetch next page + if pi.resultPageIterator == nil || !pi.resultPageIterator.HasNext() { + return nil, io.EOF + } + + fetchResult, err := pi.resultPageIterator.Next() + if err != nil { + return nil, err + } + + if fetchResult == nil || fetchResult.Results == nil { + return nil, io.EOF + } + + rowSet := fetchResult.Results + + // Extract schema bytes from the first page if not already extracted + if pi.schemaBytes == nil && fetchResult.ResultSetMetadata != nil && fetchResult.ResultSetMetadata.ArrowSchema != nil { + pi.schemaBytes = fetchResult.ResultSetMetadata.ArrowSchema + } + + // Create appropriate iterator based on result type + if len(rowSet.ResultLinks) > 0 { + pi.currentIterator, err = NewCloudRawBatchIterator(pi.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, pi.cfg) + } else if len(rowSet.ArrowBatches) > 0 { + // Pass schema bytes to local iterator if available + pi.currentIterator, err = NewLocalRawBatchIterator(pi.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, pi.schemaBytes, pi.cfg) + } else { + return nil, io.EOF + } + + if err != nil { + return nil, err + } + + // Now get the first batch from the new iterator + return pi.currentIterator.Next() +} + +func (pi *pagedRawBatchIterator) HasNext() bool { + return (pi.currentIterator != nil && pi.currentIterator.HasNext()) || + (pi.resultPageIterator != nil && pi.resultPageIterator.HasNext()) +} + +func (pi *pagedRawBatchIterator) Close() { + if pi.currentIterator != nil { + pi.currentIterator.Close() + } +} + +func (pi *pagedRawBatchIterator) GetStartRowOffset() int64 { + if pi.currentIterator != nil { + return pi.currentIterator.GetStartRowOffset() + } + return pi.startRowOffset +} + +// GetSchemaBytes returns the schema bytes extracted from the first fetched page +func (pi *pagedRawBatchIterator) GetSchemaBytes() []byte { + return pi.schemaBytes +} + +// Legacy wrapper functions for backward compatibility with tests +func NewCloudBatchIterator( + ctx context.Context, + files []*cli_service.TSparkArrowResultLink, + startRowOffset int64, + cfg *config.Config, +) (BatchIterator, dbsqlerr.DBError) { + // For cloud iterator, we don't have schema bytes at this level + // The schema will need to be provided when creating the BatchIterator wrapper + // This is a temporary compatibility function for tests + return nil, dbsqlerrint.NewDriverError(ctx, "NewCloudBatchIterator is deprecated, use NewCloudRawBatchIterator + NewBatchIterator", nil) +} + +func NewLocalBatchIterator( + ctx context.Context, + batches []*cli_service.TSparkArrowBatch, + startRowOffset int64, + arrowSchemaBytes []byte, + cfg *config.Config, +) (BatchIterator, dbsqlerr.DBError) { + rawIterator, err := NewLocalRawBatchIterator(ctx, batches, startRowOffset, arrowSchemaBytes, cfg) + if err != nil { + return nil, err + } + return NewBatchIterator(rawIterator, arrowSchemaBytes, cfg), nil +} + func getArrowRecords(r io.Reader, startRowOffset int64) ([]SparkArrowRecord, error) { ipcReader, err := ipc.NewReader(r) if err != nil { diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index d02d2992..1b2fde3e 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -72,7 +72,7 @@ func TestCloudFetchIterator(t *testing.T) { cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 - bi, err := NewCloudBatchIterator( + rawBi, err := NewCloudRawBatchIterator( context.Background(), links, startRowOffset, @@ -82,34 +82,36 @@ func TestCloudFetchIterator(t *testing.T) { panic(err) } - cbi := bi.(*cloudBatchIterator) + cbi := rawBi.(*cloudBatchIterator) - assert.True(t, bi.HasNext()) + assert.True(t, rawBi.HasNext()) assert.Equal(t, cbi.pendingLinks.Len(), len(links)) assert.Equal(t, cbi.downloadTasks.Len(), 0) // get first link - should succeed - sab1, err2 := bi.Next() + sab1, err2 := rawBi.Next() if err2 != nil { panic(err2) } assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) assert.Equal(t, cbi.downloadTasks.Len(), 0) - assert.Equal(t, sab1.Start(), startRowOffset) + // Raw batch doesn't have Start() method, check row count instead + assert.Equal(t, sab1.RowCount, int64(1)) // get second link - should succeed - sab2, err3 := bi.Next() + sab2, err3 := rawBi.Next() if err3 != nil { panic(err3) } assert.Equal(t, cbi.pendingLinks.Len(), len(links)-2) assert.Equal(t, cbi.downloadTasks.Len(), 0) - assert.Equal(t, sab2.Start(), startRowOffset+sab1.Count()) + // Check second batch row count + assert.Equal(t, sab2.RowCount, int64(1)) // all links downloaded, should be no more data - assert.False(t, bi.HasNext()) + assert.False(t, rawBi.HasNext()) }) t.Run("should fail on expired link", func(t *testing.T) { @@ -142,7 +144,7 @@ func TestCloudFetchIterator(t *testing.T) { cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 - bi, err := NewCloudBatchIterator( + rawBi, err := NewCloudRawBatchIterator( context.Background(), links, startRowOffset, @@ -152,24 +154,25 @@ func TestCloudFetchIterator(t *testing.T) { panic(err) } - cbi := bi.(*cloudBatchIterator) + cbi := rawBi.(*cloudBatchIterator) - assert.True(t, bi.HasNext()) + assert.True(t, rawBi.HasNext()) assert.Equal(t, cbi.pendingLinks.Len(), len(links)) assert.Equal(t, cbi.downloadTasks.Len(), 0) // get first link - should succeed - sab1, err2 := bi.Next() + sab1, err2 := rawBi.Next() if err2 != nil { panic(err2) } assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) assert.Equal(t, cbi.downloadTasks.Len(), 0) - assert.Equal(t, sab1.Start(), startRowOffset) + // Raw batch doesn't have Start() method, check row count instead + assert.Equal(t, sab1.RowCount, int64(1)) // get second link - should fail - _, err3 := bi.Next() + _, err3 := rawBi.Next() assert.NotNil(t, err3) assert.ErrorContains(t, err3, dbsqlerr.ErrLinkExpired) }) @@ -196,7 +199,7 @@ func TestCloudFetchIterator(t *testing.T) { cfg.UseLz4Compression = false cfg.MaxDownloadThreads = 1 - bi, err := NewCloudBatchIterator( + rawBi, err := NewCloudRawBatchIterator( context.Background(), links, startRowOffset, @@ -206,9 +209,9 @@ func TestCloudFetchIterator(t *testing.T) { panic(err) } - cbi := bi.(*cloudBatchIterator) + cbi := rawBi.(*cloudBatchIterator) - assert.True(t, bi.HasNext()) + assert.True(t, rawBi.HasNext()) assert.Equal(t, cbi.pendingLinks.Len(), len(links)) assert.Equal(t, cbi.downloadTasks.Len(), 0) @@ -222,14 +225,15 @@ func TestCloudFetchIterator(t *testing.T) { } // get first link - should succeed - sab1, err2 := bi.Next() + sab1, err2 := rawBi.Next() if err2 != nil { panic(err2) } assert.Equal(t, cbi.pendingLinks.Len(), len(links)-1) assert.Equal(t, cbi.downloadTasks.Len(), 0) - assert.Equal(t, sab1.Start(), startRowOffset) + // Raw batch doesn't have Start() method, check row count instead + assert.Equal(t, sab1.RowCount, int64(1)) // set handler for the first link, which fails with some non-retryable HTTP error handler = func(w http.ResponseWriter, r *http.Request) { @@ -237,7 +241,7 @@ func TestCloudFetchIterator(t *testing.T) { } // get second link - should fail - _, err3 := bi.Next() + _, err3 := rawBi.Next() assert.NotNil(t, err3) assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) }) diff --git a/internal/rows/arrowbased/columnValues.go b/internal/rows/arrowbased/columnValues.go index 0b6fc7d8..7c235d57 100644 --- a/internal/rows/arrowbased/columnValues.go +++ b/internal/rows/arrowbased/columnValues.go @@ -74,7 +74,7 @@ func (rv *rowValues) NColumns() int { return len(rv.columnValueHolders) } func (rv *rowValues) SetDelimiter(d rowscanner.Delimiter) { rv.Delimiter = d } type valueContainerMaker interface { - makeColumnValuesContainers(ars *arrowRowScanner, d rowscanner.Delimiter) error + makeColumnValuesContainers(ars *ArrowRowScanner, d rowscanner.Delimiter) error } // columnValues is the interface for accessing the values for a column diff --git a/internal/rows/arrowbased/ipc_stream_iterator.go b/internal/rows/arrowbased/ipc_stream_iterator.go new file mode 100644 index 00000000..2262a61d --- /dev/null +++ b/internal/rows/arrowbased/ipc_stream_iterator.go @@ -0,0 +1,73 @@ +package arrowbased + +import ( + "bytes" + "io" + + "github.com/databricks/databricks-sql-go/internal/config" + dbsqlrows "github.com/databricks/databricks-sql-go/rows" + "github.com/pierrec/lz4/v4" +) + +// ipcStreamIterator provides access to raw Arrow IPC streams without deserialization +type ipcStreamIterator struct { + rawBatchIterator RawBatchIterator + arrowSchemaBytes []byte + useLz4 bool +} + +// NewIPCStreamIterator creates an iterator that returns raw IPC streams +func NewIPCStreamIterator( + rawIterator RawBatchIterator, + schemaBytes []byte, + cfg *config.Config, +) dbsqlrows.IPCStreamIterator { + var useLz4 bool + if cfg != nil { + useLz4 = cfg.UseLz4Compression + } + + return &ipcStreamIterator{ + rawBatchIterator: rawIterator, + arrowSchemaBytes: schemaBytes, + useLz4: useLz4, + } +} + +// NextIPCStream returns the next Arrow batch as a raw IPC stream +func (it *ipcStreamIterator) NextIPCStream() (io.Reader, error) { + rawBatch, err := it.rawBatchIterator.Next() + if err != nil { + return nil, err + } + + // Create reader for the batch data + var batchReader io.Reader = bytes.NewReader(rawBatch.Batch) + + // Handle LZ4 decompression if needed + if it.useLz4 { + batchReader = lz4.NewReader(batchReader) + } + + // Combine schema and batch data into a complete IPC stream + // Arrow IPC format expects: [Schema][Batch1][Batch2]... + return io.MultiReader( + bytes.NewReader(it.arrowSchemaBytes), + batchReader, + ), nil +} + +// HasNext returns true if there are more batches +func (it *ipcStreamIterator) HasNext() bool { + return it.rawBatchIterator.HasNext() +} + +// Close releases any resources +func (it *ipcStreamIterator) Close() { + it.rawBatchIterator.Close() +} + +// GetSchemaBytes returns the Arrow schema in IPC format +func (it *ipcStreamIterator) GetSchemaBytes() ([]byte, error) { + return it.arrowSchemaBytes, nil +} diff --git a/internal/rows/rows.go b/internal/rows/rows.go index cddf2f15..0e8d0b41 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "fmt" "math" "reflect" "time" @@ -527,6 +528,14 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger { return r.logger_ } +// getArrowSchemaBytes converts the table schema to Arrow IPC format bytes +func (r *rows) getArrowSchemaBytes(schema *cli_service.TTableSchema) ([]byte, error) { + // We need to use the arrow-based row scanner's conversion methods + // This is a temporary solution - ideally this would be refactored to share code + // For now, delegate to the arrowbased package + return nil, fmt.Errorf("schema conversion not yet implemented - use ArrowSchema from metadata") +} + func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterator, error) { // update context with correlationId and connectionId which will be used in logging and errors ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId) @@ -539,3 +548,65 @@ func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterato return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil } + +// GetIPCStreams returns an iterator that provides raw Arrow IPC streams +func (r *rows) GetIPCStreams(ctx context.Context) (dbsqlrows.IPCStreamIterator, error) { + // Update context with correlationId and connectionId + ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId) + + // First try to get Arrow schema bytes from metadata if available + var schemaBytes []byte + if r.resultSetMetadata != nil && r.resultSetMetadata.ArrowSchema != nil { + schemaBytes = r.resultSetMetadata.ArrowSchema + } else { + // Fall back to generating from table schema + schema, err := r.getResultSetSchema() + if err != nil { + return nil, dbsqlerr_int.NewDriverError(ctx, "failed to get result set schema", err) + } + + // Convert schema to IPC format bytes + var err2 error + schemaBytes, err2 = r.getArrowSchemaBytes(schema) + if err2 != nil { + return nil, dbsqlerr_int.NewDriverError(ctx, "failed to convert schema to IPC format", err2) + } + } + + // Create a unified raw batch iterator + var rawIterator arrowbased.RawBatchIterator + + // Check if we have direct results from the row scanner + if r.RowScanner != nil { + // Try to get the raw batch iterator from the row scanner + if arrowScanner, ok := r.RowScanner.(*arrowbased.ArrowRowScanner); ok && arrowScanner.GetRawBatchIterator() != nil { + directRawIterator := arrowScanner.GetRawBatchIterator() + + if r.ResultPageIterator != nil { + // Compose direct results with pagination + pagedIterator := arrowbased.NewPagedRawBatchIterator(ctx, r.ResultPageIterator, r.config) + rawIterator = arrowbased.NewInitialThenPagedRawIterator(directRawIterator, pagedIterator) + } else { + // Only direct results + rawIterator = directRawIterator + } + } else if r.ResultPageIterator != nil { + // No direct results, only pagination + rawIterator = arrowbased.NewPagedRawBatchIterator(ctx, r.ResultPageIterator, r.config) + } + } else if r.ResultPageIterator != nil { + // No row scanner, only pagination + rawIterator = arrowbased.NewPagedRawBatchIterator(ctx, r.ResultPageIterator, r.config) + } + + if rawIterator == nil { + return nil, dbsqlerr_int.NewDriverError(ctx, "no data available", nil) + } + + // Create IPC stream iterator + return arrowbased.NewIPCStreamIterator( + rawIterator, + schemaBytes, + r.config, + ), nil +} diff --git a/pagination_architecture.md b/pagination_architecture.md new file mode 100644 index 00000000..b8ad3f83 --- /dev/null +++ b/pagination_architecture.md @@ -0,0 +1,305 @@ +# Databricks SQL Go Driver - Pagination Architecture + +## UML Class Diagram + +```mermaid +classDiagram + %% Interfaces + class ResultPageIterator { + <> + +Next() (TFetchResultsResp, error) + +HasNext() bool + +Close() error + +Start() int64 + +End() int64 + +Count() int64 + } + + class RawBatchIterator { + <> + +Next() (TSparkArrowBatch, error) + +HasNext() bool + +Close() + +GetStartRowOffset() int64 + } + + class BatchIterator { + <> + +Next() (SparkArrowBatch, error) + +HasNext() bool + +Close() + } + + class IPCStreamIterator { + <> + +NextIPCStream() (io.Reader, error) + +HasNext() bool + +Close() + +GetSchemaBytes() ([]byte, error) + } + + class ArrowBatchIterator { + <> + +Next() (arrow.Record, error) + +HasNext() bool + +Close() + +Schema() (*arrow.Schema, error) + } + + %% Concrete Implementations + class resultPageIterator { + -opHandle: TOperationHandle + -client: TCLIService + -maxPageSize: int64 + -isFinished: bool + -nextResultPage: TFetchResultsResp + -closedOnServer: bool + +getNextPage() (TFetchResultsResp, error) + } + + class pagedRawBatchIterator { + -resultPageIterator: ResultPageIterator + -currentIterator: RawBatchIterator + -cfg: Config + +Next() (TSparkArrowBatch, error) + +HasNext() bool + } + + class localRawBatchIterator { + -batches: []TSparkArrowBatch + -index: int + -currentRowOffset: int64 + +Next() (TSparkArrowBatch, error) + +HasNext() bool + } + + class cloudRawBatchIterator { + -pendingLinks: Queue[TSparkArrowResultLink] + -downloadTasks: Queue[downloadTask] + -currentRowOffset: int64 + +Next() (TSparkArrowBatch, error) + +HasNext() bool + } + + class batchIterator { + -rawIterator: RawBatchIterator + -arrowSchemaBytes: []byte + -cfg: Config + +Next() (SparkArrowBatch, error) + +HasNext() bool + } + + class ipcStreamIterator { + -rawBatchIterator: RawBatchIterator + -arrowSchemaBytes: []byte + -useLz4: bool + +NextIPCStream() (io.Reader, error) + +HasNext() bool + } + + class arrowRecordIterator { + -resultPageIterator: ResultPageIterator + -batchIterator: BatchIterator + -currentBatch: SparkArrowBatch + -currentRecord: SparkArrowRecord + -arrowSchemaBytes: []byte + -cfg: Config + +Next() (arrow.Record, error) + +HasNext() bool + +Schema() (*arrow.Schema, error) + +newBatchIterator(resp) BatchIterator + } + + %% Relationships + ResultPageIterator <|.. resultPageIterator : implements + RawBatchIterator <|.. pagedRawBatchIterator : implements + RawBatchIterator <|.. localRawBatchIterator : implements + RawBatchIterator <|.. cloudRawBatchIterator : implements + BatchIterator <|.. batchIterator : implements + IPCStreamIterator <|.. ipcStreamIterator : implements + ArrowBatchIterator <|.. arrowRecordIterator : implements + + pagedRawBatchIterator --> ResultPageIterator : uses + pagedRawBatchIterator --> RawBatchIterator : creates + batchIterator --> RawBatchIterator : wraps + ipcStreamIterator --> RawBatchIterator : uses + + arrowRecordIterator --> ResultPageIterator : uses + arrowRecordIterator --> BatchIterator : creates per page + arrowRecordIterator --> localRawBatchIterator : creates via factory + arrowRecordIterator --> cloudRawBatchIterator : creates via factory + + %% External dependencies + class TCLIService { + <> + +FetchResults(req) (TFetchResultsResp) + } + + class TFetchResultsResp { + <> + +Results: TRowSet + +HasMoreRows: bool + } + + class TRowSet { + <> + +ArrowBatches: []TSparkArrowBatch + +ResultLinks: []TSparkArrowResultLink + +StartRowOffset: int64 + } + + resultPageIterator --> TCLIService : calls + resultPageIterator --> TFetchResultsResp : returns + TFetchResultsResp --> TRowSet : contains +``` + +## Sequence Diagram - How Pagination Works + +```mermaid +sequenceDiagram + participant Client + participant IPCStreamIterator + participant PagedRawBatchIterator + participant ResultPageIterator + participant LocalRawBatchIterator + participant Server + + Client->>IPCStreamIterator: NextIPCStream() + IPCStreamIterator->>PagedRawBatchIterator: Next() + + alt currentIterator is null or exhausted + PagedRawBatchIterator->>ResultPageIterator: HasNext() + + alt nextResultPage is null + ResultPageIterator->>ResultPageIterator: getNextPage() + ResultPageIterator->>Server: FetchResults(opHandle, maxRows) + Server-->>ResultPageIterator: TFetchResultsResp + ResultPageIterator->>ResultPageIterator: cache in nextResultPage + end + + ResultPageIterator-->>PagedRawBatchIterator: true + PagedRawBatchIterator->>ResultPageIterator: Next() + ResultPageIterator-->>PagedRawBatchIterator: TFetchResultsResp (cached) + + alt has ArrowBatches + PagedRawBatchIterator->>LocalRawBatchIterator: new(batches) + else has ResultLinks + PagedRawBatchIterator->>CloudRawBatchIterator: new(links) + end + + PagedRawBatchIterator->>LocalRawBatchIterator: Next() + LocalRawBatchIterator-->>PagedRawBatchIterator: TSparkArrowBatch + else currentIterator has more + PagedRawBatchIterator->>LocalRawBatchIterator: Next() + LocalRawBatchIterator-->>PagedRawBatchIterator: TSparkArrowBatch + end + + PagedRawBatchIterator-->>IPCStreamIterator: TSparkArrowBatch + IPCStreamIterator->>IPCStreamIterator: wrap with schema + IPCStreamIterator-->>Client: io.Reader (IPC stream) +``` + +## Component Interaction Flow + +```mermaid +flowchart TB + subgraph Server Side + DB[(Database)] + API[Thrift API
FetchResults] + end + + subgraph Client Side - Pagination Layer + RPI[ResultPageIterator] + RPI -->|fetches pages| API + RPI -->|pre-fetches on HasNext| Cache[Page Cache] + end + + subgraph Client Side - Raw Batch Layer + PRBI[PagedRawBatchIterator] + LRBI[LocalRawBatchIterator] + CRBI[CloudRawBatchIterator] + + PRBI -->|uses| RPI + PRBI -->|creates per page| LRBI + PRBI -->|creates per page| CRBI + end + + subgraph Client Side - Processing Layer + BI[BatchIterator] + IPCI[IPCStreamIterator] + ARI[ArrowRecordIterator] + + BI -->|wraps| LRBI + BI -->|wraps| CRBI + IPCI -->|uses| PRBI + ARI -->|uses| RPI + ARI -->|creates| BI + end + + subgraph Client Side - User API + Rows[rows.GetIPCStreams()] + ArrowAPI[rows.GetArrowBatches()] + + Rows -->|creates| IPCI + ArrowAPI -->|creates| ARI + end +``` + +## ArrowRecordIterator Flow (New Architecture) + +```mermaid +sequenceDiagram + participant Client + participant ArrowRecordIterator + participant ResultPageIterator + participant RawBatchIterator + participant BatchIterator + participant Server + + Client->>ArrowRecordIterator: Next() + + alt batchIterator is null or exhausted + ArrowRecordIterator->>ResultPageIterator: Next() + ResultPageIterator->>Server: FetchResults() + Server-->>ResultPageIterator: TFetchResultsResp + ResultPageIterator-->>ArrowRecordIterator: TFetchResultsResp + + alt has ArrowBatches + ArrowRecordIterator->>localRawBatchIterator: NewLocalRawBatchIterator(batches) + ArrowRecordIterator->>BatchIterator: NewBatchIterator(rawIterator, schema) + else has ResultLinks + ArrowRecordIterator->>cloudRawBatchIterator: NewCloudRawBatchIterator(links) + ArrowRecordIterator->>BatchIterator: NewBatchIterator(rawIterator, schema) + end + + ArrowRecordIterator->>ArrowRecordIterator: batchIterator = new BatchIterator + end + + alt currentBatch is null or exhausted + ArrowRecordIterator->>BatchIterator: Next() + BatchIterator->>RawBatchIterator: Next() + RawBatchIterator-->>BatchIterator: TSparkArrowBatch + BatchIterator->>BatchIterator: Parse Arrow data + BatchIterator-->>ArrowRecordIterator: SparkArrowBatch + ArrowRecordIterator->>ArrowRecordIterator: currentBatch = batch + end + + ArrowRecordIterator->>SparkArrowBatch: Next() + SparkArrowBatch-->>ArrowRecordIterator: SparkArrowRecord + ArrowRecordIterator-->>Client: arrow.Record +``` + +## Key Design Patterns + +1. **Iterator Pattern**: Multiple levels of iterators, each handling a specific concern +2. **Decorator Pattern**: BatchIterator decorates RawBatchIterator with parsing +3. **Strategy Pattern**: Different strategies for local vs cloud batch fetching +4. **Lazy Loading**: Pages fetched only when needed +5. **Pre-fetching**: HasNext() pre-fetches to check availability without consuming + +## Benefits of This Architecture + +- **Separation of Concerns**: Each layer handles one responsibility +- **Flexibility**: Easy to add new iterator types or change implementations +- **Performance**: Pre-fetching and lazy loading optimize network calls +- **Reusability**: Raw batch iterators can be used by multiple consumers \ No newline at end of file diff --git a/rows/ipc_stream.go b/rows/ipc_stream.go new file mode 100644 index 00000000..d6520b1b --- /dev/null +++ b/rows/ipc_stream.go @@ -0,0 +1,28 @@ +package rows + +import ( + "io" +) + +// IPCStreamIterator provides access to raw Arrow IPC streams +type IPCStreamIterator interface { + // GetNextIPCStream returns the next Arrow batch as an IPC stream reader + // Returns io.EOF when no more batches are available + NextIPCStream() (io.Reader, error) + + // HasNext returns true if there are more batches + HasNext() bool + + // Close releases any resources + Close() + + // GetSchemaBytes returns the Arrow schema in IPC format + GetSchemaBytes() ([]byte, error) +} + +// Extension to existing Rows interface +type RowsWithIPCStream interface { + Rows + // GetIPCStreams returns an iterator for raw Arrow IPC streams + GetIPCStreams() (IPCStreamIterator, error) +}