Skip to content

Commit 67d4a5e

Browse files
authored
Experimental postgres query (PR 3/4): multi-input + error formatting (#5138)
## PR Stack 1. [#5135](#5135) — PR 1: scaffold + autoscaling targeting + text output 2. [#5136](#5136) — PR 2: provisioned + JSON/CSV streaming + types + `sqlcli.ResolveFormat` 3. **PR 3 (this PR)** — [#5138](#5138) — multi-input + multi-statement rejection + error formatting + `sqlcli.Collect` 4. [#5143](#5143) — PR 4: cancellation + timeout + TUI Stacked on PR 2. ## Why PR 2 shipped a single-statement, single-input command. Real workflows want multi-input (set-then-query, file-then-stdin), multi-statement rejection with a friendly hint, and rich pg error formatting. This PR also extends `experimental/libs/sqlcli` with input-collection logic shared by aitools and postgres. Same architectural principle as PR 2: instead of postgres growing its own duplicate of aitools' resolveSQLs, both commands now call `sqlcli.Collect`. ## Changes **Architectural:** `experimental/libs/sqlcli/input.go` adds: - `sqlcli.SQLFileExtension` const (.sql). - `sqlcli.Input{SQL, Source}` type — Source is the human-readable origin label ("--file PATH", "argv[N]", "stdin"). - `sqlcli.CollectOptions{Cleaner func(string) string}` — aitools passes its `cleanSQL` (strips comments+quotes); postgres passes the default `TrimSpace` because its multi-statement scanner needs comments preserved. - `sqlcli.Collect` — files-first then positionals, stdin only when neither is present, .sql autodetect on positionals. aitools' resolveSQLs collapses to a thin wrapper around sqlcli.Collect (drops the SQL strings, ignores Source). The "SQL statement #N is empty after removing comments" wording is replaced with sqlcli's `argv[N] is empty`; aitools tests updated. **User-facing changes for postgres query:** - Variadic positionals + repeatable `--file` + stdin fallback. - Multi-statement strings rejected up front with a hint (the hand-written conservative scanner ignores `;` inside string literals, identifiers, line/block comments, and dollar-quoted bodies; tag must be a valid unquoted identifier so `$1` and `$foo-bar$` are correctly NOT treated as tags). - Multi-input output: per-unit blocks for text; canonical-shape JSON array `{"source","sql","kind","elapsed_ms",...}` for json; csv rejected pre-flight when N>1. - Rich pg error formatting (`SEVERITY: message (SQLSTATE XXXXX)` with DETAIL/HINT lines), applied on both single-input and multi-input paths. Single-input keeps streaming. `runUnitBuffered` is a thin wrapper around `executeOne` + a `bufferSink`, so the row-loop and error-wrapping logic stays in one place. ## Test plan - [x] `go test ./experimental/...` (multistatement scanner: 28 cases including dollar-tag punctuation rejection, sqlcli.Collect: 12 cases including a custom-cleaner test, error formatting, multi-input renderers including byte-equal canonical-shape and first-unit-fails framing) - [x] `go tool ... golangci-lint run ./experimental/...` (0 issues)
1 parent 42e344b commit 67d4a5e

13 files changed

Lines changed: 1148 additions & 88 deletions

File tree

experimental/aitools/cmd/query.go

Lines changed: 12 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"io"
87
"os"
98
"os/signal"
109
"strings"
@@ -23,9 +22,6 @@ import (
2322
)
2423

2524
const (
26-
// sqlFileExtension is the file extension used to auto-detect SQL files.
27-
sqlFileExtension = ".sql"
28-
2925
// pollIntervalInitial is the starting interval between status polls.
3026
pollIntervalInitial = 1 * time.Second
3127

@@ -204,65 +200,21 @@ interactive table browser. Use --output csv to export results as CSV.`,
204200
}
205201

206202
// resolveSQLs collects SQL statements from --file paths, positional args, and
207-
// stdin. The returned slice preserves source order: --file paths first (in flag
208-
// order), then positional args (in arg order), then stdin (only if no other
209-
// source produced anything). Each SQL is run through cleanSQL.
203+
// stdin via sqlcli.Collect, then runs each through cleanSQL (the warehouse
204+
// statement API doesn't care about line comments, so we strip them up front
205+
// to normalise the wire payload). Returns just the SQL strings so the rest of
206+
// this command's flow stays unchanged; the Source labels sqlcli adds are
207+
// dropped on the floor (this command surfaces statement_id, not source).
210208
func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []string) ([]string, error) {
211-
var raws []string
212-
213-
for _, path := range filePaths {
214-
data, err := os.ReadFile(path)
215-
if err != nil {
216-
return nil, fmt.Errorf("read SQL file %s: %w", path, err)
217-
}
218-
raws = append(raws, string(data))
219-
}
220-
221-
for _, arg := range args {
222-
// If the argument looks like a .sql file, try to read it.
223-
// Only fall through to literal SQL if the file doesn't exist.
224-
// Surface other errors (permission denied, etc.) directly.
225-
if strings.HasSuffix(arg, sqlFileExtension) {
226-
data, err := os.ReadFile(arg)
227-
if err != nil && !errors.Is(err, os.ErrNotExist) {
228-
return nil, fmt.Errorf("read SQL file: %w", err)
229-
}
230-
if err == nil {
231-
raws = append(raws, string(data))
232-
continue
233-
}
234-
}
235-
raws = append(raws, arg)
236-
}
237-
238-
if len(raws) == 0 {
239-
// No --file and no positional args: try reading from stdin if it's piped.
240-
// If stdin was overridden (e.g. cmd.SetIn in tests), always read from it.
241-
// Otherwise, only read if stdin is not a TTY (i.e. piped input).
242-
in := cmd.InOrStdin()
243-
_, isOsFile := in.(*os.File)
244-
if isOsFile && cmdio.IsPromptSupported(ctx) {
245-
return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin")
246-
}
247-
data, err := io.ReadAll(in)
248-
if err != nil {
249-
return nil, fmt.Errorf("read stdin: %w", err)
250-
}
251-
raws = append(raws, string(data))
209+
inputs, err := sqlcli.Collect(ctx, cmd.InOrStdin(), args, filePaths, sqlcli.CollectOptions{Cleaner: cleanSQL})
210+
if err != nil {
211+
return nil, err
252212
}
253-
254-
cleaned := make([]string, 0, len(raws))
255-
for i, raw := range raws {
256-
c := cleanSQL(raw)
257-
if c == "" {
258-
if len(raws) == 1 {
259-
return nil, errors.New("SQL statement is empty after removing comments and blank lines")
260-
}
261-
return nil, fmt.Errorf("SQL statement #%d is empty after removing comments and blank lines", i+1)
262-
}
263-
cleaned = append(cleaned, c)
213+
out := make([]string, len(inputs))
214+
for i, in := range inputs {
215+
out[i] = in.SQL
264216
}
265-
return cleaned, nil
217+
return out, nil
266218
}
267219

268220
// runBatch executes multiple SQL statements in parallel and renders the result

experimental/aitools/cmd/query_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ func TestResolveSQLsUnreadableSQLFileReturnsError(t *testing.T) {
539539
cmd := newTestCmd()
540540
_, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil)
541541
require.Error(t, err)
542-
assert.Contains(t, err.Error(), "read SQL file")
542+
assert.Contains(t, err.Error(), "permission denied")
543543
}
544544

545545
func TestResolveSQLsFromStdin(t *testing.T) {
@@ -579,14 +579,14 @@ func TestResolveSQLsBatchEmptyAtIndexReturnsIndexedError(t *testing.T) {
579579
cmd := newTestCmd()
580580
_, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "-- comment only", "SELECT 3"}, nil)
581581
require.Error(t, err)
582-
assert.Contains(t, err.Error(), "SQL statement #2 is empty")
582+
assert.Contains(t, err.Error(), "argv[2] is empty")
583583
}
584584

585585
func TestResolveSQLsMissingFileReturnsError(t *testing.T) {
586586
cmd := newTestCmd()
587587
_, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{"/nonexistent/path/query.sql"})
588588
require.Error(t, err)
589-
assert.Contains(t, err.Error(), "read SQL file")
589+
assert.Contains(t, err.Error(), "no such file")
590590
}
591591

592592
func TestQueryCommandUnsupportedOutputReturnsError(t *testing.T) {

experimental/libs/sqlcli/input.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package sqlcli
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"io"
8+
"os"
9+
"strings"
10+
11+
"github.com/databricks/cli/libs/cmdio"
12+
)
13+
14+
// SQLFileExtension is the file suffix that triggers the .sql autodetect on a
15+
// positional argument: if `databricks ... query foo.sql` exists on disk, the
16+
// argument is read as a SQL file; otherwise it's treated as literal SQL.
17+
const SQLFileExtension = ".sql"
18+
19+
// Input is one SQL statement to execute, paired with a label identifying its
20+
// origin so multi-input renderers and error messages can refer back to "which
21+
// of the N inputs failed".
22+
type Input struct {
23+
// SQL is the cleaned statement text. Always non-empty (Collect rejects
24+
// inputs that clean to empty).
25+
SQL string
26+
// Source is a human-readable label: "--file PATH", "argv[N]", or "stdin".
27+
Source string
28+
}
29+
30+
// CollectOptions controls per-command behavior. The zero value is fine for
31+
// commands that just want plain trimmed input.
32+
type CollectOptions struct {
33+
// Cleaner is applied to each raw SQL after read (and before the empty
34+
// check). The default is strings.TrimSpace; aitools passes a richer
35+
// cleaner that strips SQL comments and surrounding quotes. Postgres
36+
// passes the default because its multi-statement scanner needs comments
37+
// preserved.
38+
Cleaner func(string) string
39+
}
40+
41+
// Collect assembles the ordered list of inputs from --file paths, positional
42+
// arguments, and stdin.
43+
//
44+
// Order is files-first, then positionals. Cobra/pflag does not preserve the
45+
// user's interleaved CLI spelling: it collects all --file flags into one
46+
// slice and all positionals into another, so callers cannot honour
47+
// `--file q1.sql "SELECT 1" --file q2.sql` as written.
48+
//
49+
// Stdin is read only when neither --file nor positional input was provided,
50+
// and only when stdin is not a prompt-capable TTY (otherwise we'd block
51+
// waiting for input the user did not realise they had to type).
52+
//
53+
// Errors when:
54+
// - A --file path can't be read or cleans to empty.
55+
// - A positional that looks like a .sql file but read fails with a non-
56+
// "does not exist" error (e.g. permission denied).
57+
// - A positional cleans to empty.
58+
// - Stdin is the only source and it's empty / blocked on a TTY.
59+
func Collect(ctx context.Context, in io.Reader, args, files []string, opts CollectOptions) ([]Input, error) {
60+
cleaner := opts.Cleaner
61+
if cleaner == nil {
62+
cleaner = strings.TrimSpace
63+
}
64+
65+
var inputs []Input
66+
67+
for _, path := range files {
68+
data, err := os.ReadFile(path)
69+
if err != nil {
70+
return nil, fmt.Errorf("read --file %q: %w", path, err)
71+
}
72+
sql := cleaner(string(data))
73+
if sql == "" {
74+
return nil, fmt.Errorf("--file %q is empty", path)
75+
}
76+
inputs = append(inputs, Input{SQL: sql, Source: "--file " + path})
77+
}
78+
79+
for i, arg := range args {
80+
// .sql autodetect: if the positional ends in .sql AND the file
81+
// exists, read it as a SQL file. Other read errors (permission
82+
// denied) surface directly. If the file does not exist, fall
83+
// through and treat the positional as literal SQL — useful when
84+
// the user passes a string that happens to end with ".sql".
85+
if strings.HasSuffix(arg, SQLFileExtension) {
86+
data, err := os.ReadFile(arg)
87+
if err != nil && !errors.Is(err, os.ErrNotExist) {
88+
return nil, fmt.Errorf("read positional %q: %w", arg, err)
89+
}
90+
if err == nil {
91+
sql := cleaner(string(data))
92+
if sql == "" {
93+
return nil, fmt.Errorf("positional %q is empty", arg)
94+
}
95+
inputs = append(inputs, Input{SQL: sql, Source: arg})
96+
continue
97+
}
98+
}
99+
sql := cleaner(arg)
100+
if sql == "" {
101+
return nil, fmt.Errorf("argv[%d] is empty", i+1)
102+
}
103+
inputs = append(inputs, Input{SQL: sql, Source: fmt.Sprintf("argv[%d]", i+1)})
104+
}
105+
106+
if len(inputs) == 0 {
107+
_, isOsFile := in.(*os.File)
108+
if isOsFile && cmdio.IsPromptSupported(ctx) {
109+
return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin")
110+
}
111+
data, err := io.ReadAll(in)
112+
if err != nil {
113+
return nil, fmt.Errorf("read stdin: %w", err)
114+
}
115+
sql := cleaner(string(data))
116+
if sql == "" {
117+
return nil, errors.New("no SQL provided")
118+
}
119+
inputs = append(inputs, Input{SQL: sql, Source: "stdin"})
120+
}
121+
122+
return inputs, nil
123+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package sqlcli
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"strings"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func writeTemp(t *testing.T, name, contents string) string {
14+
t.Helper()
15+
dir := t.TempDir()
16+
p := filepath.Join(dir, name)
17+
require.NoError(t, os.WriteFile(p, []byte(contents), 0o644))
18+
return p
19+
}
20+
21+
func TestCollect_PositionalOnly(t *testing.T) {
22+
got, err := Collect(t.Context(), strings.NewReader(""), []string{"SELECT 1"}, nil, CollectOptions{})
23+
require.NoError(t, err)
24+
require.Len(t, got, 1)
25+
assert.Equal(t, "SELECT 1", got[0].SQL)
26+
assert.Equal(t, "argv[1]", got[0].Source)
27+
}
28+
29+
func TestCollect_MultiplePositionals(t *testing.T) {
30+
got, err := Collect(t.Context(), strings.NewReader(""), []string{"SELECT 1", "SELECT 2"}, nil, CollectOptions{})
31+
require.NoError(t, err)
32+
require.Len(t, got, 2)
33+
assert.Equal(t, "SELECT 1", got[0].SQL)
34+
assert.Equal(t, "SELECT 2", got[1].SQL)
35+
}
36+
37+
func TestCollect_FileOnly(t *testing.T) {
38+
p := writeTemp(t, "q.sql", "SELECT * FROM t")
39+
got, err := Collect(t.Context(), strings.NewReader(""), nil, []string{p}, CollectOptions{})
40+
require.NoError(t, err)
41+
require.Len(t, got, 1)
42+
assert.Equal(t, "SELECT * FROM t", got[0].SQL)
43+
assert.Contains(t, got[0].Source, "--file")
44+
}
45+
46+
func TestCollect_FilesFirstThenPositionals(t *testing.T) {
47+
p1 := writeTemp(t, "a.sql", "SELECT 1")
48+
p2 := writeTemp(t, "b.sql", "SELECT 2")
49+
got, err := Collect(t.Context(), strings.NewReader(""), []string{"SELECT 3"}, []string{p1, p2}, CollectOptions{})
50+
require.NoError(t, err)
51+
require.Len(t, got, 3)
52+
assert.Equal(t, "SELECT 1", got[0].SQL)
53+
assert.Equal(t, "SELECT 2", got[1].SQL)
54+
assert.Equal(t, "SELECT 3", got[2].SQL)
55+
}
56+
57+
func TestCollect_DotSQLAutoDetect(t *testing.T) {
58+
p := writeTemp(t, "data.sql", "SELECT 42")
59+
got, err := Collect(t.Context(), strings.NewReader(""), []string{p}, nil, CollectOptions{})
60+
require.NoError(t, err)
61+
require.Len(t, got, 1)
62+
assert.Equal(t, "SELECT 42", got[0].SQL)
63+
}
64+
65+
func TestCollect_DotSQLNotExistingFallsThroughToLiteral(t *testing.T) {
66+
got, err := Collect(t.Context(), strings.NewReader(""), []string{"/nonexistent/path.sql"}, nil, CollectOptions{})
67+
require.NoError(t, err)
68+
require.Len(t, got, 1)
69+
assert.Equal(t, "/nonexistent/path.sql", got[0].SQL)
70+
}
71+
72+
func TestCollect_StdinOnly(t *testing.T) {
73+
got, err := Collect(t.Context(), strings.NewReader("SELECT 1\n"), nil, nil, CollectOptions{})
74+
require.NoError(t, err)
75+
require.Len(t, got, 1)
76+
assert.Equal(t, "SELECT 1", got[0].SQL)
77+
assert.Equal(t, "stdin", got[0].Source)
78+
}
79+
80+
func TestCollect_StdinIgnoredWhenPositionalsPresent(t *testing.T) {
81+
got, err := Collect(t.Context(), strings.NewReader("FROM STDIN"), []string{"SELECT 1"}, nil, CollectOptions{})
82+
require.NoError(t, err)
83+
require.Len(t, got, 1)
84+
assert.Equal(t, "SELECT 1", got[0].SQL)
85+
}
86+
87+
func TestCollect_EmptyStdinErrors(t *testing.T) {
88+
_, err := Collect(t.Context(), strings.NewReader(""), nil, nil, CollectOptions{})
89+
assert.ErrorContains(t, err, "no SQL provided")
90+
}
91+
92+
func TestCollect_EmptyFileErrors(t *testing.T) {
93+
p := writeTemp(t, "empty.sql", "")
94+
_, err := Collect(t.Context(), strings.NewReader(""), nil, []string{p}, CollectOptions{})
95+
assert.ErrorContains(t, err, "is empty")
96+
}
97+
98+
func TestCollect_EmptyPositional(t *testing.T) {
99+
_, err := Collect(t.Context(), strings.NewReader(""), []string{" "}, nil, CollectOptions{})
100+
assert.ErrorContains(t, err, "is empty")
101+
}
102+
103+
func TestCollect_CustomCleanerStripsComments(t *testing.T) {
104+
cleaner := func(s string) string {
105+
// Naive comment stripper: drop lines starting with --
106+
var lines []string
107+
for line := range strings.SplitSeq(s, "\n") {
108+
line = strings.TrimSpace(line)
109+
if line != "" && !strings.HasPrefix(line, "--") {
110+
lines = append(lines, line)
111+
}
112+
}
113+
return strings.Join(lines, "\n")
114+
}
115+
got, err := Collect(
116+
t.Context(), strings.NewReader(""),
117+
[]string{"-- ignored\nSELECT 1\n-- also ignored"},
118+
nil,
119+
CollectOptions{Cleaner: cleaner},
120+
)
121+
require.NoError(t, err)
122+
require.Len(t, got, 1)
123+
assert.Equal(t, "SELECT 1", got[0].SQL)
124+
}

0 commit comments

Comments
 (0)