Skip to content

Commit 4008bbe

Browse files
authored
Add general-purpose SQL Statement Execution engine (#5416)
## Summary Adds `libs/sqlexec`, a general-purpose, non-interactive engine for running SQL through the Databricks SQL Statement Execution API, and refactors the experimental aitools query commands to use it instead of each re-implementing the submit/poll/fetch loop. A `Client` binds to a single SQL warehouse and exposes the full lifecycle: `Submit` (async, returns immediately with a statement ID so callers can wire up cancellation), `Poll` (additive backoff between status checks), `Get`, `Cancel`, `Results`, and the convenience wrappers `Execute` and `ExecuteScalar`. Failures surface as a typed `*StatementError` carrying the terminal `State`, `Code`, and `Message`, so callers compare with `errors.As` rather than string-matching. The `Client` holds no mutable state and is safe for concurrent use — aitools fans many statements out through one instance. The engine speaks only the INLINE disposition with JSON_ARRAY format (the API caps this at 25 MiB per result set), which covers every caller today; EXTERNAL_LINKS is intentionally left out as a separate concern. It exists to be shared by programmatic callers such as bundle deploy resources (e.g. metric views, which have no REST API and are managed via SQL DDL) and the aitools commands. The aitools consumers (`query.go`, `batch.go`, `discover_schema.go`, `statement*.go`) are reworked to delegate to the engine, removing the duplicated polling and result-assembly code (net ~549 lines deleted from those files). ## Test plan - Hermetic unit + HTTP tests in `libs/sqlexec` covering path/response decoding, polling, cancellation, parameters, and error mapping. - Live integration coverage in `integration/libs/sqlexec` (skips without `CLOUD_ENV` / `TEST_DEFAULT_WAREHOUSE_ID`); all 6 tests verified green against a real workspace. This pull request and its description were written by Isaac.
1 parent 520ab83 commit 4008bbe

17 files changed

Lines changed: 1192 additions & 549 deletions

experimental/aitools/cmd/batch.go

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/databricks/cli/libs/cmdio"
1414
"github.com/databricks/cli/libs/log"
15+
"github.com/databricks/cli/libs/sqlexec"
1516
"github.com/databricks/databricks-sdk-go/service/sql"
1617
"golang.org/x/sync/errgroup"
1718
)
@@ -58,6 +59,8 @@ type batchResultError struct {
5859
// reused across the batch, so callers must ensure each SQL uses only markers
5960
// that are covered.
6061
func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) []batchResult {
62+
client := sqlexec.New(api, warehouseID)
63+
6164
pollCtx, pollCancel := context.WithCancel(ctx)
6265
defer pollCancel()
6366

@@ -101,51 +104,45 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware
101104
g.SetLimit(concurrency)
102105
for i, sqlStr := range sqls {
103106
g.Go(func() error {
104-
results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, params, statementIDs, i)
107+
results[i] = runOneBatchQuery(pollCtx, client, sqlStr, params, statementIDs, i)
105108
completed.Add(1)
106109
return nil
107110
})
108111
}
109112
_ = g.Wait()
110113

111-
// pollStatement is a pure helper that returns ctx.Err() on cancellation
112-
// without touching the server. Sweep any not-yet-terminal statements here.
114+
// Poll returns ctx.Err() on cancellation without touching the server.
115+
// Sweep any not-yet-terminal statements here.
113116
if pollCtx.Err() != nil {
114-
cancelInFlight(ctx, api, statementIDs, results)
117+
cancelInFlight(ctx, client, statementIDs, results)
115118
}
116119

117120
return results
118121
}
119122

