Skip to content

Commit 5a2669c

Browse files
authored
aitools: add --param flag for parameterized SQL queries (#5336)
## Why The Databricks SQL Statement Execution API supports named parameters (`:name` markers in SQL plus a `parameters` payload), but the experimental aitools `query` and `statement submit` commands never set that field. Users who want to avoid SQL injection, sidestep shell-quoting issues with dates and strings, or run typed bindings (`DATE`, `INT`, `DECIMAL(...)`, etc.) currently have to drop down to raw HTTP. This wires the field through. ## Changes **Before:** no way to pass parameters. SQL had to inline every value as a literal, with all the quoting and injection risk that implies. **Now:** `--param` is a repeatable flag on `query`, `statement submit`, and the multi-query batch path. Format: - `--param name=value` (default type, server-side STRING) - `--param name:TYPE=value` for typed bindings, e.g. `--param since:DATE=2026-01-01` Empty value (`--param opt=`) is sent as NULL via `omitempty`. Duplicate names and missing `=` are rejected at flag-parse time. In batch mode the same parameter set is applied to every statement. Implementation: - New `parseParams` helper in `experimental/aitools/cmd/params.go`, plus parser unit tests. - Plumbed `[]sql.StatementParameterListItem` through `executeAndPoll`, `submitStatement`, `executeBatch`, and `runOneBatchQuery`. - `--param` flag registered on both `newQueryCmd` and `newStatementSubmitCmd`. - Help text and examples updated. No `NEXT_CHANGELOG.md` entry: this is still under `experimental aitools tools`. ## Test plan - [x] `./task checks` clean (tidy, whitespace, links, deadcode) - [x] `./task lint-q` clean (0 issues) - [x] `./task fmt` clean (no changes) - [x] `go test ./experimental/aitools/...` passes - [x] New unit tests for parser: typed, untyped, value with embedded `=`/`:`, decimal types with parens, empty value, whitespace trimming, error cases (no `=`, empty name, duplicates) - [x] New mock-based tests confirming `Parameters` reaches `ExecuteStatement` for `executeAndPoll`, `submitStatement`, and `executeBatch` - [ ] Manual smoke test against a real warehouse (`databricks experimental aitools tools query --param name=alice "SELECT :name"`)
1 parent 30f637f commit 5a2669c

8 files changed

Lines changed: 313 additions & 26 deletions

File tree

experimental/aitools/cmd/batch.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ type batchResultError struct {
5353
// On context cancellation (Ctrl+C or parent context), still-running statements
5454
// are cancelled server-side via CancelExecution. Statements that finished
5555
// before cancellation are left as-is.
56-
func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) []batchResult {
56+
//
57+
// params, if non-nil, are bound on every statement. The same parameter set is
58+
// reused across the batch, so callers must ensure each SQL uses only markers
59+
// that are covered.
60+
func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) []batchResult {
5761
pollCtx, pollCancel := context.WithCancel(ctx)
5862
defer pollCancel()
5963

@@ -97,7 +101,7 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware
97101
g.SetLimit(concurrency)
98102
for i, sqlStr := range sqls {
99103
g.Go(func() error {
100-
results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, statementIDs, i)
104+
results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, params, statementIDs, i)
101105
completed.Add(1)
102106
return nil
103107
})
@@ -115,13 +119,14 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware
115119

116120
// runOneBatchQuery submits one SQL, polls to completion, and returns its
117121
// batchResult. All errors are encoded into the result; never returns an error.
118-
func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, statementIDs []string, idx int) batchResult {
122+
func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult {
119123
start := time.Now()
120124
result := batchResult{SQL: sqlStr}
121125

122126
resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
123127
WarehouseId: warehouseID,
124128
Statement: sqlStr,
129+
Parameters: params,
125130
WaitTimeout: "0s",
126131
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue,
127132
})

experimental/aitools/cmd/batch_test.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func TestExecuteBatchAllSucceed(t *testing.T) {
7373
}, nil).Once()
7474
}
7575

76-
results := executeBatch(ctx, mockAPI, "wh-123", sqls, 8)
76+
results := executeBatch(ctx, mockAPI, "wh-123", sqls, nil, 8)
7777

7878
require.Len(t, results, 3)
7979
for i, r := range results {
@@ -86,6 +86,25 @@ func TestExecuteBatchAllSucceed(t *testing.T) {
8686
}
8787
}
8888

