Skip to content

Commit 3c0f7e4

Browse files
authored
[ES-1804970] Fix CloudFetch returning stale column names from cached results (#351)
## Summary Fixes a bug where `arrow.Record.Schema()` returns stale column aliases when CloudFetch serves cached Arrow IPC files from a structurally identical prior query with different `AS` aliases. - **Root cause:** `NewCloudBatchIterator` was not receiving the authoritative schema bytes from `GetResultSetMetadata`, unlike the local batch path which already had this. CloudFetch Arrow IPC files have column names baked in from the original query, and the driver was reading them as-is. - **Fix:** Pass `arrowSchemaBytes` (the authoritative schema from `GetResultSetMetadata`) into `NewCloudBatchIterator`. After records are deserialized from the IPC stream, replace the stale schema with the authoritative one using `array.NewRecord()` (zero-copy — shares underlying column data, only swaps metadata). ## Changes - **`arrowRecordIterator.go`** — Pass `ri.arrowSchemaBytes` to `NewCloudBatchIterator` in `newBatchIterator()` - **`arrowRows.go`** — Pass `schemaBytes` to `NewCloudBatchIterator` in `NewArrowRowScanner()` - **`batchloader.go`** — Core fix: - `NewCloudBatchIterator` accepts `arrowSchemaBytes`, parses into `*arrow.Schema`, stores on `batchIterator` - `batchIterator.Next()` applies override schema to CloudFetch records only (local path is untouched, `overrideSchema` is `nil`) - Added `schemaFromIPCBytes()` helper - Field count validation guard to prevent panics on schema mismatch - Schema parse failure logged at `Warn` level - **`batchloader_test.go`** — Added `TestCloudFetchSchemaOverride` with two subtests: - Verifies stale column names `["id","name"]` are overridden to `["x","y"]` - Verifies `nil` schema bytes pass through original names unchanged ## Who is affected Go driver users with CloudFetch enabled (`WithCloudFetch(true)`) who read `arrow.Record.Schema()` directly. Python, ODBC, and JDBC drivers are not affected. ## Test plan - [x] All existing unit tests pass (37 tests in `internal/rows/arrowbased/`) - [x] New unit test `TestCloudFetchSchemaOverride` covers the override and no-override paths - [x] Verified end-to-end against a real Databricks warehouse using `samples.tpch.lineitem` (~30M rows) with two queries differing only in column aliases — confirmed `arrow.Record.Schema()` now returns correct aliases This pull request was AI-assisted by Isaac. --------- Signed-off-by: Sreekanth Vadigi <sreekanth.vadigi@databricks.com>
1 parent 305e3bc commit 3c0f7e4

4 files changed

Lines changed: 172 additions & 6 deletions

File tree

internal/rows/arrowbased/arrowRecordIterator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func (ri *arrowRecordIterator) getBatchIterator() error {
169169
func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) {
170170
rowSet := fr.Results
171171
if len(rowSet.ResultLinks) > 0 {
172-
return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg, nil)
172+
return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg, nil)
173173
} else {
174174
return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
175175
}

internal/rows/arrowbased/arrowRows.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp
121121
for _, resultLink := range rowSet.ResultLinks {
122122
logger.Debug().Msgf("- start row offset: %d, row count: %d", resultLink.StartRowOffset, resultLink.RowCount)
123123
}
124-
bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg, onCloudFetchDownload)
124+
bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, schemaBytes, cfg, onCloudFetchDownload)
125125
} else {
126126
bi, err2 = NewLocalBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg)
127127
}