120123
// runOneBatchQuery submits one SQL, polls to completion, and returns its
121124
// batchResult. All errors are encoded into the result; never returns an error.
122-
func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult {
125+
func runOneBatchQuery(ctx context.Context, client *sqlexec.Client, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult {
123126
start := time.Now()
124127
result := batchResult{SQL: sqlStr}
125128

126-
resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
127-
WarehouseId: warehouseID,
128-
Statement: sqlStr,
129-
Parameters: params,
130-
WaitTimeout: "0s",
131-
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue,
132-
})
129+
stmt, err := client.Submit(ctx, sqlStr, sqlexec.WithParameters(params))
133130
if err != nil {
134131
if ctx.Err() != nil {
135132
result.State = sql.StatementStateCanceled
136133
result.Error = &batchResultError{Message: "submission cancelled"}
137134
} else {
138135
result.State = sql.StatementStateFailed
139-
result.Error = &batchResultError{Message: fmt.Sprintf("execute statement: %v", err)}
136+
result.Error = &batchResultError{Message: err.Error()}
140137
}
141138
result.ElapsedMs = time.Since(start).Milliseconds()
142139
return result
143140
}
144141

145-
statementIDs[idx] = resp.StatementId
146-
result.StatementID = resp.StatementId
142+
statementIDs[idx] = stmt.ID
143+
result.StatementID = stmt.ID
147144

148-
pollResp, err := pollStatement(ctx, api, resp)
145+
stmt, err = client.Poll(ctx, stmt)
149146
if err != nil {
150147
if ctx.Err() != nil {
151148
result.State = sql.StatementStateCanceled
@@ -158,38 +155,34 @@ func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface,
158155
return result
159156
}
160157

161-
if pollResp.Status != nil {
162-
result.State = pollResp.Status.State
163-
}
158+
result.State = stmt.State
164159

165-
if result.State != sql.StatementStateSucceeded {
166-
result.Error = &batchResultError{}
167-
if pollResp.Status != nil && pollResp.Status.Error != nil {
168-
result.Error.Message = pollResp.Status.Error.Message
169-
result.Error.ErrorCode = string(pollResp.Status.Error.ErrorCode)
170-
} else {
171-
result.Error.Message = fmt.Sprintf("query reached terminal state %s", result.State)
160+
if err := stmt.Err(); err != nil {
161+
se, _ := errors.AsType[*sqlexec.StatementError](err)
162+
result.Error = &batchResultError{
163+
Message: se.Message,
164+
ErrorCode: string(se.Code),
172165
}
173166
result.ElapsedMs = time.Since(start).Milliseconds()
174167
return result
175168
}
176169

177-
result.Columns = extractColumns(pollResp.Manifest)
178-
rows, err := fetchAllRows(ctx, api, pollResp)
170+
res, err := client.Results(ctx, stmt)
179171
if err != nil {
180172
result.Error = &batchResultError{Message: fmt.Sprintf("fetch rows: %v", err)}
181173
result.ElapsedMs = time.Since(start).Milliseconds()
182174
return result
183175
}
184-
result.Rows = rows
176+
result.Columns = res.Columns
177+
result.Rows = res.Rows
185178
result.ElapsedMs = time.Since(start).Milliseconds()
186179
return result
187180
}
188181

189182
// cancelInFlight sends CancelExecution for every statement that didn't reach
190183
// a terminal state server-side before context cancellation. Best effort: errors
191184
// are logged at warn but don't fail the batch.
192-
func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, statementIDs []string, results []batchResult) {
185+
func cancelInFlight(ctx context.Context, client *sqlexec.Client, statementIDs []string, results []batchResult) {
193186
var cancelled int
194187
for i, sid := range statementIDs {
195188
if sid == "" {
@@ -208,7 +201,7 @@ func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, st
208201
// values but drops the cancellation signal so the cancel RPC actually
209202
// reaches the warehouse instead of short-circuiting on ctx.Err().
210203
cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout)
211-
if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil {
204+
if err := client.Cancel(cancelCtx, sid); err != nil {
212205
log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err)
213206
}
214207
cancel()

experimental/aitools/cmd/discover_schema.go

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/databricks/cli/libs/cmdctx"
2020
"github.com/databricks/cli/libs/cmdio"
2121
"github.com/databricks/cli/libs/log"
22+
"github.com/databricks/cli/libs/sqlexec"
2223
"github.com/databricks/databricks-sdk-go"
2324
dbsql "github.com/databricks/databricks-sdk-go/service/sql"
2425
"github.com/spf13/cobra"
@@ -45,8 +46,10 @@ func newSQLGate(limit int) *sqlGate {
4546
// run executes a SQL statement asynchronously, polls until terminal, and
4647
// records the statement_id so it can be cancelled if the parent context is
4748
// cancelled. Acquires a slot from the gate before submitting and releases it
48-
// when polling completes (or the caller's context is cancelled).
49-
func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) {
49+
// when polling completes (or the caller's context is cancelled). On success it
50+
// returns the assembled result; a terminal non-success state is surfaced as the
51+
// CLI-facing query error.
52+
func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*sqlexec.Result, error) {
5053
// If the caller cancelled before we even tried, don't enter the select:
5154
// when the gate has free slots both cases are ready and Go picks one
5255
// pseudo-randomly. Without this early-out we'd occasionally submit a
@@ -61,28 +64,25 @@ func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, wareho
6164
return nil, ctx.Err()
6265
}
6366

64-
resp, err := w.StatementExecution.ExecuteStatement(ctx, dbsql.ExecuteStatementRequest{
65-
WarehouseId: warehouseID,
66-
Statement: statement,
67-
WaitTimeout: "0s",
68-
OnWaitTimeout: dbsql.ExecuteStatementRequestOnWaitTimeoutContinue,
69-
})
67+
client := sqlexec.New(w.StatementExecution, warehouseID)
68+
69+
stmt, err := client.Submit(ctx, statement)
7070
if err != nil {
71-
return nil, fmt.Errorf("execute statement: %w", err)
71+
return nil, err
7272
}
7373

7474
g.mu.Lock()
75-
g.ids = append(g.ids, resp.StatementId)
75+
g.ids = append(g.ids, stmt.ID)
7676
g.mu.Unlock()
7777

78-
pollResp, err := pollStatement(ctx, w.StatementExecution, resp)
78+
stmt, err = client.Poll(ctx, stmt)
7979
if err != nil {
8080
return nil, err
8181
}
82-
if err := checkFailedState(pollResp.Status); err != nil {
82+
if err := presentQueryError(stmt.Err()); err != nil {
8383
return nil, err
8484
}
85-
return pollResp, nil
85+
return client.Results(ctx, stmt)
8686
}
8787

8888
// trackedIDs returns a snapshot of statement_ids submitted through this gate.
@@ -235,9 +235,11 @@ func cancelDiscoverInFlight(ctx context.Context, api dbsql.StatementExecutionInt
235235
cmdio.LogString(ctx, "discover-schema cancelled.")
236236
return
237237
}
238+
// Cancel/Poll/Get don't use the warehouse ID, so an empty one is fine here.
239+
client := sqlexec.New(api, "")
238240
for _, id := range ids {
239241
cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout)
240-
if err := api.CancelExecution(cancelCtx, dbsql.CancelExecutionRequest{StatementId: id}); err != nil {
242+
if err := client.Cancel(cancelCtx, id); err != nil {
241243
log.Warnf(ctx, "Failed to cancel statement %s: %v", id, err)
242244
}
243245
cancel()
@@ -252,12 +254,12 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl
252254
}
253255

254256
// 1. describe table - get columns and types
255-
descResp, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted)
257+
descResult, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted)
256258
if err != nil {
257259
return "", fmt.Errorf("describe table: %w", err)
258260
}
259261

260-
columns, types := parseDescribeResult(descResp)
262+
columns, types := parseDescribeResult(descResult)
261263
if len(columns) == 0 {
262264
return "", errors.New("no columns found")
263265
}
@@ -281,16 +283,16 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl
281283
nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s",
282284
strings.Join(nullCountExprs, ", "), quoted)
283285

284-
var sampleResp, nullResp *dbsql.StatementResponse
286+
var sampleResult, nullResult *sqlexec.Result
285287
var sampleErr, nullErr error
286288

287289
g := new(errgroup.Group)
288290
g.Go(func() error {
289-
sampleResp, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL)
291+
sampleResult, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL)
290292
return nil
291293
})
292294
g.Go(func() error {
293-
nullResp, nullErr = gate.run(ctx, w, warehouseID, nullSQL)
295+
nullResult, nullErr = gate.run(ctx, w, warehouseID, nullSQL)
294296
return nil
295297
})
296298
_ = g.Wait()
@@ -306,25 +308,21 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl
306308
fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", sampleErr)
307309
} else {
308310
sb.WriteString("\nSAMPLE DATA:\n")
309-
sb.WriteString(formatTableData(sampleResp))
311+
sb.WriteString(formatTableData(sampleResult))
310312
}
311313

