Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/rows/arrowbased/arrowRecordIterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (ri *arrowRecordIterator) getBatchIterator() error {
func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) {
rowSet := fr.Results
if len(rowSet.ResultLinks) > 0 {
return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg)
return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
} else {
return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/rows/arrowbased/arrowRows.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp
for _, resultLink := range rowSet.ResultLinks {
logger.Debug().Msgf("- start row offset: %d, row count: %d", resultLink.StartRowOffset, resultLink.RowCount)
}
bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg)
bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, schemaBytes, cfg)
} else {
bi, err2 = NewLocalBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg)
}
Expand Down
52 changes: 50 additions & 2 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (

"net/http"

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/ipc"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
Expand Down Expand Up @@ -55,18 +57,35 @@ func NewCloudIPCStreamIterator(
return bi, nil
}

// NewCloudBatchIterator creates a cloud-based BatchIterator for backward compatibility
// NewCloudBatchIterator creates a cloud-based BatchIterator for backward compatibility.
// arrowSchemaBytes is the authoritative schema from GetResultSetMetadata, used to
// override stale column names in cached Arrow IPC files.
func NewCloudBatchIterator(
ctx context.Context,
files []*cli_service.TSparkArrowResultLink,
startRowOffset int64,
arrowSchemaBytes []byte,
cfg *config.Config,
) (BatchIterator, dbsqlerr.DBError) {
ipcIterator, err := NewCloudIPCStreamIterator(ctx, files, startRowOffset, cfg)
if err != nil {
return nil, err
}
return NewBatchIterator(ipcIterator, startRowOffset), nil

var overrideSchema *arrow.Schema
if len(arrowSchemaBytes) > 0 {
var schemaErr error
overrideSchema, schemaErr = schemaFromIPCBytes(arrowSchemaBytes)
if schemaErr != nil {
logger.Warn().Msgf("CloudFetch: failed to parse override schema: %v", schemaErr)
}
}

return &batchIterator{
ipcIterator: ipcIterator,
startRowOffset: startRowOffset,
overrideSchema: overrideSchema,
}, nil
}

func NewLocalIPCStreamIterator(
Expand Down Expand Up @@ -400,6 +419,7 @@ type BatchIterator interface {
type batchIterator struct {
ipcIterator IPCStreamIterator
startRowOffset int64
overrideSchema *arrow.Schema // authoritative schema to fix stale CloudFetch column names
}

// NewBatchIterator creates a BatchIterator from an IPCStreamIterator
Expand All @@ -421,6 +441,24 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) {
return nil, err
}

// When using CloudFetch, cached Arrow IPC files may contain stale column
// names from a previous query. Replace the embedded schema with the
// authoritative schema from GetResultSetMetadata.
if bi.overrideSchema != nil && len(records) > 0 && len(bi.overrideSchema.Fields()) == len(records[0].Columns()) {
for i, rec := range records {
sar, ok := rec.(*sparkArrowRecord)
if !ok {
continue
}
corrected := array.NewRecord(bi.overrideSchema, sar.Columns(), sar.NumRows())
sar.Release()
records[i] = &sparkArrowRecord{
Delimiter: sar.Delimiter,
Record: corrected,
}
}
}

// Calculate total rows in this batch
totalRows := int64(0)
for _, record := range records {
Expand All @@ -443,3 +481,13 @@ func (bi *batchIterator) HasNext() bool {
func (bi *batchIterator) Close() {
bi.ipcIterator.Close()
}

// schemaFromIPCBytes parses Arrow schema bytes (IPC format) into an *arrow.Schema.
func schemaFromIPCBytes(schemaBytes []byte) (*arrow.Schema, error) {
reader, err := ipc.NewReader(bytes.NewReader(schemaBytes))
if err != nil {
return nil, err
}
defer reader.Release()
return reader.Schema(), nil
}
116 changes: 116 additions & 0 deletions internal/rows/arrowbased/batchloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func TestCloudFetchIterator(t *testing.T) {
context.Background(),
links,
startRowOffset,
nil,
cfg,
)
if err != nil {
Expand Down Expand Up @@ -150,6 +151,7 @@ func TestCloudFetchIterator(t *testing.T) {
context.Background(),
links,
startRowOffset,
nil,
cfg,
)
if err != nil {
Expand Down Expand Up @@ -208,6 +210,7 @@ func TestCloudFetchIterator(t *testing.T) {
context.Background(),
links,
startRowOffset,
nil,
cfg,
)
if err != nil {
Expand Down Expand Up @@ -282,6 +285,7 @@ func TestCloudFetchIterator(t *testing.T) {
RowCount: 1,
}},
startRowOffset,
nil,
cfg,
)
assert.Nil(t, err)
Expand Down Expand Up @@ -320,6 +324,7 @@ func TestCloudFetchIterator(t *testing.T) {
RowCount: 1,
}},
startRowOffset,
nil,
cfg,
)
assert.Nil(t, err)
Expand All @@ -334,6 +339,117 @@ func TestCloudFetchIterator(t *testing.T) {
})
}

func TestCloudFetchSchemaOverride(t *testing.T) {
// Reproduces ES-1804970: When the server result cache serves Arrow IPC files
// from a prior query, the embedded schema has stale column names. The
// authoritative schema from GetResultSetMetadata must override them.

// IPC data has columns ["id", "name"] (stale, from cached query)
staleRecord := generateArrowRecord()
staleIPCBytes := generateMockArrowBytes(staleRecord)

// Authoritative schema has columns ["x", "y"] (correct, from GetResultSetMetadata)
correctFields := []arrow.Field{
{Name: "x", Type: arrow.PrimitiveTypes.Int32},
{Name: "y", Type: arrow.BinaryTypes.String},
}
correctSchema := arrow.NewSchema(correctFields, nil)
var schemaBuf bytes.Buffer
schemaWriter := ipc.NewWriter(&schemaBuf, ipc.WithSchema(correctSchema))
if err := schemaWriter.Close(); err != nil {
t.Fatal(err)
}
correctSchemaBytes := schemaBuf.Bytes()

// Serve stale IPC data via mock HTTP
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, err := w.Write(staleIPCBytes)
if err != nil {
panic(err)
}
}))
defer server.Close()

