Skip to content

Commit a3f0765

Browse files
authored
aitools: parallelize discover-schema across tables and probes (#5097)
## Stack This PR is part of a 4-PR stack making `aitools` data exploration faster for ai-dev-kit. Each PR is independently reviewable; merge in order. 1. #5092 — aitools: extract pollStatement helper and pin OnWaitTimeout *(base: `main`)* 2. #5093 — aitools: run multiple SQL queries in parallel from one query invocation *(base: #5092)* 3. #5095 — aitools: add 'tools statement' lifecycle commands *(base: #5093)* 4. **#5097 — aitools: parallelize discover-schema across tables and probes** *(base: #5095)* — **this PR** Use `git diff <base>...HEAD` or set the comparison base in the GitHub UI to see only this PR's changes; the default "Files changed" diff against `main` includes ancestor PRs. --- ## Why `discover-schema` walked tables sequentially and ran each table's three probes (DESCRIBE, sample SELECT, null counts) one after the other. For ai-dev-kit's data-exploration phase that meant warehouse-bound work was idle most of the time. Same root cause as the multi-query exploration latency that #5093 (batch query) fixed; same fix. This is a pure latency win. No new user-facing API surface, no output-shape change. ## Changes **Two layers of parallelism plus a shared statement budget:** 1. **Across tables.** The for-loop in `RunE` becomes an `errgroup.Group`. A failure on one table never aborts the others; it's rendered inline as `"Error discovering ..."` exactly as before. 2. **Within a table.** `discoverTable` still runs DESCRIBE first because the column list feeds the null-counts query. After DESCRIBE returns, the sample SELECT and null-counts probes run concurrently. Output text is assembled once both probes finish, preserving the existing `COLUMNS / SAMPLE DATA / NULL COUNTS` order. 3. **Single warehouse-statement budget.** A new `sqlGate` (chan struct{} of capacity N + statement_id tracking) wraps every `executeSQL` call. `--concurrency` (default 8) caps total in-flight statements globally, regardless of how many tables you pass. So `--concurrency 1` actually serializes statement load, not just table fan-out. **Switch `executeSQL` to use `pollStatement`** (the helper extracted in #5092) instead of the SDK's `ExecuteAndWait`. Pins `OnWaitTimeout: CONTINUE`. Failed states flow through `checkFailedState`, yielding more specific error messages (e.g. `"query failed: SYNTAX_ERROR near 'oops'"`) than the previous hand-rolled branch. The user-visible `"SAMPLE DATA: Error - %v" / "NULL COUNTS: Error - %v"` wrapping is unchanged. Future polling-helper improvements land here for free. **Cancellation discipline mirroring batch.go (#5093):** signal handler cancels a derived `pollCtx`; `sqlGate` records each `statement_id` post-submission; on cancellation the recorded IDs are swept via `CancelExecution` before returning `root.ErrAlreadyPrinted`. Without this, parallelism would orphan up to N×2 statements server-side on Ctrl+C. **`--concurrency` validation** mirrors `cmd/fs/cp.go` and #5093: `PreRunE` rejects values <= 0 with `errInvalidBatchConcurrency`. Table-name validation also runs in `PreRunE` so malformed identifiers are rejected before `MustWorkspaceClient` runs (no unnecessary auth roundtrip on bad input). **Output unchanged** for any input that previously succeeded. Same dividers, same header/probe ordering, same per-probe error wrapping. ## Test plan - [x] `go test ./experimental/aitools/...` passes. - [x] `make checks` clean. - [x] `make fmt` no drift. - [x] `make lint` 0 issues. - [x] New unit tests in `discover_schema_test.go`: - `quoteTableName` table-driven (valid, missing parts, too many parts, injection attempts, empty parts, leading-digit identifiers, backtick in name) - `parseDescribeResult` skips metadata rows (`#`-prefixed and empty) - `sqlGate.run` pins `OnWaitTimeout: CONTINUE`, propagates FAILED state, wraps transport errors, records IDs, respects cancelled context - `cancelDiscoverInFlight` calls API per ID; empty list is a no-op - `discoverTable`: sample and null-count probes run concurrently after DESCRIBE (deterministic atomic-counter + sync.OnceFunc + channel-close barrier; sequential execution surfaces a timeout error) - `discoverTable`: a sample-probe failure does not abort null counts - `--concurrency 0` and `-1` rejected at PreRunE - Invalid table name (not `CATALOG.SCHEMA.TABLE`) and injection attempts rejected at PreRunE before any API call - [x] Manual smoke against a real warehouse: ```bash databricks experimental aitools tools discover-schema \ samples.nyctaxi.trips samples.tpch.orders samples.tpch.customer ```
1 parent d0d58e8 commit a3f0765

2 files changed

Lines changed: 518 additions & 55 deletions

File tree

experimental/aitools/cmd/discover_schema.go

Lines changed: 192 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,96 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"os"
8+
"os/signal"
79
"regexp"
10+
"slices"
811
"strings"
12+
"sync"
13+
"syscall"
914

1015
"github.com/databricks/cli/cmd/root"
1116
"github.com/databricks/cli/experimental/aitools/lib/middlewares"
1217
"github.com/databricks/cli/experimental/aitools/lib/session"
1318
"github.com/databricks/cli/libs/cmdctx"
1419
"github.com/databricks/cli/libs/cmdio"
20+
"github.com/databricks/cli/libs/log"
1521
"github.com/databricks/databricks-sdk-go"
1622
dbsql "github.com/databricks/databricks-sdk-go/service/sql"
1723
"github.com/spf13/cobra"
24+
"golang.org/x/sync/errgroup"
1825
)
1926

