diff --git a/internal/rows/arrowbased/arrowRecordIterator.go b/internal/rows/arrowbased/arrowRecordIterator.go index 787a0ba..c7ba8cc 100644 --- a/internal/rows/arrowbased/arrowRecordIterator.go +++ b/internal/rows/arrowbased/arrowRecordIterator.go @@ -169,7 +169,7 @@ func (ri *arrowRecordIterator) getBatchIterator() error { func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) { rowSet := fr.Results if len(rowSet.ResultLinks) > 0 { - return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) + return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) } else { return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) } diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index 4e2cf80..47aeede 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -119,7 +119,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp for _, resultLink := range rowSet.ResultLinks { logger.Debug().Msgf("- start row offset: %d, row count: %d", resultLink.StartRowOffset, resultLink.RowCount) } - bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) + bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, schemaBytes, cfg) } else { bi, err2 = NewLocalBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) } diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index e12ea4e..460dd80 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -15,6 +15,8 @@ import ( "net/http" + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" @@ -55,18 +57,35 @@ func NewCloudIPCStreamIterator( return bi, nil } -// NewCloudBatchIterator creates a cloud-based BatchIterator for backward compatibility +// NewCloudBatchIterator creates a cloud-based BatchIterator for backward compatibility. +// arrowSchemaBytes is the authoritative schema from GetResultSetMetadata, used to +// override stale column names in cached Arrow IPC files. func NewCloudBatchIterator( ctx context.Context, files []*cli_service.TSparkArrowResultLink, startRowOffset int64, + arrowSchemaBytes []byte, cfg *config.Config, ) (BatchIterator, dbsqlerr.DBError) { ipcIterator, err := NewCloudIPCStreamIterator(ctx, files, startRowOffset, cfg) if err != nil { return nil, err } - return NewBatchIterator(ipcIterator, startRowOffset), nil + + var overrideSchema *arrow.Schema + if len(arrowSchemaBytes) > 0 { + var schemaErr error + overrideSchema, schemaErr = schemaFromIPCBytes(arrowSchemaBytes) + if schemaErr != nil { + logger.Warn().Msgf("CloudFetch: failed to parse override schema: %v", schemaErr) + } + } + + return &batchIterator{ + ipcIterator: ipcIterator, + startRowOffset: startRowOffset, + overrideSchema: overrideSchema, + }, nil } func NewLocalIPCStreamIterator( @@ -400,6 +419,7 @@ type BatchIterator interface { type batchIterator struct { ipcIterator IPCStreamIterator startRowOffset int64 + overrideSchema *arrow.Schema // authoritative schema to fix stale CloudFetch column names } // NewBatchIterator creates a BatchIterator from an IPCStreamIterator @@ -421,6 +441,24 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) { return nil, err } + // When using CloudFetch, cached Arrow IPC files may contain stale column + // names from a previous query. Replace the embedded schema with the + // authoritative schema from GetResultSetMetadata. + if bi.overrideSchema != nil && len(records) > 0 && len(bi.overrideSchema.Fields()) == len(records[0].Columns()) { + for i, rec := range records { + sar, ok := rec.(*sparkArrowRecord) + if !ok { + continue + } + corrected := array.NewRecord(bi.overrideSchema, sar.Columns(), sar.NumRows()) + sar.Release() + records[i] = &sparkArrowRecord{ + Delimiter: sar.Delimiter, + Record: corrected, + } + } + } + // Calculate total rows in this batch totalRows := int64(0) for _, record := range records { @@ -443,3 +481,13 @@ func (bi *batchIterator) HasNext() bool { func (bi *batchIterator) Close() { bi.ipcIterator.Close() } + +// schemaFromIPCBytes parses Arrow schema bytes (IPC format) into an *arrow.Schema. +func schemaFromIPCBytes(schemaBytes []byte) (*arrow.Schema, error) { + reader, err := ipc.NewReader(bytes.NewReader(schemaBytes)) + if err != nil { + return nil, err + } + defer reader.Release() + return reader.Schema(), nil +} diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 99538bb..8d17274 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -76,6 +76,7 @@ func TestCloudFetchIterator(t *testing.T) { context.Background(), links, startRowOffset, + nil, cfg, ) if err != nil { @@ -150,6 +151,7 @@ func TestCloudFetchIterator(t *testing.T) { context.Background(), links, startRowOffset, + nil, cfg, ) if err != nil { @@ -208,6 +210,7 @@ func TestCloudFetchIterator(t *testing.T) { context.Background(), links, startRowOffset, + nil, cfg, ) if err != nil { @@ -282,6 +285,7 @@ func TestCloudFetchIterator(t *testing.T) { RowCount: 1, }}, startRowOffset, + nil, cfg, ) assert.Nil(t, err) @@ -320,6 +324,7 @@ func TestCloudFetchIterator(t *testing.T) { RowCount: 1, }}, startRowOffset, + nil, cfg, ) assert.Nil(t, err) @@ -334,6 +339,117 @@ func TestCloudFetchIterator(t *testing.T) { }) } +func TestCloudFetchSchemaOverride(t *testing.T) { + // Reproduces ES-1804970: When the server result cache serves Arrow IPC files + // from a prior query, the embedded schema has stale column names. The + // authoritative schema from GetResultSetMetadata must override them. + + // IPC data has columns ["id", "name"] (stale, from cached query) + staleRecord := generateArrowRecord() + staleIPCBytes := generateMockArrowBytes(staleRecord) + + // Authoritative schema has columns ["x", "y"] (correct, from GetResultSetMetadata) + correctFields := []arrow.Field{ + {Name: "x", Type: arrow.PrimitiveTypes.Int32}, + {Name: "y", Type: arrow.BinaryTypes.String}, + } + correctSchema := arrow.NewSchema(correctFields, nil) + var schemaBuf bytes.Buffer + schemaWriter := ipc.NewWriter(&schemaBuf, ipc.WithSchema(correctSchema)) + if err := schemaWriter.Close(); err != nil { + t.Fatal(err) + } + correctSchemaBytes := schemaBuf.Bytes() + + // Serve stale IPC data via mock HTTP + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(staleIPCBytes) + if err != nil { + panic(err) + } + })) + defer server.Close() + + t.Run("should override stale column names with authoritative schema", func(t *testing.T) { + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: 0, + RowCount: 3, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + 0, + correctSchemaBytes, + cfg, + ) + assert.Nil(t, err) + + batch, batchErr := bi.Next() + assert.Nil(t, batchErr) + assert.NotNil(t, batch) + + rec, recErr := batch.Next() + assert.Nil(t, recErr) + assert.NotNil(t, rec) + + // The record schema must use the authoritative names, not the stale ones + assert.Equal(t, "x", rec.Schema().Field(0).Name) + assert.Equal(t, "y", rec.Schema().Field(1).Name) + + // Data must be preserved + assert.Equal(t, int64(3), rec.NumRows()) + assert.Equal(t, 2, len(rec.Schema().Fields())) + + rec.Release() + }) + + t.Run("should pass through unchanged when no override schema provided", func(t *testing.T) { + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: 0, + RowCount: 3, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + 0, + nil, + cfg, + ) + assert.Nil(t, err) + + batch, batchErr := bi.Next() + assert.Nil(t, batchErr) + + rec, recErr := batch.Next() + assert.Nil(t, recErr) + + // Without override, the original (stale) column names are preserved + assert.Equal(t, "id", rec.Schema().Field(0).Name) + assert.Equal(t, "name", rec.Schema().Field(1).Name) + + rec.Release() + }) +} + func generateArrowRecord() arrow.Record { mem := memory.NewCheckedAllocator(memory.NewGoAllocator())