Skip to content

Commit 65a8750

Browse files
committed
[ES-1804970] Fix CloudFetch returning stale column names from cached results
When the server result cache serves Arrow IPC files from a prior query, the embedded schema contains stale column aliases. The Go driver's CloudFetch path read these stale names directly, while the local path already used the authoritative schema from GetResultSetMetadata. Pass the authoritative schema bytes into NewCloudBatchIterator and replace stale column names on deserialized records using array.NewRecord, which is zero-copy (shares underlying column data). Co-authored-by: Isaac Signed-off-by: Sreekanth Vadigi <sreekanth.vadigi@databricks.com>
1 parent 6dd935f commit 65a8750

File tree

4 files changed

+163
-4
lines changed

4 files changed

+163
-4
lines changed

internal/rows/arrowbased/arrowRecordIterator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func (ri *arrowRecordIterator) getBatchIterator() error {
169169
func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) {
170170
rowSet := fr.Results
171171
if len(rowSet.ResultLinks) > 0 {
172-
return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg)
172+
return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
173173
} else {
174174
return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
175175
}

internal/rows/arrowbased/arrowRows.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp
119119
for _, resultLink := range rowSet.ResultLinks {
120120
logger.Debug().Msgf("- start row offset: %d, row count: %d", resultLink.StartRowOffset, resultLink.RowCount)
121121
}
122-
bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg)
122+
bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, schemaBytes, cfg)
123123
} else {
124124
bi, err2 = NewLocalBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg)
125125
}