2027
var sqlIdentifierRe = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
2128

29+
// sqlGate caps in-flight SQL statements globally and records each statement_id
30+
// so a Ctrl+C sweep can cancel anything still running server-side. The gate's
31+
// concurrency limit applies across all probes (DESCRIBE, sample SELECT, null
32+
// counts) and across all tables, so --concurrency really means "max statements
33+
// in flight," not "max tables in flight."
34+
type sqlGate struct {
35+
sem chan struct{}
36+
mu sync.Mutex
37+
ids []string
38+
}
39+
40+
func newSQLGate(limit int) *sqlGate {
41+
return &sqlGate{sem: make(chan struct{}, limit)}
42+
}
43+
44+
// run executes a SQL statement asynchronously, polls until terminal, and
45+
// records the statement_id so it can be cancelled if the parent context is
46+
// cancelled. Acquires a slot from the gate before submitting and releases it
47+
// when polling completes (or the caller's context is cancelled).
48+
func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) {
49+
// If the caller cancelled before we even tried, don't enter the select:
50+
// when the gate has free slots both cases are ready and Go picks one
51+
// pseudo-randomly. Without this early-out we'd occasionally submit a
52+
// statement under a cancelled context.
53+
if err := ctx.Err(); err != nil {
54+
return nil, err
55+
}
56+
select {
57+
case g.sem <- struct{}{}:
58+
defer func() { <-g.sem }()
59+
case <-ctx.Done():
60+
return nil, ctx.Err()
61+
}
62+
63+
resp, err := w.StatementExecution.ExecuteStatement(ctx, dbsql.ExecuteStatementRequest{
64+
WarehouseId: warehouseID,
65+
Statement: statement,
66+
WaitTimeout: "0s",
67+
OnWaitTimeout: dbsql.ExecuteStatementRequestOnWaitTimeoutContinue,
68+
})
69+
if err != nil {
70+
return nil, fmt.Errorf("execute statement: %w", err)
71+
}
72+
73+
g.mu.Lock()
74+
g.ids = append(g.ids, resp.StatementId)
75+
g.mu.Unlock()
76+
77+
pollResp, err := pollStatement(ctx, w.StatementExecution, resp)
78+
if err != nil {
79+
return nil, err
80+
}
81+
if err := checkFailedState(pollResp.Status); err != nil {
82+
return nil, err
83+
}
84+
return pollResp, nil
85+
}
86+
87+
// trackedIDs returns a snapshot of statement_ids submitted through this gate.
88+
func (g *sqlGate) trackedIDs() []string {
89+
g.mu.Lock()
90+
defer g.mu.Unlock()
91+
return slices.Clone(g.ids)
92+
}
93+
2294
func newDiscoverSchemaCmd() *cobra.Command {
95+
var concurrency int
96+
2397
cmd := &cobra.Command{
2498
Use: "discover-schema TABLE...",
2599
Short: "Discover schema for one or more tables",
@@ -31,21 +105,33 @@ For each table, returns:
31105
- Column names and types
32106
- Sample data (5 rows)
33107
- Null counts per column
34-
- Total row count`,
108+
- Total row count
109+
110+
Tables and probes (DESCRIBE, sample SELECT, null counts) all share a
111+
single warehouse-statement budget. --concurrency (default 8) caps the
112+
total number of statements in flight at any moment, regardless of how
113+
many tables you pass in.
114+
115+
On Ctrl+C, in-flight statements are cancelled server-side via
116+
CancelExecution before the command exits.`,
35117
Example: ` databricks experimental aitools tools discover-schema samples.nyctaxi.trips
36118
databricks experimental aitools tools discover-schema catalog.schema.table1 catalog.schema.table2`,
37-
Args: cobra.MinimumNArgs(1),
38-
PreRunE: root.MustWorkspaceClient,
39-
RunE: func(cmd *cobra.Command, args []string) error {
40-
ctx := cmd.Context()
41-
w := cmdctx.WorkspaceClient(ctx)
42-
43-
// validate table names: each part must be a safe SQL identifier
119+
Args: cobra.MinimumNArgs(1),
120+
PreRunE: func(cmd *cobra.Command, args []string) error {
121+
if concurrency <= 0 {
122+
return errInvalidBatchConcurrency
123+
}
124+
// Reject malformed identifiers before any auth/profile work.
44125
for _, table := range args {
45126
if _, err := quoteTableName(table); err != nil {
46127
return err
47128
}
48129
}
130+
return root.MustWorkspaceClient(cmd, args)
131+
},
132+
RunE: func(cmd *cobra.Command, args []string) error {
133+
ctx := cmd.Context()
134+
w := cmdctx.WorkspaceClient(ctx)
49135

50136
// set up session with client for middleware compatibility
51137
sess := session.NewSession()
@@ -57,13 +143,43 @@ For each table, returns:
57143
return err
58144
}
59145

60-
var results []string
61-
for _, table := range args {
62-
result, err := discoverTable(ctx, w, warehouseID, table)
63-
if err != nil {
64-
result = fmt.Sprintf("Error discovering %s: %v", table, err)
146+
pollCtx, pollCancel := context.WithCancel(ctx)
147+
defer pollCancel()
148+
149+
sigCh := make(chan os.Signal, 1)
150+
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
151+
defer signal.Stop(sigCh)
152+
153+
go func() {
154+
select {
155+
case <-sigCh:
156+
log.Infof(ctx, "Received interrupt, cancelling in-flight discover-schema statements")
157+
pollCancel()
158+
case <-pollCtx.Done():
65159
}
66-
results = append(results, result)
160+
}()
161+
162+
gate := newSQLGate(concurrency)
163+
164+
results := make([]string, len(args))
165+
g := new(errgroup.Group)
166+
for i, table := range args {
167+
g.Go(func() error {
168+
result, err := discoverTable(pollCtx, gate, w, warehouseID, table)
169+
if err != nil {
170+
results[i] = fmt.Sprintf("Error discovering %s: %v", table, err)
171+
} else {
172+
results[i] = result
173+
}
174+
// A failure on one table shouldn't abort the others.
175+
return nil
176+
})
177+
}
178+
_ = g.Wait()
179+
180+
if pollCtx.Err() != nil {
181+
cancelDiscoverInFlight(ctx, w.StatementExecution, gate.trackedIDs())
182+
return root.ErrAlreadyPrinted
67183
}
68184

69185
// format output with dividers for multiple tables
@@ -90,20 +206,39 @@ For each table, returns:
90206
},
91207
}
92208

209+
cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum SQL statements in flight at once across all tables and probes")
210+
93211
return cmd
94212
}
95213

96-
func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, table string) (string, error) {
97-
var sb strings.Builder
214+
// cancelDiscoverInFlight sends CancelExecution for every recorded statement_id.
215+
// Best effort: errors are logged but don't fail the user-visible exit.
216+
// Statements that already finished server-side return an error which we just
217+
// swallow at warn level; the alternative (per-statement state tracking) isn't
218+
// worth the bookkeeping here.
219+
func cancelDiscoverInFlight(ctx context.Context, api dbsql.StatementExecutionInterface, ids []string) {
220+
if len(ids) == 0 {
221+
cmdio.LogString(ctx, "discover-schema cancelled.")
222+
return
223+
}
224+
for _, id := range ids {
225+
cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout)
226+
if err := api.CancelExecution(cancelCtx, dbsql.CancelExecutionRequest{StatementId: id}); err != nil {
227+
log.Warnf(ctx, "Failed to cancel statement %s: %v", id, err)
228+
}
229+
cancel()
230+
}
231+
cmdio.LogString(ctx, fmt.Sprintf("discover-schema cancelled; sent CancelExecution for %d statement(s).", len(ids)))
232+
}
98233

234+
func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceClient, warehouseID, table string) (string, error) {
99235
quoted, err := quoteTableName(table)
100236
if err != nil {
101237
return "", err
102238
}
103239

104240
// 1. describe table - get columns and types
105-
describeSQL := "DESCRIBE TABLE " + quoted
106-
descResp, err := executeSQL(ctx, w, warehouseID, describeSQL)
241+
descResp, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted)
107242
if err != nil {
108243
return "", fmt.Errorf("describe table: %w", err)
109244
}
@@ -113,32 +248,55 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse
113248
return "", errors.New("no columns found")
114249
}
115250

