Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/ipcstreams/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package main

import (
"context"
"database/sql"
"database/sql/driver"
"io"
"log"
"os"
"strconv"
"time"

"github.com/apache/arrow/go/v12/arrow/ipc"
dbsql "github.com/databricks/databricks-sql-go"
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
"github.com/joho/godotenv"
)

func main() {
// Load environment variables from .env file if it exists
// This will not override existing environment variables
_ = godotenv.Load()

port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
if err != nil {
log.Fatal(err.Error())
}

connector, err := dbsql.NewConnector(
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
dbsql.WithPort(port),
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
)
if err != nil {
log.Fatal(err)
}

db := sql.OpenDB(connector)
defer db.Close()

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

conn, _ := db.Conn(ctx)
defer conn.Close()

query := `SELECT * FROM samples.nyctaxi.trips LIMIT 1000`

var rows driver.Rows
err = conn.Raw(func(d interface{}) error {
var err error
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
return err
})

if err != nil {
log.Fatal("Failed to execute query: ", err)
}
defer rows.Close()

// Get the IPC stream iterator
ipcStreams, err := rows.(dbsqlrows.Rows).GetArrowIPCStreams(ctx)
if err != nil {
log.Fatal("Failed to get IPC streams: ", err)
}
defer ipcStreams.Close()

// Get the schema bytes
schemaBytes, err := ipcStreams.SchemaBytes()
if err != nil {
log.Fatal("Failed to get schema bytes: ", err)
}
log.Printf("Schema bytes length: %d", len(schemaBytes))

// Process IPC streams
streamCount := 0
recordCount := 0

for ipcStreams.HasNext() {
// Get the next IPC stream
reader, err := ipcStreams.Next()
if err != nil {
if err == io.EOF {
break
}
log.Fatal("Failed to get next IPC stream: ", err)
}

streamCount++

// Create an IPC reader for this stream
ipcReader, err := ipc.NewReader(reader)
if err != nil {
log.Fatal("Failed to create IPC reader: ", err)
}

// Process records in the stream
for ipcReader.Next() {
record := ipcReader.Record()
recordCount++
log.Printf("Stream %d, Record %d: %d rows, %d columns",
streamCount, recordCount, record.NumRows(), record.NumCols())

// Don't forget to release the record when done
record.Release()
}

if err := ipcReader.Err(); err != nil {
log.Fatal("IPC reader error: ", err)
}

ipcReader.Release()
}

log.Printf("Processed %d IPC streams with %d total records", streamCount, recordCount)
}
145 changes: 145 additions & 0 deletions internal/rows/arrowbased/arrowIPCStreamIterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package arrowbased

import (
"context"
"io"

"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
"github.com/databricks/databricks-sql-go/rows"
)

// NewArrowIPCStreamIterator creates a new iterator for Arrow IPC streams
func NewArrowIPCStreamIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, ipcIterator IPCStreamIterator, arrowSchemaBytes []byte, cfg config.Config) rows.ArrowIPCStreamIterator {
return &arrowIPCStreamIterator{
cfg: cfg,
ipcStreamIterator: ipcIterator,
resultPageIterator: rpi,
ctx: ctx,
arrowSchemaBytes: arrowSchemaBytes,
}
}

// arrowIPCStreamIterator implements rows.ArrowIPCStreamIterator
type arrowIPCStreamIterator struct {
ctx context.Context
cfg config.Config
ipcStreamIterator IPCStreamIterator
resultPageIterator rowscanner.ResultPageIterator
isFinished bool
arrowSchemaBytes []byte
}

var _ rows.ArrowIPCStreamIterator = (*arrowIPCStreamIterator)(nil)

// Next retrieves the next Arrow IPC stream
func (ri *arrowIPCStreamIterator) Next() (io.Reader, error) {
if !ri.HasNext() {
return nil, io.EOF
}

if ri.ipcStreamIterator != nil && ri.ipcStreamIterator.HasNext() {
return ri.ipcStreamIterator.Next()
}

// If there is no iterator, or we have exhausted the current iterator, try to load more data
if err := ri.fetchNextData(); err != nil {
return nil, err
}

// Try again after fetching new data
if ri.ipcStreamIterator != nil && ri.ipcStreamIterator.HasNext() {
return ri.ipcStreamIterator.Next()
}

return nil, io.EOF
}

// HasNext returns true if there are more streams available
func (ri *arrowIPCStreamIterator) HasNext() bool {
if ri.isFinished {
return false
}

if ri.ipcStreamIterator != nil && ri.ipcStreamIterator.HasNext() {
return true
}

if ri.resultPageIterator == nil || !ri.resultPageIterator.HasNext() {
return false
}

return true
}

// Close releases resources
func (ri *arrowIPCStreamIterator) Close() {
if ri.ipcStreamIterator != nil {
ri.ipcStreamIterator.Close()
ri.ipcStreamIterator = nil
}
ri.isFinished = true
}

