Skip to content

Commit 12801b3

Browse files
committed
gofmt
Signed-off-by: Jane Doe <jane@example.com>
1 parent 12d2ced commit 12801b3

3 files changed

Lines changed: 186 additions & 0 deletions

File tree

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package arrowbased
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"io"
7+
8+
"github.com/databricks/databricks-sql-go/internal/cli_service"
9+
"github.com/databricks/databricks-sql-go/internal/config"
10+
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
11+
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
12+
"github.com/pierrec/lz4/v4"
13+
)
14+
15+
// ipcStreamIterator provides access to raw Arrow IPC streams without deserialization
16+
type ipcStreamIterator struct {
17+
ctx context.Context
18+
resultPageIterator rowscanner.ResultPageIterator
19+
currentBatches []*cli_service.TSparkArrowBatch
20+
currentIndex int
21+
arrowSchemaBytes []byte
22+
useLz4 bool
23+
hasMorePages bool
24+
}
25+
26+
// NewIPCStreamIterator creates an iterator that returns raw IPC streams
27+
func NewIPCStreamIterator(
28+
ctx context.Context,
29+
resultPageIterator rowscanner.ResultPageIterator,
30+
initialRowSet *cli_service.TRowSet,
31+
schemaBytes []byte,
32+
cfg *config.Config,
33+
) (dbsqlrows.IPCStreamIterator, error) {
34+
var useLz4 bool
35+
if cfg != nil {
36+
useLz4 = cfg.UseLz4Compression
37+
}
38+
39+
var batches []*cli_service.TSparkArrowBatch
40+
if initialRowSet != nil {
41+
batches = initialRowSet.ArrowBatches
42+
}
43+
44+
return &ipcStreamIterator{
45+
ctx: ctx,
46+
resultPageIterator: resultPageIterator,
47+
currentBatches: batches,
48+
currentIndex: 0,
49+
arrowSchemaBytes: schemaBytes,
50+
useLz4: useLz4,
51+
hasMorePages: resultPageIterator != nil && resultPageIterator.HasNext(),
52+
}, nil
53+
}
54+
55+
// NextIPCStream returns the next Arrow batch as a raw IPC stream
56+
func (it *ipcStreamIterator) NextIPCStream() (io.Reader, error) {
57+
// Check if we need to load more batches from the next page
58+
if it.currentIndex >= len(it.currentBatches) {
59+
if !it.hasMorePages || it.resultPageIterator == nil {
60+
return nil, io.EOF
61+
}
62+
63+
// Fetch next page
64+
fetchResult, err := it.resultPageIterator.Next()
65+
if err != nil {
66+
return nil, err
67+
}
68+
69+
if fetchResult == nil || fetchResult.Results == nil || fetchResult.Results.ArrowBatches == nil {
70+
return nil, io.EOF
71+
}
72+
73+
it.currentBatches = fetchResult.Results.ArrowBatches
74+
it.currentIndex = 0
75+
it.hasMorePages = it.resultPageIterator.HasNext()
76+
77+
// If no batches in this page, recurse to try next page
78+
if len(it.currentBatches) == 0 {
79+
return it.NextIPCStream()
80+
}
81+
}
82+
83+
batch := it.currentBatches[it.currentIndex]
84+
it.currentIndex++
85+
86+
// Create reader for the batch data
87+
var batchReader io.Reader = bytes.NewReader(batch.Batch)
88+
89+
// Handle LZ4 decompression if needed
90+
if it.useLz4 {
91+
batchReader = lz4.NewReader(batchReader)
92+
}
93+
94+
// Combine schema and batch data into a complete IPC stream
95+
// Arrow IPC format expects: [Schema][Batch1][Batch2]...
96+
return io.MultiReader(
97+
bytes.NewReader(it.arrowSchemaBytes),
98+
batchReader,
99+
), nil
100+
}
101+
102+
// HasNext returns true if there are more batches
103+
func (it *ipcStreamIterator) HasNext() bool {
104+
return it.currentIndex < len(it.currentBatches) || it.hasMorePages
105+
}
106+
107+
// Close releases any resources
108+
func (it *ipcStreamIterator) Close() {
109+
// Nothing to close for this implementation
110+
}
111+
112+
// GetSchemaBytes returns the Arrow schema in IPC format
113+
func (it *ipcStreamIterator) GetSchemaBytes() ([]byte, error) {
114+
return it.arrowSchemaBytes, nil
115+
}

internal/rows/rows.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7+
"fmt"
78
"math"
89
"reflect"
910
"time"
@@ -527,6 +528,14 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger {
527528
return r.logger_
528529
}
529530

531+
// getArrowSchemaBytes converts the table schema to Arrow IPC format bytes
532+
func (r *rows) getArrowSchemaBytes(schema *cli_service.TTableSchema) ([]byte, error) {
533+
// We need to use the arrow-based row scanner's conversion methods
534+
// This is a temporary solution - ideally this would be refactored to share code
535+
// For now, delegate to the arrowbased package
536+
return nil, fmt.Errorf("schema conversion not yet implemented - use ArrowSchema from metadata")
537+
}
538+
530539
func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterator, error) {
531540
// update context with correlationId and connectionId which will be used in logging and errors
532541
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId)
@@ -539,3 +548,37 @@ func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterato
539548

540549
return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil
541550
}
551+
552+
// GetIPCStreams returns an iterator that provides raw Arrow IPC streams
553+
func (r *rows) GetIPCStreams(ctx context.Context) (dbsqlrows.IPCStreamIterator, error) {
554+
// Update context with correlationId and connectionId
555+
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId)
556+
557+
// First try to get Arrow schema bytes from metadata if available
558+
var schemaBytes []byte
559+
if r.resultSetMetadata != nil && r.resultSetMetadata.ArrowSchema != nil {
560+
schemaBytes = r.resultSetMetadata.ArrowSchema
561+
} else {
562+
// Fall back to generating from table schema
563+
schema, err := r.getResultSetSchema()
564+
if err != nil {
565+
return nil, dbsqlerr_int.NewDriverError(ctx, "failed to get result set schema", err)
566+
}
567+
568+
// Convert schema to IPC format bytes
569+
var err2 error
570+
schemaBytes, err2 = r.getArrowSchemaBytes(schema)
571+
if err2 != nil {
572+
return nil, dbsqlerr_int.NewDriverError(ctx, "failed to convert schema to IPC format", err2)
573+
}
574+
}
575+
576+
// Create IPC stream iterator
577+
return arrowbased.NewIPCStreamIterator(
578+
ctx,
579+
r.ResultPageIterator,
580+
nil, // We don't have access to initial row set here
581+
schemaBytes,
582+
r.config,
583+
)
584+
}

rows/ipc_stream.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package rows
2+
3+
import (
4+
"io"
5+
)
6+
7+
// IPCStreamIterator provides access to raw Arrow IPC streams
8+
type IPCStreamIterator interface {
9+
// GetNextIPCStream returns the next Arrow batch as an IPC stream reader
10+
// Returns io.EOF when no more batches are available
11+
NextIPCStream() (io.Reader, error)
12+
13+
// HasNext returns true if there are more batches
14+
HasNext() bool
15+
16+
// Close releases any resources
17+
Close()
18+
19+
// GetSchemaBytes returns the Arrow schema in IPC format
20+
GetSchemaBytes() ([]byte, error)
21+
}
22+
23+
// Extension to existing Rows interface
24+
type RowsWithIPCStream interface {
25+
Rows
26+
// GetIPCStreams returns an iterator for raw Arrow IPC streams
27+
GetIPCStreams() (IPCStreamIterator, error)
28+
}

0 commit comments

Comments
 (0)