251+
// 2 + 3. Sample data and null counts run in parallel; both depend only on
252+
// the column list (already known) and not on each other. The gate (not
253+
// errgroup) is what actually limits warehouse concurrency.
254+
sampleSQL := fmt.Sprintf("SELECT * FROM %s LIMIT 5", quoted)
255+
256+
nullCountExprs := make([]string, len(columns))
257+
for i, col := range columns {
258+
// Backticks inside an identifier are escaped by doubling them in
259+
// Databricks/Delta SQL (`` ` `` → `` `` ``). Without this, a column
260+
// name containing a backtick would terminate the quoted identifier
261+
// mid-string and produce a PARSE_SYNTAX_ERROR. Sample-data uses
262+
// SELECT * so the failure shows up only as a confusing
263+
// "NULL COUNTS: Error - ..." line in the user-facing output.
264+
escaped := strings.ReplaceAll(col, "`", "``")
265+
nullCountExprs[i] = fmt.Sprintf("SUM(CASE WHEN `%s` IS NULL THEN 1 ELSE 0 END) AS `%s_nulls`", escaped, escaped)
266+
}
267+
nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s",
268+
strings.Join(nullCountExprs, ", "), quoted)
269+
270+
var sampleResp, nullResp *dbsql.StatementResponse
271+
var sampleErr, nullErr error
272+
273+
g := new(errgroup.Group)
274+
g.Go(func() error {
275+
sampleResp, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL)
276+
return nil
277+
})
278+
g.Go(func() error {
279+
nullResp, nullErr = gate.run(ctx, w, warehouseID, nullSQL)
280+
return nil
281+
})
282+
_ = g.Wait()
283+
284+
// Assemble the output in the established order: columns, sample, null counts.
285+
var sb strings.Builder
116286
sb.WriteString("COLUMNS:\n")
117287
for i, col := range columns {
118288
fmt.Fprintf(&sb, " %s: %s\n", col, types[i])
119289
}
120290