89+
func TestExecuteBatchPassesParametersToAllStatements(t *testing.T) {
90+
ctx := cmdio.MockDiscard(t.Context())
91+
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
92+
93+
params := []sql.StatementParameterListItem{
94+
{Name: "since", Type: "DATE", Value: "2026-01-01"},
95+
}
96+
97+
mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool {
98+
return assert.ObjectsAreEqual(params, req.Parameters)
99+
})).Return(&sql.StatementResponse{
100+
StatementId: "stmt",
101+
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
102+
}, nil).Times(2)
103+
104+
results := executeBatch(ctx, mockAPI, "wh-1", []string{"SELECT 1 WHERE 1=1 AND :since IS NOT NULL", "SELECT 2 WHERE :since IS NOT NULL"}, params, 8)
105+
require.Len(t, results, 2)
106+
}
107+
89108
func TestExecuteBatchPartialFailure(t *testing.T) {
90109
ctx := cmdio.MockDiscard(t.Context())
91110
mockAPI := mocksql.NewMockStatementExecutionInterface(t)
@@ -112,7 +131,7 @@ func TestExecuteBatchPartialFailure(t *testing.T) {
112131
},
113132
}, nil).Once()
114133

115-
results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, 8)
134+
results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, nil, 8)
116135

117136
require.Len(t, results, 2)
118137
assert.Nil(t, results[0].Error)
@@ -141,7 +160,7 @@ func TestExecuteBatchSubmissionFailure(t *testing.T) {
141160
return req.Statement == "SELECT broken"
142161
})).Return(nil, errors.New("network unreachable")).Once()
143162

144-
results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, 8)
163+
results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, nil, 8)
145164

146165
require.Len(t, results, 2)
147166
assert.Nil(t, results[0].Error)
@@ -163,7 +182,7 @@ func TestExecuteBatchSetsOnWaitTimeoutContinue(t *testing.T) {
163182
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
164183
}, nil).Times(2)
165184

166-
results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, 8)
185+
results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, nil, 8)
167186
require.Len(t, results, 2)
168187
}
169188

@@ -196,7 +215,7 @@ func TestExecuteBatchPreservesInputOrder(t *testing.T) {
196215
}
197216

198217
sqls := []string{"SELECT 'slow'", "SELECT 'fast1'", "SELECT 'fast2'"}
199-
results := executeBatch(ctx, mockAPI, "wh-1", sqls, 8)
218+
results := executeBatch(ctx, mockAPI, "wh-1", sqls, nil, 8)
200219

201220
require.Len(t, results, 3)
202221
for i, r := range results {
@@ -233,7 +252,7 @@ func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) {
233252

234253
cancel()
235254

236-
results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, 8)
255+
results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, nil, 8)
237256