312314
if nullErr != nil {
313315
fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", nullErr)
314316
} else {
315317
sb.WriteString("\nNULL COUNTS:\n")
316-
sb.WriteString(formatNullCounts(nullResp, columns))
318+
sb.WriteString(formatNullCounts(nullResult, columns))
317319
}
318320

319321
return sb.String(), nil
320322
}
321323

322-
func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string) {
323-
if resp.Result == nil || resp.Result.DataArray == nil {
324-
return nil, nil
325-
}
326-
327-
for _, row := range resp.Result.DataArray {
324+
func parseDescribeResult(result *sqlexec.Result) (columns, types []string) {
325+
for _, row := range result.Rows {
328326
if len(row) < 2 {
329327
continue
330328
}
@@ -340,20 +338,15 @@ func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string
340338
return columns, types
341339
}
342340

343-
func formatTableData(resp *dbsql.StatementResponse) string {
344-
if resp.Result == nil || resp.Result.DataArray == nil || len(resp.Result.DataArray) == 0 {
341+
func formatTableData(result *sqlexec.Result) string {
342+
if len(result.Rows) == 0 {
345343
return " (no data)\n"
346344
}
347345

348346
var sb strings.Builder
349-
var columns []string
350-
if resp.Manifest != nil && resp.Manifest.Schema != nil {
351-
for _, col := range resp.Manifest.Schema.Columns {
352-
columns = append(columns, col.Name)
353-
}
354-
}
347+
columns := result.Columns
355348

356-
for i, row := range resp.Result.DataArray {
349+
for i, row := range result.Rows {
357350
fmt.Fprintf(&sb, " Row %d:\n", i+1)
358351
for j, val := range row {
359352
colName := fmt.Sprintf("col%d", j)
@@ -366,12 +359,12 @@ func formatTableData(resp *dbsql.StatementResponse) string {
366359
return sb.String()
367360
}
368361

369-
func formatNullCounts(resp *dbsql.StatementResponse, columns []string) string {
370-
if resp.Result == nil || resp.Result.DataArray == nil || len(resp.Result.DataArray) == 0 {
362+
func formatNullCounts(result *sqlexec.Result, columns []string) string {
363+
if len(result.Rows) == 0 {
371364
return " (no data)\n"
372365
}
373366

374-
row := resp.Result.DataArray[0]
367+
row := result.Rows[0]
375368
var sb strings.Builder
376369

377370
// first value is total_rows

experimental/aitools/cmd/discover_schema_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"time"
1111

1212
"github.com/databricks/cli/libs/cmdio"
13+
"github.com/databricks/cli/libs/sqlexec"
1314
"github.com/databricks/databricks-sdk-go"
1415
mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql"
1516
dbsql "github.com/databricks/databricks-sdk-go/service/sql"
@@ -50,17 +51,17 @@ func TestQuoteTableName(t *testing.T) {
5051
}
5152

5253
func TestParseDescribeResultSkipsMetadataRows(t *testing.T) {
53-
resp := &dbsql.StatementResponse{
54-
Result: &dbsql.ResultData{DataArray: [][]string{
54+
result := &sqlexec.Result{
55+
Rows: [][]string{
5556
{"id", "BIGINT", ""},
5657
{"name", "STRING", ""},
5758
{"# Partition Information", "", ""},
5859
{"region", "STRING", ""},
5960
{"", "STRING", ""},
60-
}},
61+
},
6162
}
6263

63-
cols, types := parseDescribeResult(resp)
64+
cols, types := parseDescribeResult(result)
6465
assert.Equal(t, []string{"id", "name", "region"}, cols)
6566
assert.Equal(t, []string{"BIGINT", "STRING", "STRING"}, types)
6667
}
@@ -82,9 +83,9 @@ func TestSQLGateRunPinsOnWaitTimeoutAndRecordsID(t *testing.T) {
8283
w := &databricks.WorkspaceClient{StatementExecution: mockAPI}
8384
gate := newSQLGate(2)
8485

85-
resp, err := gate.run(ctx, w, "wh-1", "SELECT 1")
86+
result, err := gate.run(ctx, w, "wh-1", "SELECT 1")
8687
require.NoError(t, err)
87-
assert.Equal(t, "stmt-1", resp.StatementId)
88+
assert.Equal(t, [][]string{{"1"}}, result.Rows)
8889
assert.Equal(t, []string{"stmt-1"}, gate.trackedIDs())
8990
}
9091

0 commit comments

Comments
 (0)