// SchemaBytes returns the Arrow schema bytes
func (ri *arrowIPCStreamIterator) SchemaBytes() ([]byte, error) {
return ri.arrowSchemaBytes, nil
}

// fetchNextData loads the next page of data
func (ri *arrowIPCStreamIterator) fetchNextData() error {
if ri.isFinished {
return io.EOF
}

// First close any existing iterator
if ri.ipcStreamIterator != nil {
ri.ipcStreamIterator.Close()
ri.ipcStreamIterator = nil
}

if ri.resultPageIterator == nil || !ri.resultPageIterator.HasNext() {
ri.isFinished = true
return io.EOF
}

// Get the next page of the result set
resp, err := ri.resultPageIterator.Next()
if err != nil {
ri.isFinished = true
return err
}

// Check the result format
resultFormat := resp.ResultSetMetadata.GetResultFormat()
if resultFormat != cli_service.TSparkRowSetType_ARROW_BASED_SET && resultFormat != cli_service.TSparkRowSetType_URL_BASED_SET {
return dbsqlerrint.NewDriverError(ri.ctx, errArrowRowsNotArrowFormat, nil)
}

// Update schema if this is the first fetch
if ri.arrowSchemaBytes == nil && resp.ResultSetMetadata != nil && resp.ResultSetMetadata.ArrowSchema != nil {
ri.arrowSchemaBytes = resp.ResultSetMetadata.ArrowSchema
}

// Create new iterator from the fetched data
bi, err := ri.newIPCStreamIterator(resp)
if err != nil {
ri.isFinished = true
return err
}

ri.ipcStreamIterator = bi
return nil
}

// Create a new IPC stream iterator from a page of the result set
func (ri *arrowIPCStreamIterator) newIPCStreamIterator(fr *cli_service.TFetchResultsResp) (IPCStreamIterator, error) {
rowSet := fr.Results
if len(rowSet.ResultLinks) > 0 {
return NewCloudIPCStreamIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg)
} else {
return NewLocalIPCStreamIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
}
}
14 changes: 14 additions & 0 deletions internal/rows/arrowbased/arrowRows.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,20 @@ func (ars *arrowRowScanner) GetArrowBatches(ctx context.Context, cfg config.Conf
return ri, nil
}

func (ars *arrowRowScanner) GetArrowIPCStreams(ctx context.Context, cfg config.Config, rpi rowscanner.ResultPageIterator) (dbsqlrows.ArrowIPCStreamIterator, error) {
// Get the underlying IPC stream iterator from the batch iterator
var ipcIterator IPCStreamIterator
if ars.batchIterator != nil {
// If we have a batch iterator, extract its IPC stream iterator
if wrapper, ok := ars.batchIterator.(*batchIterator); ok {
ipcIterator = wrapper.ipcIterator
}
}

ri := NewArrowIPCStreamIterator(ctx, rpi, ipcIterator, ars.arrowSchemaBytes, cfg)
return ri, nil
}

// getArrowSchemaBytes returns the serialized schema in ipc format
func getArrowSchemaBytes(schema *arrow.Schema, ctx context.Context) ([]byte, dbsqlerr.DBError) {
if schema == nil {
Expand Down
16 changes: 7 additions & 9 deletions internal/rows/arrowbased/arrowRows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1041,10 +1041,8 @@ func TestArrowRowScanner(t *testing.T) {
ars := d.(*arrowRowScanner)
assert.Equal(t, int64(53940), ars.NRows())

bi, ok := ars.batchIterator.(*localBatchIterator)
assert.True(t, ok)
fbi := &batchIteratorWrapper{
bi: bi,
fbi := &testBatchIteratorWrapper{
bi: ars.batchIterator,
}

ars.batchIterator = fbi
Expand Down Expand Up @@ -1674,26 +1672,26 @@ func (fbi *fakeBatchIterator) Close() {
fbi.lastReadBatch = nil
}

type batchIteratorWrapper struct {
type testBatchIteratorWrapper struct {
bi BatchIterator
callCount int
lastLoadedBatch SparkArrowBatch
}

var _ BatchIterator = (*batchIteratorWrapper)(nil)
var _ BatchIterator = (*testBatchIteratorWrapper)(nil)

func (biw *batchIteratorWrapper) Next() (SparkArrowBatch, error) {
func (biw *testBatchIteratorWrapper) Next() (SparkArrowBatch, error) {
biw.callCount += 1
batch, err := biw.bi.Next()
biw.lastLoadedBatch = batch
return batch, err
}

func (biw *batchIteratorWrapper) HasNext() bool {
func (biw *testBatchIteratorWrapper) HasNext() bool {
return biw.bi.HasNext()
}

func (biw *batchIteratorWrapper) Close() {
func (biw *testBatchIteratorWrapper) Close() {
biw.bi.Close()
}

Expand Down
Loading
Loading