-
Notifications
You must be signed in to change notification settings - Fork 60
Add IPC stream interface for zero-copy Arrow data access #278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| package arrowbased | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "context" | ||
| "io" | ||
|
|
||
| "github.com/databricks/databricks-sql-go/internal/cli_service" | ||
| "github.com/databricks/databricks-sql-go/internal/config" | ||
| "github.com/databricks/databricks-sql-go/internal/rows/rowscanner" | ||
| dbsqlrows "github.com/databricks/databricks-sql-go/rows" | ||
| "github.com/pierrec/lz4/v4" | ||
| ) | ||
|
|
||
| // ipcStreamIterator provides access to raw Arrow IPC streams without deserialization | ||
| type ipcStreamIterator struct { | ||
| ctx context.Context | ||
| resultPageIterator rowscanner.ResultPageIterator | ||
| currentBatches []*cli_service.TSparkArrowBatch | ||
| currentIndex int | ||
| arrowSchemaBytes []byte | ||
| useLz4 bool | ||
| hasMorePages bool | ||
| } | ||
|
|
||
| // NewIPCStreamIterator creates an iterator that returns raw IPC streams | ||
| func NewIPCStreamIterator( | ||
| ctx context.Context, | ||
| resultPageIterator rowscanner.ResultPageIterator, | ||
| initialRowSet *cli_service.TRowSet, | ||
| schemaBytes []byte, | ||
| cfg *config.Config, | ||
| ) (dbsqlrows.IPCStreamIterator, error) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have a scenario where we could return error here? |
||
| var useLz4 bool | ||
| if cfg != nil { | ||
| useLz4 = cfg.UseLz4Compression | ||
| } | ||
|
|
||
| var batches []*cli_service.TSparkArrowBatch | ||
| if initialRowSet != nil { | ||
| batches = initialRowSet.ArrowBatches | ||
| } | ||
|
|
||
| return &ipcStreamIterator{ | ||
| ctx: ctx, | ||
| resultPageIterator: resultPageIterator, | ||
| currentBatches: batches, | ||
| currentIndex: 0, | ||
| arrowSchemaBytes: schemaBytes, | ||
| useLz4: useLz4, | ||
| hasMorePages: resultPageIterator != nil && resultPageIterator.HasNext(), | ||
| }, nil | ||
| } | ||
|
|
||
| // NextIPCStream returns the next Arrow batch as a raw IPC stream | ||
| func (it *ipcStreamIterator) NextIPCStream() (io.Reader, error) { | ||
| // Check if we need to load more batches from the next page | ||
| if it.currentIndex >= len(it.currentBatches) { | ||
| if !it.hasMorePages || it.resultPageIterator == nil { | ||
| return nil, io.EOF | ||
| } | ||
|
|
||
| // Fetch next page | ||
| fetchResult, err := it.resultPageIterator.Next() | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| if fetchResult == nil || fetchResult.Results == nil || fetchResult.Results.ArrowBatches == nil { | ||
| return nil, io.EOF | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this assumes that fetchResult will always arrow batches but we could also have cloud fetch links, we could use BatchIterator to abstract those details for us: https://github.com/databricks/databricks-sql-go/blob/main/internal/rows/arrowbased/arrowRecordIterator.go#L141-L162 |
||
|
|
||
| it.currentBatches = fetchResult.Results.ArrowBatches | ||
| it.currentIndex = 0 | ||
| it.hasMorePages = it.resultPageIterator.HasNext() | ||
|
|
||
| // If no batches in this page, recurse to try next page | ||
| if len(it.currentBatches) == 0 { | ||
| return it.NextIPCStream() | ||
| } | ||
| } | ||
|
|
||
| batch := it.currentBatches[it.currentIndex] | ||
| it.currentIndex++ | ||
|
|
||
| // Create reader for the batch data | ||
| var batchReader io.Reader = bytes.NewReader(batch.Batch) | ||
|
|
||
| // Handle LZ4 decompression if needed | ||
| if it.useLz4 { | ||
| batchReader = lz4.NewReader(batchReader) | ||
| } | ||
|
|
||
| // Combine schema and batch data into a complete IPC stream | ||
| // Arrow IPC format expects: [Schema][Batch1][Batch2]... | ||
| return io.MultiReader( | ||
| bytes.NewReader(it.arrowSchemaBytes), | ||
| batchReader, | ||
| ), nil | ||
| } | ||
|
|
||
| // HasNext returns true if there are more batches | ||
| func (it *ipcStreamIterator) HasNext() bool { | ||
| return it.currentIndex < len(it.currentBatches) || it.hasMorePages | ||
| } | ||
|
|
||
| // Close releases any resources | ||
| func (it *ipcStreamIterator) Close() { | ||
| // Nothing to close for this implementation | ||
| } | ||
|
|
||
| // GetSchemaBytes returns the Arrow schema in IPC format | ||
| func (it *ipcStreamIterator) GetSchemaBytes() ([]byte, error) { | ||
| return it.arrowSchemaBytes, nil | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ import ( | |
| "context" | ||
| "database/sql" | ||
| "database/sql/driver" | ||
| "fmt" | ||
| "math" | ||
| "reflect" | ||
| "time" | ||
|
|
@@ -527,6 +528,14 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger { | |
| return r.logger_ | ||
| } | ||
|
|
||
| // getArrowSchemaBytes converts the table schema to Arrow IPC format bytes | ||
| func (r *rows) getArrowSchemaBytes(schema *cli_service.TTableSchema) ([]byte, error) { | ||
| // We need to use the arrow-based row scanner's conversion methods | ||
| // This is a temporary solution - ideally this would be refactored to share code | ||
| // For now, delegate to the arrowbased package | ||
| return nil, fmt.Errorf("schema conversion not yet implemented - use ArrowSchema from metadata") | ||
| } | ||
|
|
||
| func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterator, error) { | ||
| // update context with correlationId and connectionId which will be used in logging and errors | ||
| ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId) | ||
|
|
@@ -539,3 +548,37 @@ func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterato | |
|
|
||
| return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil | ||
| } | ||
|
|
||
| // GetIPCStreams returns an iterator that provides raw Arrow IPC streams | ||
| func (r *rows) GetIPCStreams(ctx context.Context) (dbsqlrows.IPCStreamIterator, error) { | ||
| // Update context with correlationId and connectionId | ||
| ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId) | ||
|
|
||
| // First try to get Arrow schema bytes from metadata if available | ||
| var schemaBytes []byte | ||
| if r.resultSetMetadata != nil && r.resultSetMetadata.ArrowSchema != nil { | ||
| schemaBytes = r.resultSetMetadata.ArrowSchema | ||
| } else { | ||
| // Fall back to generating from table schema | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we already have tTableSchemaToArrowSchema in arrowRows |
||
| schema, err := r.getResultSetSchema() | ||
| if err != nil { | ||
| return nil, dbsqlerr_int.NewDriverError(ctx, "failed to get result set schema", err) | ||
| } | ||
|
|
||
| // Convert schema to IPC format bytes | ||
| var err2 error | ||
| schemaBytes, err2 = r.getArrowSchemaBytes(schema) | ||
| if err2 != nil { | ||
| return nil, dbsqlerr_int.NewDriverError(ctx, "failed to convert schema to IPC format", err2) | ||
| } | ||
| } | ||
|
|
||
| // Create IPC stream iterator | ||
| return arrowbased.NewIPCStreamIterator( | ||
| ctx, | ||
| r.ResultPageIterator, | ||
| nil, // We don't have access to initial row set here | ||
| schemaBytes, | ||
| r.config, | ||
| ) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| package rows | ||
|
|
||
| import ( | ||
| "io" | ||
| ) | ||
|
|
||
| // IPCStreamIterator provides access to raw Arrow IPC streams | ||
| type IPCStreamIterator interface { | ||
| // GetNextIPCStream returns the next Arrow batch as an IPC stream reader | ||
| // Returns io.EOF when no more batches are available | ||
| NextIPCStream() (io.Reader, error) | ||
|
|
||
| // HasNext returns true if there are more batches | ||
| HasNext() bool | ||
|
|
||
| // Close releases any resources | ||
| Close() | ||
|
|
||
| // GetSchemaBytes returns the Arrow schema in IPC format | ||
| GetSchemaBytes() ([]byte, error) | ||
|
jadewang-db marked this conversation as resolved.
|
||
| } | ||
|
|
||
| // Extension to existing Rows interface | ||
| type RowsWithIPCStream interface { | ||
| Rows | ||
| // GetIPCStreams returns an iterator for raw Arrow IPC streams | ||
| GetIPCStreams() (IPCStreamIterator, error) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs tests