Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions internal/rows/arrowbased/ipc_stream_iterator.go
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs tests

Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package arrowbased

import (
"bytes"
"context"
"io"

"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
"github.com/pierrec/lz4/v4"
)

// ipcStreamIterator provides access to raw Arrow IPC streams without deserialization
type ipcStreamIterator struct {
ctx context.Context
resultPageIterator rowscanner.ResultPageIterator
currentBatches []*cli_service.TSparkArrowBatch
currentIndex int
arrowSchemaBytes []byte
useLz4 bool
hasMorePages bool
}

// NewIPCStreamIterator creates an iterator that returns raw IPC streams
func NewIPCStreamIterator(
ctx context.Context,
resultPageIterator rowscanner.ResultPageIterator,
initialRowSet *cli_service.TRowSet,
schemaBytes []byte,
cfg *config.Config,
) (dbsqlrows.IPCStreamIterator, error) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a scenario where we could return error here?

var useLz4 bool
if cfg != nil {
useLz4 = cfg.UseLz4Compression
}

var batches []*cli_service.TSparkArrowBatch
if initialRowSet != nil {
batches = initialRowSet.ArrowBatches
}

return &ipcStreamIterator{
ctx: ctx,
resultPageIterator: resultPageIterator,
currentBatches: batches,
currentIndex: 0,
arrowSchemaBytes: schemaBytes,
useLz4: useLz4,
hasMorePages: resultPageIterator != nil && resultPageIterator.HasNext(),
}, nil
}

// NextIPCStream returns the next Arrow batch as a raw IPC stream
func (it *ipcStreamIterator) NextIPCStream() (io.Reader, error) {
// Check if we need to load more batches from the next page
if it.currentIndex >= len(it.currentBatches) {
if !it.hasMorePages || it.resultPageIterator == nil {
return nil, io.EOF
}

// Fetch next page
fetchResult, err := it.resultPageIterator.Next()
if err != nil {
return nil, err
}

if fetchResult == nil || fetchResult.Results == nil || fetchResult.Results.ArrowBatches == nil {
return nil, io.EOF
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assumes that fetchResult will always arrow batches but we could also have cloud fetch links, we could use BatchIterator to abstract those details for us: https://github.com/databricks/databricks-sql-go/blob/main/internal/rows/arrowbased/arrowRecordIterator.go#L141-L162


it.currentBatches = fetchResult.Results.ArrowBatches
it.currentIndex = 0
it.hasMorePages = it.resultPageIterator.HasNext()

// If no batches in this page, recurse to try next page
if len(it.currentBatches) == 0 {
return it.NextIPCStream()
}
}

batch := it.currentBatches[it.currentIndex]
it.currentIndex++

// Create reader for the batch data
var batchReader io.Reader = bytes.NewReader(batch.Batch)

// Handle LZ4 decompression if needed
if it.useLz4 {
batchReader = lz4.NewReader(batchReader)
}

// Combine schema and batch data into a complete IPC stream
// Arrow IPC format expects: [Schema][Batch1][Batch2]...
return io.MultiReader(
bytes.NewReader(it.arrowSchemaBytes),
batchReader,
), nil
}

// HasNext returns true if there are more batches
func (it *ipcStreamIterator) HasNext() bool {
return it.currentIndex < len(it.currentBatches) || it.hasMorePages
}

// Close releases any resources
func (it *ipcStreamIterator) Close() {
// Nothing to close for this implementation
}

// GetSchemaBytes returns the Arrow schema in IPC format
func (it *ipcStreamIterator) GetSchemaBytes() ([]byte, error) {
return it.arrowSchemaBytes, nil
}
43 changes: 43 additions & 0 deletions internal/rows/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"math"
"reflect"
"time"
Expand Down Expand Up @@ -527,6 +528,14 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger {
return r.logger_
}

// getArrowSchemaBytes converts the table schema to Arrow IPC format bytes
func (r *rows) getArrowSchemaBytes(schema *cli_service.TTableSchema) ([]byte, error) {
// We need to use the arrow-based row scanner's conversion methods
// This is a temporary solution - ideally this would be refactored to share code
// For now, delegate to the arrowbased package
return nil, fmt.Errorf("schema conversion not yet implemented - use ArrowSchema from metadata")
}

func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterator, error) {
// update context with correlationId and connectionId which will be used in logging and errors
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId)
Expand All @@ -539,3 +548,37 @@ func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterato

return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil
}

// GetIPCStreams returns an iterator that provides raw Arrow IPC streams
func (r *rows) GetIPCStreams(ctx context.Context) (dbsqlrows.IPCStreamIterator, error) {
// Update context with correlationId and connectionId
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId)

// First try to get Arrow schema bytes from metadata if available
var schemaBytes []byte
if r.resultSetMetadata != nil && r.resultSetMetadata.ArrowSchema != nil {
schemaBytes = r.resultSetMetadata.ArrowSchema
} else {
// Fall back to generating from table schema
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have tTableSchemaToArrowSchema in arrowRows

schema, err := r.getResultSetSchema()
if err != nil {
return nil, dbsqlerr_int.NewDriverError(ctx, "failed to get result set schema", err)
}

// Convert schema to IPC format bytes
var err2 error
schemaBytes, err2 = r.getArrowSchemaBytes(schema)
if err2 != nil {
return nil, dbsqlerr_int.NewDriverError(ctx, "failed to convert schema to IPC format", err2)
}
}

// Create IPC stream iterator
return arrowbased.NewIPCStreamIterator(
ctx,
r.ResultPageIterator,
nil, // We don't have access to initial row set here
schemaBytes,
r.config,
)
}
28 changes: 28 additions & 0 deletions rows/ipc_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package rows

import (
"io"
)

// IPCStreamIterator provides access to raw Arrow IPC streams
type IPCStreamIterator interface {
// GetNextIPCStream returns the next Arrow batch as an IPC stream reader
// Returns io.EOF when no more batches are available
NextIPCStream() (io.Reader, error)

// HasNext returns true if there are more batches
HasNext() bool

// Close releases any resources
Close()

// GetSchemaBytes returns the Arrow schema in IPC format
GetSchemaBytes() ([]byte, error)
Comment thread
jadewang-db marked this conversation as resolved.
}

// Extension to existing Rows interface
type RowsWithIPCStream interface {
Rows
// GetIPCStreams returns an iterator for raw Arrow IPC streams
GetIPCStreams() (IPCStreamIterator, error)
}
Loading