internal/rows/arrowbased/batchloader.go

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import (
1515

1616
"net/http"
1717

18+
"github.com/apache/arrow/go/v12/arrow"
19+
"github.com/apache/arrow/go/v12/arrow/array"
1820
"github.com/apache/arrow/go/v12/arrow/ipc"
1921
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
2022
"github.com/databricks/databricks-sql-go/internal/cli_service"
@@ -57,19 +59,36 @@ func NewCloudIPCStreamIterator(
5759
return bi, nil
5860
}
5961

60-
// NewCloudBatchIterator creates a cloud-based BatchIterator for backward compatibility
62+
// NewCloudBatchIterator creates a cloud-based BatchIterator for backward compatibility.
63+
// arrowSchemaBytes is the authoritative schema from GetResultSetMetadata, used to
64+
// override stale column names in cached Arrow IPC files.
6165
func NewCloudBatchIterator(
6266
ctx context.Context,
6367
files []*cli_service.TSparkArrowResultLink,
6468
startRowOffset int64,
69+
arrowSchemaBytes []byte,
6570
cfg *config.Config,
6671
onFileDownloaded func(downloadMs int64),
6772
) (BatchIterator, dbsqlerr.DBError) {
6873
ipcIterator, err := NewCloudIPCStreamIterator(ctx, files, startRowOffset, cfg, onFileDownloaded)
6974
if err != nil {
7075
return nil, err
7176
}
72-
return NewBatchIterator(ipcIterator, startRowOffset), nil
77+
78+
var overrideSchema *arrow.Schema
79+
if len(arrowSchemaBytes) > 0 {
80+
var schemaErr error
81+
overrideSchema, schemaErr = schemaFromIPCBytes(arrowSchemaBytes)
82+
if schemaErr != nil {
83+
logger.Warn().Msgf("CloudFetch: failed to parse override schema: %v", schemaErr)
84+
}
85+
}
86+
87+
return &batchIterator{
88+
ipcIterator: ipcIterator,
89+
startRowOffset: startRowOffset,
90+
overrideSchema: overrideSchema,
91+
}, nil
7392
}
7493

7594
func NewLocalIPCStreamIterator(
@@ -416,6 +435,7 @@ type BatchIterator interface {
416435
type batchIterator struct {
417436
ipcIterator IPCStreamIterator
418437
startRowOffset int64
438+
overrideSchema *arrow.Schema // authoritative schema to fix stale CloudFetch column names
419439
}
420440

421441
// NewBatchIterator creates a BatchIterator from an IPCStreamIterator
@@ -437,6 +457,24 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) {
437457
return nil, err
438458
}
439459

460+
// When using CloudFetch, cached Arrow IPC files may contain stale column
461+
// names from a previous query. Replace the embedded schema with the
462+
// authoritative schema from GetResultSetMetadata.
463+
if bi.overrideSchema != nil && len(records) > 0 && len(bi.overrideSchema.Fields()) == len(records[0].Columns()) {
464+
for i, rec := range records {
465+
sar, ok := rec.(*sparkArrowRecord)
466+
if !ok {
467+
continue
468+
}
469+
corrected := array.NewRecord(bi.overrideSchema, sar.Columns(), sar.NumRows())
470+
sar.Release()
471+
records[i] = &sparkArrowRecord{
472+
Delimiter: sar.Delimiter,
473+
Record: corrected,
474+
}
475+
}
476+
}
477+
440478
// Calculate total rows in this batch
441479
totalRows := int64(0)
442480
for _, record := range records {
@@ -459,3 +497,13 @@ func (bi *batchIterator) HasNext() bool {
459497
func (bi *batchIterator) Close() {
460498
bi.ipcIterator.Close()
461499
}
500+
501+
// schemaFromIPCBytes parses Arrow schema bytes (IPC format) into an *arrow.Schema.
502+
func schemaFromIPCBytes(schemaBytes []byte) (*arrow.Schema, error) {
503+
reader, err := ipc.NewReader(bytes.NewReader(schemaBytes))
504+
if err != nil {
505+
return nil, err
506+
}
507+
defer reader.Release()
508+
return reader.Schema(), nil
509+
}

internal/rows/arrowbased/batchloader_test.go

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ func TestCloudFetchIterator(t *testing.T) {
7777
context.Background(),
7878
links,
7979
startRowOffset,
80+
nil,
8081
cfg,
8182
nil,
8283
)
@@ -152,6 +153,7 @@ func TestCloudFetchIterator(t *testing.T) {
152153
context.Background(),
153154
links,
154155
startRowOffset,
156+
nil,
155157
cfg,
156158
nil,
157159
)
@@ -211,6 +213,7 @@ func TestCloudFetchIterator(t *testing.T) {
211213
context.Background(),
212214
links,
213215
startRowOffset,
216+
nil,
214217
cfg,
215218
nil,
216219
)
@@ -286,6 +289,7 @@ func TestCloudFetchIterator(t *testing.T) {
286289
RowCount: 1,
287290
}},
288291
startRowOffset,
292+
nil,
289293
cfg,
290294
nil,
291295
)
@@ -325,6 +329,7 @@ func TestCloudFetchIterator(t *testing.T) {
325329
RowCount: 1,
326330
}},
327331
startRowOffset,
332+
nil,
328333
cfg,
329334
nil,
330335
)
@@ -340,6 +345,119 @@ func TestCloudFetchIterator(t *testing.T) {
340345
})
341346
}
342347