121-
// 2. sample data (5 rows)
122-
sampleSQL := fmt.Sprintf("SELECT * FROM %s LIMIT 5", quoted)
123-
sampleResp, err := executeSQL(ctx, w, warehouseID, sampleSQL)
124-
if err != nil {
125-
fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", err)
291+
if sampleErr != nil {
292+
fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", sampleErr)
126293
} else {
127294
sb.WriteString("\nSAMPLE DATA:\n")
128295
sb.WriteString(formatTableData(sampleResp))
129296
}
130297

131-
// 3. null counts per column
132-
nullCountExprs := make([]string, len(columns))
133-
for i, col := range columns {
134-
nullCountExprs[i] = fmt.Sprintf("SUM(CASE WHEN `%s` IS NULL THEN 1 ELSE 0 END) AS `%s_nulls`", col, col)
135-
}
136-
nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s",
137-
strings.Join(nullCountExprs, ", "), quoted)
138-
139-
nullResp, err := executeSQL(ctx, w, warehouseID, nullSQL)
140-
if err != nil {
141-
fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", err)
298+
if nullErr != nil {
299+
fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", nullErr)
142300
} else {
143301
sb.WriteString("\nNULL COUNTS:\n")
144302
sb.WriteString(formatNullCounts(nullResp, columns))
@@ -147,27 +305,6 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse
147305
return sb.String(), nil
148306
}
149307

150-
func executeSQL(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) {
151-
resp, err := w.StatementExecution.ExecuteAndWait(ctx, dbsql.ExecuteStatementRequest{
152-
WarehouseId: warehouseID,
153-
Statement: statement,
154-
WaitTimeout: "50s",
155-
})
156-
if err != nil {
157-
return nil, err
158-
}
159-
160-
if resp.Status != nil && resp.Status.State == dbsql.StatementStateFailed {
161-
errMsg := "query failed"
162-
if resp.Status.Error != nil {
163-
errMsg = resp.Status.Error.Message
164-
}
165-
return nil, errors.New(errMsg)
166-
}
167-
168-
return resp, nil
169-
}
170-
171308
func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string) {
172309
if resp.Result == nil || resp.Result.DataArray == nil {
173310
return nil, nil

0 commit comments

Comments
 (0)