internal/rows/arrowbased/batchloader.go

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import (
1515

1616
"net/http"
1717

18+
"github.com/apache/arrow/go/v12/arrow"
19+
"github.com/apache/arrow/go/v12/arrow/array"
1820
"github.com/apache/arrow/go/v12/arrow/ipc"
1921
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
2022
"github.com/databricks/databricks-sql-go/internal/cli_service"
@@ -55,18 +57,35 @@ func NewCloudIPCStreamIterator(
5557
return bi, nil
5658
}
5759

58-
// NewCloudBatchIterator creates a cloud-based BatchIterator for backward compatibility
60+
// NewCloudBatchIterator creates a cloud-based BatchIterator for backward compatibility.
61+
// arrowSchemaBytes is the authoritative schema from GetResultSetMetadata, used to
62+
// override stale column names in cached Arrow IPC files.
5963
func NewCloudBatchIterator(
6064
ctx context.Context,
6165
files []*cli_service.TSparkArrowResultLink,
6266
startRowOffset int64,
67+
arrowSchemaBytes []byte,
6368
cfg *config.Config,
6469
) (BatchIterator, dbsqlerr.DBError) {
6570
ipcIterator, err := NewCloudIPCStreamIterator(ctx, files, startRowOffset, cfg)
6671
if err != nil {
6772
return nil, err
6873
}
69-
return NewBatchIterator(ipcIterator, startRowOffset), nil
74+
75+
var overrideSchema *arrow.Schema
76+
if len(arrowSchemaBytes) > 0 {
77+
var schemaErr error
78+
overrideSchema, schemaErr = schemaFromIPCBytes(arrowSchemaBytes)
79+
if schemaErr != nil {
80+
logger.Warn().Msgf("CloudFetch: failed to parse override schema: %v", schemaErr)
81+
}
82+
}
83+
84+
return &batchIterator{
85+
ipcIterator: ipcIterator,
86+
startRowOffset: startRowOffset,
87+
overrideSchema: overrideSchema,
88+
}, nil
7089
}
7190

7291
func NewLocalIPCStreamIterator(
@@ -400,6 +419,7 @@ type BatchIterator interface {
400419
type batchIterator struct {
401420
ipcIterator IPCStreamIterator
402421
startRowOffset int64
422+
overrideSchema *arrow.Schema // authoritative schema to fix stale CloudFetch column names
403423
}
404424

405425
// NewBatchIterator creates a BatchIterator from an IPCStreamIterator
@@ -421,6 +441,24 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) {
421441
return nil, err
422442
}
423443

444+
// When using CloudFetch, cached Arrow IPC files may contain stale column
445+
// names from a previous query. Replace the embedded schema with the
446+
// authoritative schema from GetResultSetMetadata.
447+
if bi.overrideSchema != nil && len(records) > 0 && len(bi.overrideSchema.Fields()) == len(records[0].Columns()) {
448+
for i, rec := range records {
449+
sar, ok := rec.(*sparkArrowRecord)
450+
if !ok {
451+
continue
452+
}
453+
corrected := array.NewRecord(bi.overrideSchema, sar.Columns(), sar.NumRows())
454+
sar.Release()
455+
records[i] = &sparkArrowRecord{
456+
Delimiter: sar.Delimiter,
457+
Record: corrected,
458+
}
459+
}
460+
}
461+
424462
// Calculate total rows in this batch
425463
totalRows := int64(0)
426464
for _, record := range records {
@@ -443,3 +481,13 @@ func (bi *batchIterator) HasNext() bool {
443481
func (bi *batchIterator) Close() {
444482
bi.ipcIterator.Close()
445483
}
484+
485+
// schemaFromIPCBytes parses Arrow schema bytes (IPC format) into an *arrow.Schema.
486+
func schemaFromIPCBytes(schemaBytes []byte) (*arrow.Schema, error) {
487+
reader, err := ipc.NewReader(bytes.NewReader(schemaBytes))
488+
if err != nil {
489+
return nil, err
490+
}
491+
defer reader.Release()
492+
return reader.Schema(), nil
493+
}

internal/rows/arrowbased/batchloader_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ func TestCloudFetchIterator(t *testing.T) {
7676
context.Background(),
7777
links,
7878
startRowOffset,
79+
nil,
7980
cfg,
8081
)
8182
if err != nil {
@@ -150,6 +151,7 @@ func TestCloudFetchIterator(t *testing.T) {
150151
context.Background(),
151152
links,
152153
startRowOffset,
154+
nil,
153155
cfg,
154156
)
155157
if err != nil {
@@ -208,6 +210,7 @@ func TestCloudFetchIterator(t *testing.T) {
208210
context.Background(),
209211
links,
210212
startRowOffset,
213+
nil,
211214
cfg,
212215
)
213216
if err != nil {
@@ -282,6 +285,7 @@ func TestCloudFetchIterator(t *testing.T) {
282285
RowCount: 1,
283286
}},
284287
startRowOffset,
288+
nil,
285289
cfg,
286290
)
287291
assert.Nil(t, err)
@@ -320,6 +324,7 @@ func TestCloudFetchIterator(t *testing.T) {
320324
RowCount: 1,
321325
}},
322326
startRowOffset,
327+
nil,
323328
cfg,
324329
)
325330
assert.Nil(t, err)
@@ -334,6 +339,112 @@ func TestCloudFetchIterator(t *testing.T) {
334339
})
335340
}
336341

342+
func TestCloudFetchSchemaOverride(t *testing.T) {
343+
// Reproduces ES-1804970: When the server result cache serves Arrow IPC files
344+
// from a prior query, the embedded schema has stale column names. The
345+
// authoritative schema from GetResultSetMetadata must override them.
346+
347+
// IPC data has columns ["id", "name"] (stale, from cached query)
348+
staleRecord := generateArrowRecord()
349+
staleIPCBytes := generateMockArrowBytes(staleRecord)
350+
351+
// Authoritative schema has columns ["x", "y"] (correct, from GetResultSetMetadata)
352+
correctFields := []arrow.Field{
353+
{Name: "x", Type: arrow.PrimitiveTypes.Int32},
354+
{Name: "y", Type: arrow.BinaryTypes.String},
355+
}
356+
correctSchema := arrow.NewSchema(correctFields, nil)
357+
var schemaBuf bytes.Buffer
358+
schemaWriter := ipc.NewWriter(&schemaBuf, ipc.WithSchema(correctSchema))
359+
schemaWriter.Close()
360+
correctSchemaBytes := schemaBuf.Bytes()
361+
362+
// Serve stale IPC data via mock HTTP
363+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
364+
w.WriteHeader(http.StatusOK)
365+
w.Write(staleIPCBytes)
366+
}))
367+
defer server.Close()
368+
369+
t.Run("should override stale column names with authoritative schema", func(t *testing.T) {
370+
links := []*cli_service.TSparkArrowResultLink{
371+
{
372+
FileLink: server.URL,
373+
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
374+
StartRowOffset: 0,
375+
RowCount: 3,
376+
},
377+
}
378+
379+
cfg := config.WithDefaults()
380+
cfg.UseLz4Compression = false
381+
cfg.MaxDownloadThreads = 1
382+
383+
bi, err := NewCloudBatchIterator(
384+
context.Background(),
385+
links,
386+
0,
387+
correctSchemaBytes,
388+
cfg,
389+
)
390+
assert.Nil(t, err)
391+
392+
batch, batchErr := bi.Next()
393+
assert.Nil(t, batchErr)
394+
assert.NotNil(t, batch)
395+
396+
rec, recErr := batch.Next()
397+
assert.Nil(t, recErr)
398+
assert.NotNil(t, rec)
399+
400+
// The record schema must use the authoritative names, not the stale ones
401+
assert.Equal(t, "x", rec.Schema().Field(0).Name)
402+
assert.Equal(t, "y", rec.Schema().Field(1).Name)
403+
404+
// Data must be preserved
405+
assert.Equal(t, int64(3), rec.NumRows())
406+
assert.Equal(t, 2, len(rec.Schema().Fields()))
407+
408+
rec.Release()
409+
})
410+
411+
t.Run("should pass through unchanged when no override schema provided", func(t *testing.T) {
412+
links := []*cli_service.TSparkArrowResultLink{
413+
{
414+
FileLink: server.URL,
415+
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
416+
StartRowOffset: 0,
417+
RowCount: 3,
418+
},
419+
}
420+
421+
cfg := config.WithDefaults()
422+
cfg.UseLz4Compression = false
423+
cfg.MaxDownloadThreads = 1
424+
425+
bi, err := NewCloudBatchIterator(
426+
context.Background(),
427+
links,
428+
0,
429+
nil,
430+
cfg,
431+
)
432+
assert.Nil(t, err)
433+
434+
batch, batchErr := bi.Next()
435+
assert.Nil(t, batchErr)
436+
437+
rec, recErr := batch.Next()
438+
assert.Nil(t, recErr)
439+
440+
// Without override, the original (stale) column names are preserved
441+
assert.Equal(t, "id", rec.Schema().Field(0).Name)
442+
assert.Equal(t, "name", rec.Schema().Field(1).Name)
443+
444+
rec.Release()
445+
})
446+
}
447+
337448
func generateArrowRecord() arrow.Record {
338449
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
339450

0 commit comments

Comments
 (0)