348+
func TestCloudFetchSchemaOverride(t *testing.T) {
349+
// Reproduces ES-1804970: When the server result cache serves Arrow IPC files
350+
// from a prior query, the embedded schema has stale column names. The
351+
// authoritative schema from GetResultSetMetadata must override them.
352+
353+
// IPC data has columns ["id", "name"] (stale, from cached query)
354+
staleRecord := generateArrowRecord()
355+
staleIPCBytes := generateMockArrowBytes(staleRecord)
356+
357+
// Authoritative schema has columns ["x", "y"] (correct, from GetResultSetMetadata)
358+
correctFields := []arrow.Field{
359+
{Name: "x", Type: arrow.PrimitiveTypes.Int32},
360+
{Name: "y", Type: arrow.BinaryTypes.String},
361+
}
362+
correctSchema := arrow.NewSchema(correctFields, nil)
363+
var schemaBuf bytes.Buffer
364+
schemaWriter := ipc.NewWriter(&schemaBuf, ipc.WithSchema(correctSchema))
365+
if err := schemaWriter.Close(); err != nil {
366+
t.Fatal(err)
367+
}
368+
correctSchemaBytes := schemaBuf.Bytes()
369+
370+
// Serve stale IPC data via mock HTTP
371+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
372+
w.WriteHeader(http.StatusOK)
373+
_, err := w.Write(staleIPCBytes)
374+
if err != nil {
375+
panic(err)
376+
}
377+
}))
378+
defer server.Close()
379+
380+
t.Run("should override stale column names with authoritative schema", func(t *testing.T) {
381+
links := []*cli_service.TSparkArrowResultLink{
382+
{
383+
FileLink: server.URL,
384+
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
385+
StartRowOffset: 0,
386+
RowCount: 3,
387+
},
388+
}
389+
390+
cfg := config.WithDefaults()
391+
cfg.UseLz4Compression = false
392+
cfg.MaxDownloadThreads = 1
393+
394+
bi, err := NewCloudBatchIterator(
395+
context.Background(),
396+
links,
397+
0,
398+
correctSchemaBytes,
399+
cfg,
400+
nil,
401+
)
402+
assert.Nil(t, err)
403+
404+
batch, batchErr := bi.Next()
405+
assert.Nil(t, batchErr)
406+
assert.NotNil(t, batch)
407+
408+
rec, recErr := batch.Next()
409+
assert.Nil(t, recErr)
410+
assert.NotNil(t, rec)
411+
412+
// The record schema must use the authoritative names, not the stale ones
413+
assert.Equal(t, "x", rec.Schema().Field(0).Name)
414+
assert.Equal(t, "y", rec.Schema().Field(1).Name)
415+
416+
// Data must be preserved
417+
assert.Equal(t, int64(3), rec.NumRows())
418+
assert.Equal(t, 2, len(rec.Schema().Fields()))
419+
420+
rec.Release()
421+
})
422+
423+
t.Run("should pass through unchanged when no override schema provided", func(t *testing.T) {
424+
links := []*cli_service.TSparkArrowResultLink{
425+
{
426+
FileLink: server.URL,
427+
ExpiryTime: time.Now().Add(10 * time.Minute).Unix(),
428+
StartRowOffset: 0,
429+
RowCount: 3,
430+
},
431+
}
432+
433+
cfg := config.WithDefaults()
434+
cfg.UseLz4Compression = false
435+
cfg.MaxDownloadThreads = 1
436+
437+
bi, err := NewCloudBatchIterator(
438+
context.Background(),
439+
links,
440+
0,
441+
nil,
442+
cfg,
443+
nil,
444+
)
445+
assert.Nil(t, err)
446+
447+
batch, batchErr := bi.Next()
448+
assert.Nil(t, batchErr)
449+
450+
rec, recErr := batch.Next()
451+
assert.Nil(t, recErr)
452+
453+
// Without override, the original (stale) column names are preserved
454+
assert.Equal(t, "id", rec.Schema().Field(0).Name)
455+
assert.Equal(t, "name", rec.Schema().Field(1).Name)
456+
457+
rec.Release()
458+
})
459+
}
460+
343461
// TestCloudFetchIterator_OnFileDownloaded_CallbackInvokedWithPositiveDuration verifies
344462
// that the onFileDownloaded telemetry callback is called once per downloaded S3 file with
345463
// a positive downloadMs value.
@@ -388,7 +506,7 @@ func TestCloudFetchIterator_OnFileDownloaded_CallbackInvokedWithPositiveDuration
388506
callbackMu.Unlock()
389507
}
390508

391-
bi, err := NewCloudBatchIterator(context.Background(), links, startRowOffset, cfg, onFileDownloaded)
509+
bi, err := NewCloudBatchIterator(context.Background(), links, startRowOffset, nil, cfg, onFileDownloaded)
392510
assert.Nil(t, err)
393511

394512
// Consume all batches to trigger the downloads.
@@ -439,7 +557,7 @@ func TestCloudFetchIterator_OnFileDownloaded_NilCallbackDoesNotPanic(t *testing.
439557
cfg.MaxDownloadThreads = 1
440558

441559
// nil callback — must not panic
442-
bi, err := NewCloudBatchIterator(context.Background(), links, startRowOffset, cfg, nil)
560+
bi, err := NewCloudBatchIterator(context.Background(), links, startRowOffset, nil, cfg, nil)
443561
assert.Nil(t, err)
444562

445563
assert.NotPanics(t, func() {

0 commit comments

Comments
 (0)