t.Run("should override stale column names with authoritative schema", func(t *testing.T) {
links := []*cli_service.TSparkArrowResultLink{
{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: 0,
RowCount: 3,
},
}

cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.MaxDownloadThreads = 1

bi, err := NewCloudBatchIterator(
context.Background(),
links,
0,
correctSchemaBytes,
cfg,
)
assert.Nil(t, err)

batch, batchErr := bi.Next()
assert.Nil(t, batchErr)
assert.NotNil(t, batch)

rec, recErr := batch.Next()
assert.Nil(t, recErr)
assert.NotNil(t, rec)

// The record schema must use the authoritative names, not the stale ones
assert.Equal(t, "x", rec.Schema().Field(0).Name)
assert.Equal(t, "y", rec.Schema().Field(1).Name)

// Data must be preserved
assert.Equal(t, int64(3), rec.NumRows())
assert.Equal(t, 2, len(rec.Schema().Fields()))

rec.Release()
})

t.Run("should pass through unchanged when no override schema provided", func(t *testing.T) {
links := []*cli_service.TSparkArrowResultLink{
{
FileLink: server.URL,
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
StartRowOffset: 0,
RowCount: 3,
},
}

cfg := config.WithDefaults()
cfg.UseLz4Compression = false
cfg.MaxDownloadThreads = 1

bi, err := NewCloudBatchIterator(
context.Background(),
links,
0,
nil,
cfg,
)
assert.Nil(t, err)

batch, batchErr := bi.Next()
assert.Nil(t, batchErr)

rec, recErr := batch.Next()
assert.Nil(t, recErr)

// Without override, the original (stale) column names are preserved
assert.Equal(t, "id", rec.Schema().Field(0).Name)
assert.Equal(t, "name", rec.Schema().Field(1).Name)

rec.Release()
})
}

func generateArrowRecord() arrow.Record {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())

Expand Down
Loading