238257
require.Len(t, results, 3)
239258
for i, r := range results {

experimental/aitools/cmd/params.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package aitools
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/databricks/databricks-sdk-go/service/sql"
8+
)
9+
10+
// parseParams converts --param flag values into SDK parameter list items for
11+
// the Databricks SQL Statement Execution API. Each input is either
12+
// "name=value" (defaults to STRING server-side) or "name:TYPE=value" (typed,
13+
// e.g. "since:DATE=2026-01-01"). An empty value becomes NULL on the wire
14+
// because StatementParameterListItem.Value uses omitempty.
15+
//
16+
// The Databricks API only supports named markers (`:name`), not positional
17+
// `?`, and parameter names must be unique within a statement.
18+
func parseParams(raw []string) ([]sql.StatementParameterListItem, error) {
19+
if len(raw) == 0 {
20+
return nil, nil
21+
}
22+
23+
out := make([]sql.StatementParameterListItem, 0, len(raw))
24+
seen := make(map[string]struct{}, len(raw))
25+
for _, s := range raw {
26+
lhs, value, ok := strings.Cut(s, "=")
27+
if !ok {
28+
return nil, fmt.Errorf("invalid --param %q: expected name=value or name:TYPE=value", s)
29+
}
30+
31+
name, typ, _ := strings.Cut(lhs, ":")
32+
name = strings.TrimSpace(name)
33+
typ = strings.TrimSpace(typ)
34+
35+
if name == "" {
36+
return nil, fmt.Errorf("invalid --param %q: name is empty", s)
37+
}
38+
if _, dup := seen[name]; dup {
39+
return nil, fmt.Errorf("duplicate --param name %q", name)
40+
}
41+
seen[name] = struct{}{}
42+
43+
out = append(out, sql.StatementParameterListItem{
44+
Name: name,
45+
Type: typ,
46+
Value: value,
47+
})
48+
}
49+
return out, nil
50+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package aitools
2+
3+
import (
4+
"testing"
5+
6+
"github.com/databricks/databricks-sdk-go/service/sql"
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestParseParams(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
in []string
15+
want []sql.StatementParameterListItem
16+
}{
17+
{
18+
name: "nil input returns nil",
19+
in: nil,
20+
want: nil,
21+
},
22+
{
23+
name: "empty input returns nil",
24+
in: []string{},
25+
want: nil,
26+
},
27+
{
28+
name: "single string param defaults type to empty (server-side STRING)",
29+
in: []string{"name=alice"},
30+
want: []sql.StatementParameterListItem{
31+
{Name: "name", Value: "alice"},
32+
},
33+
},
34+
{
35+
name: "typed param splits name and type on first colon",
36+
in: []string{"since:DATE=2026-01-01"},
37+
want: []sql.StatementParameterListItem{
38+
{Name: "since", Type: "DATE", Value: "2026-01-01"},
39+
},
40+
},
41+
{
42+
name: "value can contain = and :",
43+
in: []string{"clause=ts >= '2026-01-01T00:00:00'"},
44+
want: []sql.StatementParameterListItem{
45+
{Name: "clause", Value: "ts >= '2026-01-01T00:00:00'"},
46+
},
47+
},
48+
{
49+
name: "decimal type with parens preserved",
50+
in: []string{"amount:DECIMAL(10,2)=42.50"},
51+
want: []sql.StatementParameterListItem{
52+
{Name: "amount", Type: "DECIMAL(10,2)", Value: "42.50"},
53+
},
54+
},
55+
{
56+
name: "empty value becomes NULL on the wire via omitempty",
57+
in: []string{"opt="},
58+
want: []sql.StatementParameterListItem{
59+
{Name: "opt", Value: ""},
60+
},
61+
},
62+
{
63+
name: "whitespace around name and type is trimmed",
64+
in: []string{" name : INT =42"},
65+
want: []sql.StatementParameterListItem{
66+
{Name: "name", Type: "INT", Value: "42"},
67+
},
68+
},
69+
{
70+
name: "multiple params preserve input order",
71+
in: []string{"a=1", "b:INT=2", "c=three"},
72+
want: []sql.StatementParameterListItem{
73+
{Name: "a", Value: "1"},
74+
{Name: "b", Type: "INT", Value: "2"},
75+
{Name: "c", Value: "three"},
76+
},
77+
},
78+
}
79+
80+
for _, tc := range tests {
81+
t.Run(tc.name, func(t *testing.T) {
82+
got, err := parseParams(tc.in)
83+
require.NoError(t, err)
84+
assert.Equal(t, tc.want, got)
85+
})
86+
}
87+
}
88+
89+
func TestParseParamsErrors(t *testing.T) {
90+
tests := []struct {
91+
name string
92+
in []string
93+
wantMsg string
94+
}{
95+
{
96+
name: "no equals sign",
97+
in: []string{"foo"},
98+
wantMsg: "expected name=value",
99+
},
100+
{
101+
name: "empty name",
102+
in: []string{"=value"},
103+
wantMsg: "name is empty",
104+
},
105+
{
106+
name: "empty name with type",
107+
in: []string{":INT=42"},
108+
wantMsg: "name is empty",
109+
},
110+
{
111+
name: "whitespace-only name",
112+
in: []string{" =value"},
113+
wantMsg: "name is empty",
114+
},
115+
{
116+
name: "duplicate name",
117+
in: []string{"name=alice", "name=bob"},
118+
wantMsg: `duplicate --param name "name"`,
119+
},
120+
}
121+
122+
for _, tc := range tests {
123+
t.Run(tc.name, func(t *testing.T) {
124+
_, err := parseParams(tc.in)
125+
require.Error(t, err)
126+
assert.Contains(t, err.Error(), tc.wantMsg)
127+
})
128+
}
129+
}

experimental/aitools/cmd/query.go

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ func newQueryCmd() *cobra.Command {
6868
var filePaths []string
6969
var outputFormat string
7070
var concurrency int
71+
var paramFlags []string
72+
var params []sql.StatementParameterListItem
7173

7274
cmd := &cobra.Command{
7375
Use: "query [SQL | file.sql]...",
@@ -91,19 +93,33 @@ or the DATABRICKS_WAREHOUSE_ID environment variable is configured.
9193
9294
For a single query, output is JSON in non-interactive contexts. In
9395
interactive terminals it renders tables, and large results open an
94-
interactive table browser. Use --output csv to export results as CSV.`,
96+
interactive table browser. Use --output csv to export results as CSV.
97+
98+
Pass named parameters with --param. Use ":name" markers in the SQL and
99+
"--param name=value" (string) or "--param name:TYPE=value" (typed, e.g.
100+
DATE, INT) to bind values. Positional "?" markers are not supported. In
101+
multi-query mode, the same parameter set is applied to every statement.`,
95102
Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5"
96103
databricks experimental aitools tools query --warehouse abc123 "SELECT 1"
97104
databricks experimental aitools tools query --file report.sql
98105
databricks experimental aitools tools query report.sql
99106
databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5"
100107
databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3"
108+
databricks experimental aitools tools query --param name=alice "SELECT * FROM users WHERE name = :name"
109+
databricks experimental aitools tools query --param since:DATE=2026-01-01 "SELECT * FROM events WHERE ts > :since"
101110
echo "SELECT 1" | databricks experimental aitools tools query`,
102111
Args: cobra.ArbitraryArgs,
103112
PreRunE: func(cmd *cobra.Command, args []string) error {
104113
if concurrency <= 0 {
105114
return errInvalidBatchConcurrency
106115
}
116+
117+
var err error
118+
params, err = parseParams(paramFlags)
119+
if err != nil {
120+
return err
121+
}
122+
107123
return root.MustWorkspaceClient(cmd, args)
108124
},
109125
RunE: func(cmd *cobra.Command, args []string) error {
@@ -139,10 +155,10 @@ interactive table browser. Use --output csv to export results as CSV.`,
139155
}
140156

141157
if len(sqls) > 1 {
142-
return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, concurrency)
158+
return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, params, concurrency)
143159
}
144160

