Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/cli_service/cli_service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

175 changes: 101 additions & 74 deletions internal/rows/arrowbased/arrowRecordIterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,59 @@ import (
"github.com/apache/arrow/go/v12/arrow/ipc"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
"github.com/databricks/databricks-sql-go/rows"
)

func NewArrowRecordIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, bi BatchIterator, arrowSchemaBytes []byte, cfg config.Config) rows.ArrowBatchIterator {
ari := arrowRecordIterator{
cfg: cfg,
batchIterator: bi,
resultPageIterator: rpi,
ctx: ctx,
arrowSchemaBytes: arrowSchemaBytes,
cfg: cfg,
ctx: ctx,
arrowSchemaBytes: arrowSchemaBytes,
}

return &ari
if bi != nil && rpi != nil {
// Both initial batch iterator and result page iterator
// Extract the raw iterator from the initial batch iterator and create a composite
if batchIter, ok := bi.(*batchIterator); ok {
pagedRaw := &pagedRawBatchIterator{
ctx: ctx,
resultPageIterator: rpi,
cfg: &cfg,
startRowOffset: 0,
}
compositeRaw := NewInitialThenPagedRawIterator(batchIter.rawIterator, pagedRaw)
ari.batchIterator = NewBatchIterator(compositeRaw, arrowSchemaBytes, &cfg)
} else {
// Fallback: use initial batch iterator, ignore pagination for now
ari.batchIterator = bi
}
} else if bi != nil {
// Only initial batch iterator
ari.batchIterator = bi
} else if rpi != nil {
// Only result page iterator
pagedRawIter := &pagedRawBatchIterator{
ctx: ctx,
resultPageIterator: rpi,
cfg: &cfg,
startRowOffset: 0,
}
ari.batchIterator = NewBatchIterator(pagedRawIter, arrowSchemaBytes, &cfg)
}

return &ari
}

// A type implemented DBSQLArrowBatchIterator
type arrowRecordIterator struct {
ctx context.Context
cfg config.Config
batchIterator BatchIterator
resultPageIterator rowscanner.ResultPageIterator
currentBatch SparkArrowBatch
isFinished bool
arrowSchemaBytes []byte
arrowSchema *arrow.Schema
ctx context.Context
cfg config.Config
batchIterator BatchIterator
currentBatch SparkArrowBatch
isFinished bool
arrowSchemaBytes []byte
arrowSchema *arrow.Schema
}

var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil)
Expand Down Expand Up @@ -80,18 +105,13 @@ func (ri *arrowRecordIterator) Close() {
if ri.batchIterator != nil {
ri.batchIterator.Close()
}

if ri.resultPageIterator != nil {
ri.resultPageIterator.Close()
}
}
}

func (ri *arrowRecordIterator) checkFinished() {
finished := ri.isFinished ||
((ri.currentBatch == nil || !ri.currentBatch.HasNext()) &&
(ri.batchIterator == nil || !ri.batchIterator.HasNext()) &&
(ri.resultPageIterator == nil || !ri.resultPageIterator.HasNext()))
(ri.batchIterator == nil || !ri.batchIterator.HasNext()))

if finished {
// Reached end of result set so Close
Expand All @@ -101,80 +121,39 @@ func (ri *arrowRecordIterator) checkFinished() {

// Update the current batch if necessary
func (ri *arrowRecordIterator) getCurrentBatch() error {

// only need to update if no current batch or current batch has no more records
if ri.currentBatch == nil || !ri.currentBatch.HasNext() {

// ensure up to date batch iterator
err := ri.getBatchIterator()
if err != nil {
return err
}

// release current batch
if ri.currentBatch != nil {
ri.currentBatch.Close()
}

// Get next batch from batch iterator
ri.currentBatch, err = ri.batchIterator.Next()
if err != nil {
return err
}
}

return nil
}

// Update batch iterator if necessary
func (ri *arrowRecordIterator) getBatchIterator() error {
// only need to update if there is no batch iterator or the
// batch iterator has no more batches
if ri.batchIterator == nil || !ri.batchIterator.HasNext() {
if ri.batchIterator != nil {
// release any resources held by the current batch iterator
ri.batchIterator.Close()
ri.batchIterator = nil
if ri.batchIterator == nil {
return io.EOF
}

// Get the next page of the result set
resp, err := ri.resultPageIterator.Next()
var err error
ri.currentBatch, err = ri.batchIterator.Next()
if err != nil {
return err
}

// Check the result format
resultFormat := resp.ResultSetMetadata.GetResultFormat()
if resultFormat != cli_service.TSparkRowSetType_ARROW_BASED_SET && resultFormat != cli_service.TSparkRowSetType_URL_BASED_SET {
return dbsqlerr.NewDriverError(ri.ctx, errArrowRowsNotArrowFormat, nil)
}

// Update schema bytes if we don't have them yet and the batch iterator got them
if ri.arrowSchemaBytes == nil {
ri.arrowSchemaBytes = resp.ResultSetMetadata.ArrowSchema
}

// Create a new batch iterator for the batches in the result page
bi, err := ri.newBatchIterator(resp)
if err != nil {
return err
if batchIter, ok := ri.batchIterator.(*batchIterator); ok {
if pagedIter, ok := batchIter.rawIterator.(*pagedRawBatchIterator); ok {
if schemaBytes := pagedIter.GetSchemaBytes(); schemaBytes != nil {
ri.arrowSchemaBytes = schemaBytes
}
}
}
}

ri.batchIterator = bi
}

return nil
}

// Create a new batch iterator from a page of the result set
func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) {
rowSet := fr.Results
if len(rowSet.ResultLinks) > 0 {
return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg)
} else {
return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
}
}

// Return the schema of the records.
func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) {
// Return cached schema if available
Expand Down Expand Up @@ -207,3 +186,51 @@ func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) {
ri.arrowSchema = reader.Schema()
return ri.arrowSchema, nil
}

// InitialThenPagedRawIterator handles initial raw iterator first, then paged raw iterator
type InitialThenPagedRawIterator struct {
InitialRaw RawBatchIterator
PagedRaw RawBatchIterator
}

// NewInitialThenPagedRawIterator creates a composite iterator
func NewInitialThenPagedRawIterator(initial, paged RawBatchIterator) RawBatchIterator {
return &InitialThenPagedRawIterator{
InitialRaw: initial,
PagedRaw: paged,
}
}

func (i *InitialThenPagedRawIterator) Next() (*cli_service.TSparkArrowBatch, error) {
if i.InitialRaw != nil && i.InitialRaw.HasNext() {
return i.InitialRaw.Next()
}
if i.PagedRaw != nil {
return i.PagedRaw.Next()
}
return nil, io.EOF
}

func (i *InitialThenPagedRawIterator) HasNext() bool {
return (i.InitialRaw != nil && i.InitialRaw.HasNext()) ||
(i.PagedRaw != nil && i.PagedRaw.HasNext())
}

func (i *InitialThenPagedRawIterator) Close() {
if i.InitialRaw != nil {
i.InitialRaw.Close()
}
if i.PagedRaw != nil {
i.PagedRaw.Close()
}
}

func (i *InitialThenPagedRawIterator) GetStartRowOffset() int64 {
if i.InitialRaw != nil && i.InitialRaw.HasNext() {
return i.InitialRaw.GetStartRowOffset()
}
if i.PagedRaw != nil {
return i.PagedRaw.GetStartRowOffset()
}
return 0
}
Loading
Loading