145-
resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0])
161+
resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0], params)
146162
if err != nil {
147163
return err
148164
}
@@ -185,6 +201,7 @@ interactive table browser. Use --output csv to export results as CSV.`,
185201
cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution")
186202
cmd.Flags().StringSliceVarP(&filePaths, "file", "f", nil, "Path to a SQL file to execute (repeatable; pair with positional SQLs to run a batch)")
187203
cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries")
204+
cmd.Flags().StringArrayVar(&paramFlags, "param", nil, "Named parameter, repeatable. Format: name=value (STRING) or name:TYPE=value (e.g. name:DATE=2026-01-01). Empty value is sent as NULL.")
188205
// Local --output flag shadows the root command's persistent --output flag,
189206
// adding csv support for this command only.
190207
cmd.Flags().StringVarP(&outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv")
@@ -222,8 +239,10 @@ func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []stri
222239
// without an extra error message) when any statement failed; the failure detail
223240
// is already encoded in the printed JSON. The caller is responsible for
224241
// rejecting incompatible output formats before invoking this.
225-
func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) error {
226-
results := executeBatch(ctx, api, warehouseID, sqls, concurrency)
242+
//
243+
// params, if non-nil, are applied to every statement in the batch.
244+
func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) error {
245+
results := executeBatch(ctx, api, warehouseID, sqls, params, concurrency)
227246
if err := renderBatchJSON(cmd.OutOrStdout(), results); err != nil {
228247
return err
229248
}
@@ -252,11 +271,12 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e
252271

253272
// executeAndPoll submits a SQL statement asynchronously and polls until completion.
254273
// It shows a spinner in interactive mode and supports Ctrl+C cancellation.
255-
func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) {
274+
func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string, params []sql.StatementParameterListItem) (*sql.StatementResponse, error) {
256275
// Submit asynchronously to get the statement ID immediately for cancellation.
257276
resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
258277
WarehouseId: warehouseID,
259278
Statement: statement,
279+
Parameters: params,
260280
WaitTimeout: "0s",
261281
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue,
262282
})

0 commit comments

Comments
 (0)