Skip to content

Commit 778b45c

Browse files
committed
fixes and tests
1 parent a8fec25 commit 778b45c

File tree

6 files changed

+292
-70
lines changed

6 files changed

+292
-70
lines changed

internal/cmd/generate.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,23 +234,23 @@ type generator struct {
234234
codegenHandlerOverride grpc.ClientConnInterface
235235
}
236236

237-
func (g *generator) Pairs(ctx context.Context, conf *config.Config) []OutputPair {
238-
var pairs []OutputPair
237+
func (g *generator) Pairs(ctx context.Context, conf *config.Config) []outputPair {
238+
var pairs []outputPair
239239
for _, sql := range conf.SQL {
240240
if sql.Gen.Go != nil {
241-
pairs = append(pairs, OutputPair{
241+
pairs = append(pairs, outputPair{
242242
SQL: sql,
243243
Gen: config.SQLGen{Go: sql.Gen.Go},
244244
})
245245
}
246246
if sql.Gen.JSON != nil {
247-
pairs = append(pairs, OutputPair{
247+
pairs = append(pairs, outputPair{
248248
SQL: sql,
249249
Gen: config.SQLGen{JSON: sql.Gen.JSON},
250250
})
251251
}
252252
for i := range sql.Codegen {
253-
pairs = append(pairs, OutputPair{
253+
pairs = append(pairs, outputPair{
254254
SQL: sql,
255255
Plugin: &sql.Codegen[i],
256256
})
@@ -259,7 +259,7 @@ func (g *generator) Pairs(ctx context.Context, conf *config.Config) []OutputPair
259259
return pairs
260260
}
261261

262-
func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSettings, sql OutputPair, result *compiler.Result) error {
262+
func (g *generator) ProcessResult(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result) error {
263263
out, resp, err := codegen(ctx, combo, sql, result, g.codegenHandlerOverride)
264264
if err != nil {
265265
return err
@@ -393,7 +393,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
393393
return c.Result(), false
394394
}
395395

396-
func codegen(ctx context.Context, combo config.CombinedSettings, sql OutputPair, result *compiler.Result, codegenOverride grpc.ClientConnInterface) (string, *plugin.GenerateResponse, error) {
396+
func codegen(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result, codegenOverride grpc.ClientConnInterface) (string, *plugin.GenerateResponse, error) {
397397
defer trace.StartRegion(ctx, "codegen").End()
398398
req := codeGenRequest(result, combo)
399399
var handler grpc.ClientConnInterface

internal/cmd/plugin_engine.go

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -93,27 +93,13 @@ var defaultCommentSyntax = metadata.CommentSyntax(source.CommentSyntax{Dash: tru
9393
// engine plugin via ParseRequest; the responses are turned into compiler.Result and
9494
// passed to ProcessResult. No AST or compiler parsing is used.
9595
// When inputs.FileContents is set, schema/query bytes are taken from it (no disk read).
96-
func runPluginQuerySet(ctx context.Context, rp ResultProcessor, name, dir string, sql OutputPair, combo config.CombinedSettings, inputs *sourceFiles, o *Options) error {
96+
func runPluginQuerySet(ctx context.Context, rp resultProcessor, name, dir string, sql outputPair, combo config.CombinedSettings, inputs *sourceFiles, o *Options) error {
9797
enginePlugin, found := config.FindEnginePlugin(&combo.Global, string(combo.Package.Engine))
9898
if !found || enginePlugin.Process == nil {
9999
e := string(combo.Package.Engine)
100100
return fmt.Errorf("unknown engine: %s\n\nAdd the engine to the 'engines' section of sqlc.yaml. See the engine plugins documentation: https://docs.sqlc.dev/en/latest/guides/engine-plugins.html", e)
101101
}
102102

103-
var parseFn func(schemaSQL, querySQL string) (*pb.ParseResponse, error)
104-
if o != nil && o.PluginParseFunc != nil {
105-
parseFn = o.PluginParseFunc
106-
} else {
107-
r := newEngineProcessRunner(enginePlugin.Process.Cmd, combo.Dir, enginePlugin.Env)
108-
parseFn = func(schemaSQL, querySQL string) (*pb.ParseResponse, error) {
109-
req := &pb.ParseRequest{Sql: querySQL}
110-
if schemaSQL != "" {
111-
req.SchemaSource = &pb.ParseRequest_SchemaSql{SchemaSql: schemaSQL}
112-
}
113-
return r.parseRequest(ctx, req)
114-
}
115-
}
116-
117103
readFile := func(path string) ([]byte, error) {
118104
if inputs != nil && inputs.FileContents != nil {
119105
if b, ok := inputs.FileContents[path]; ok {
@@ -123,24 +109,54 @@ func runPluginQuerySet(ctx context.Context, rp ResultProcessor, name, dir string
123109
return os.ReadFile(path)
124110
}
125111

112+
databaseOnly := combo.Package.Analyzer.Database.IsOnly() && combo.Package.Database != nil && combo.Package.Database.URI != ""
113+
126114
var schemaSQL string
127-
var err error
128-
if inputs != nil && inputs.FileContents != nil {
129-
var parts []string
130-
for _, p := range sql.Schema {
131-
if b, ok := inputs.FileContents[p]; ok {
132-
parts = append(parts, string(b))
115+
if !databaseOnly {
116+
var err error
117+
if inputs != nil && inputs.FileContents != nil {
118+
var parts []string
119+
for _, p := range sql.Schema {
120+
if b, ok := inputs.FileContents[p]; ok {
121+
parts = append(parts, string(b))
122+
}
123+
}
124+
schemaSQL = strings.Join(parts, "\n")
125+
} else {
126+
schemaSQL, err = loadSchemaSQL(sql.Schema, readFile)
127+
if err != nil {
128+
return err
133129
}
134130
}
135-
schemaSQL = strings.Join(parts, "\n")
131+
}
132+
133+
type parseFuncType func(querySQL string) (*pb.ParseResponse, error)
134+
var parseFn parseFuncType
135+
if o != nil && o.PluginParseFunc != nil {
136+
schemaStr := schemaSQL
137+
if databaseOnly {
138+
schemaStr = ""
139+
}
140+
parseFn = func(querySQL string) (*pb.ParseResponse, error) {
141+
return o.PluginParseFunc(schemaStr, querySQL)
142+
}
136143
} else {
137-
schemaSQL, err = loadSchemaSQL(sql.Schema, readFile)
138-
if err != nil {
139-
return err
144+
r := newEngineProcessRunner(enginePlugin.Process.Cmd, combo.Dir, enginePlugin.Env)
145+
parseFn = func(querySQL string) (*pb.ParseResponse, error) {
146+
req := &pb.ParseRequest{Sql: querySQL}
147+
if databaseOnly {
148+
req.SchemaSource = &pb.ParseRequest_ConnectionParams{
149+
ConnectionParams: &pb.ConnectionParams{Dsn: combo.Package.Database.URI},
150+
}
151+
} else {
152+
req.SchemaSource = &pb.ParseRequest_SchemaSql{SchemaSql: schemaSQL}
153+
}
154+
return r.parseRequest(ctx, req)
140155
}
141156
}
142157

143158
var queryPaths []string
159+
var err error
144160
if inputs != nil && inputs.FileContents != nil {
145161
queryPaths = sql.Queries
146162
} else {
@@ -161,32 +177,28 @@ func runPluginQuerySet(ctx context.Context, rp ResultProcessor, name, dir string
161177
continue
162178
}
163179
queryContent := string(blob)
164-
nameStr, cmd, err := metadata.ParseQueryNameAndType(queryContent, defaultCommentSyntax)
165-
if err != nil {
166-
merr.Add(filename, queryContent, 0, err)
167-
continue
168-
}
169-
if nameStr == "" {
170-
continue
171-
}
172-
173-
resp, err := parseFn(schemaSQL, queryContent)
180+
blocks, err := metadata.QueryBlocks(queryContent, defaultCommentSyntax)
174181
if err != nil {
175182
merr.Add(filename, queryContent, 0, err)
176183
continue
177184
}
178-
179-
q := pluginResponseToCompilerQuery(nameStr, cmd, filepath.Base(filename), resp)
180-
if q == nil {
181-
continue
182-
}
183-
184-
if _, exists := set[nameStr]; exists {
185-
merr.Add(filename, queryContent, 0, fmt.Errorf("duplicate query name: %s", nameStr))
186-
continue
185+
for _, b := range blocks {
186+
resp, err := parseFn(b.SQL)
187+
if err != nil {
188+
merr.Add(filename, queryContent, 0, err)
189+
continue
190+
}
191+
q := pluginResponseToCompilerQuery(b.Name, b.Cmd, filepath.Base(filename), resp)
192+
if q == nil {
193+
continue
194+
}
195+
if _, exists := set[b.Name]; exists {
196+
merr.Add(filename, queryContent, 0, fmt.Errorf("duplicate query name: %s", b.Name))
197+
continue
198+
}
199+
set[b.Name] = struct{}{}
200+
queries = append(queries, q)
187201
}
188-
set[nameStr] = struct{}{}
189-
queries = append(queries, q)
190202
}
191203

192204
if len(merr.Errs()) > 0 {

internal/cmd/plugin_engine_test.go

Lines changed: 137 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ package cmd
99
// mock → codegen request) and that the plugin package is used when PluginParseFunc is nil.
1010
//
1111
// Proof that the technology works:
12-
// - TestPluginPipeline_FullPipeline: with PluginParseFunc set, the pipeline sends schema+query into the mock,
13-
// takes Sql/Params/Columns from it, builds compiler.Result → plugin.GenerateRequest, and the codegen mock
14-
// receives that. So "plugin engine → ParseRequest contract → codegen" is validated.
12+
// - TestPluginPipeline_FullPipeline: one block → one Parse call; that call receives schema; codegen gets the result.
13+
// - TestPluginPipeline_NBlocksNCalls: N blocks in query.sql → exactly N Parse calls; each call receives schema.
14+
// - TestPluginPipeline_DatabaseOnly_ReceivesNoSchema: with analyzer.database: only + database.uri, each Parse
15+
// call receives empty schema (the real runner would get connection_params in ParseRequest).
1516
// - TestPluginPipeline_WithoutOverride_UsesPluginPackage: with PluginParseFunc nil, generate fails with an error
1617
// that is NOT "unknown engine", so we did enter runPluginQuerySet and call the engine process runner.
17-
// The plugin package is required for that path. Vet does not support plugin engines.
1818

1919
import (
2020
"bytes"
@@ -50,10 +50,13 @@ engines:
5050
// engineMockRecord holds what the engine-plugin mock received and returned.
5151
// Used to validate that the pipeline passes schema + raw query in, and that
5252
// the plugin's Sql/Parameters/Columns are what later reach codegen.
53+
// CalledWith records each Parse call so we can assert N blocks → N calls and
54+
// that every call received schema (or "" in databaseOnly mode).
5355
type engineMockRecord struct {
5456
Calls int
55-
SchemaSQL string
56-
QuerySQL string
57+
SchemaSQL string // last call (backward compat)
58+
QuerySQL string // last call (backward compat)
59+
CalledWith []struct{ SchemaSQL, QuerySQL string }
5760
ReturnedSQL string
5861
ReturnedParams []*pb.Parameter
5962
ReturnedCols []*pb.Column
@@ -101,6 +104,7 @@ func TestPluginPipeline_FullPipeline(t *testing.T) {
101104
engineRecord.Calls++
102105
engineRecord.SchemaSQL = schemaSQL
103106
engineRecord.QuerySQL = querySQL
107+
engineRecord.CalledWith = append(engineRecord.CalledWith, struct{ SchemaSQL, QuerySQL string }{schemaSQL, querySQL})
104108
return &pb.ParseResponse{
105109
Sql: engineRecord.ReturnedSQL,
106110
Parameters: engineRecord.ReturnedParams,
@@ -146,12 +150,20 @@ func TestPluginPipeline_FullPipeline(t *testing.T) {
146150
}
147151

148152
// ---- 1) Validate what was sent INTO the engine plugin ----
153+
// N blocks in query.sql must yield N Parse calls; each call must receive schema (or connection in databaseOnly).
149154
if engineRecord.Calls != 1 {
150-
t.Errorf("engine mock called %d times, want 1", engineRecord.Calls)
155+
t.Errorf("engine mock called %d times, want 1 (one block → one Parse call)", engineRecord.Calls)
156+
}
157+
if len(engineRecord.CalledWith) != 1 {
158+
t.Errorf("engine mock CalledWith has %d entries, want 1", len(engineRecord.CalledWith))
159+
}
160+
if len(engineRecord.CalledWith) > 0 && engineRecord.CalledWith[0].SchemaSQL != schemaContent {
161+
t.Errorf("every Parse call must receive schema: got %q", engineRecord.CalledWith[0].SchemaSQL)
151162
}
152163
if engineRecord.SchemaSQL != schemaContent {
153164
t.Errorf("engine received schema:\n got: %q\n want: %q", engineRecord.SchemaSQL, schemaContent)
154165
}
166+
// With one block, query SQL is the whole block (same as queryContent)
155167
if engineRecord.QuerySQL != queryContent {
156168
t.Errorf("engine received query:\n got: %q\n want: %q", engineRecord.QuerySQL, queryContent)
157169
}
@@ -265,6 +277,124 @@ func TestPluginPipeline_WithoutOverride_UsesPluginPackage(t *testing.T) {
265277
}
266278
}
267279

280+
// TestPluginPipeline_NBlocksNCalls verifies that N sqlc blocks in query.sql yield N plugin Parse calls,
281+
// and each call receives the schema (or connection params in databaseOnly mode).
282+
func TestPluginPipeline_NBlocksNCalls(t *testing.T) {
283+
const (
284+
schemaContent = "CREATE TABLE users (id INT, name TEXT);"
285+
block1 = "-- name: GetUser :one\nSELECT id, name FROM users WHERE id = $1"
286+
block2 = "-- name: ListUsers :many\nSELECT id, name FROM users ORDER BY id"
287+
)
288+
queryContent := block1 + "\n\n" + block2
289+
// QueryBlocks slices from " name: " line to the next " name: " (exclusive), so first block includes "\n\n".
290+
expectedBlock1 := block1 + "\n\n"
291+
expectedBlock2 := block2
292+
293+
engineRecord := &engineMockRecord{
294+
ReturnedSQL: "SELECT id, name FROM users WHERE id = $1",
295+
ReturnedParams: []*pb.Parameter{{Position: 1, DataType: "int", Nullable: false}},
296+
ReturnedCols: []*pb.Column{{Name: "id", DataType: "int", Nullable: false}, {Name: "name", DataType: "text", Nullable: false}},
297+
}
298+
pluginParse := func(schemaSQL, querySQL string) (*pb.ParseResponse, error) {
299+
engineRecord.Calls++
300+
engineRecord.CalledWith = append(engineRecord.CalledWith, struct{ SchemaSQL, QuerySQL string }{schemaSQL, querySQL})
301+
return &pb.ParseResponse{Sql: querySQL, Parameters: engineRecord.ReturnedParams, Columns: engineRecord.ReturnedCols}, nil
302+
}
303+
conf, _ := config.ParseConfig(strings.NewReader(testPluginPipelineConfig))
304+
inputs := &sourceFiles{
305+
Config: &conf, ConfigPath: "sqlc.yaml", Dir: ".",
306+
FileContents: map[string][]byte{"schema.sql": []byte(schemaContent), "query.sql": []byte(queryContent)},
307+
}
308+
debug := opts.DebugFromString("")
309+
debug.ProcessPlugins = true
310+
o := &Options{
311+
Env: Env{Debug: debug}, Stderr: &bytes.Buffer{}, PluginParseFunc: pluginParse,
312+
CodegenHandlerOverride: ext.HandleFunc(func(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { return &plugin.GenerateResponse{}, nil }),
313+
}
314+
_, err := generate(context.Background(), inputs, o)
315+
if err != nil {
316+
t.Fatalf("generate failed: %v", err)
317+
}
318+
if n := len(engineRecord.CalledWith); n != 2 {
319+
t.Errorf("expected 2 Parse calls (2 blocks), got %d", n)
320+
}
321+
for i, call := range engineRecord.CalledWith {
322+
if call.SchemaSQL != schemaContent {
323+
t.Errorf("Parse call %d: every call must receive schema; got schemaSQL %q", i+1, call.SchemaSQL)
324+
}
325+
}
326+
if len(engineRecord.CalledWith) >= 1 && engineRecord.CalledWith[0].QuerySQL != expectedBlock1 {
327+
t.Errorf("Parse call 1: query must be first block; got %q", engineRecord.CalledWith[0].QuerySQL)
328+
}
329+
if len(engineRecord.CalledWith) >= 2 && engineRecord.CalledWith[1].QuerySQL != expectedBlock2 {
330+
t.Errorf("Parse call 2: query must be second block; got %q", engineRecord.CalledWith[1].QuerySQL)
331+
}
332+
}
333+
334+
const testPluginPipelineConfigDatabaseOnly = `version: "2"
335+
sql:
336+
- engine: "testeng"
337+
schema: ["schema.sql"]
338+
queries: ["query.sql"]
339+
analyzer:
340+
database: only
341+
database:
342+
uri: "postgres://localhost/test"
343+
codegen:
344+
- plugin: "mock"
345+
out: "."
346+
plugins:
347+
- name: "mock"
348+
process:
349+
cmd: "echo"
350+
engines:
351+
- name: "testeng"
352+
process:
353+
cmd: "echo"
354+
`
355+
356+
// TestPluginPipeline_DatabaseOnly_ReceivesNoSchema verifies that in databaseOnly mode (analyzer.database: only +
357+
// database.uri) the plugin receives empty schema and the core would pass connection_params to the real runner.
358+
// The mock only sees (schemaSQL, querySQL); in databaseOnly we pass schemaSQL="".
359+
func TestPluginPipeline_DatabaseOnly_ReceivesNoSchema(t *testing.T) {
360+
const queryContent = "-- name: GetOne :one\nSELECT 1"
361+
engineRecord := &engineMockRecord{
362+
ReturnedSQL: "SELECT 1", ReturnedParams: nil, ReturnedCols: []*pb.Column{{Name: "?column?", DataType: "int", Nullable: true}},
363+
}
364+
pluginParse := func(schemaSQL, querySQL string) (*pb.ParseResponse, error) {
365+
engineRecord.Calls++
366+
engineRecord.CalledWith = append(engineRecord.CalledWith, struct{ SchemaSQL, QuerySQL string }{schemaSQL, querySQL})
367+
return &pb.ParseResponse{Sql: querySQL, Parameters: nil, Columns: engineRecord.ReturnedCols}, nil
368+
}
369+
conf, err := config.ParseConfig(strings.NewReader(testPluginPipelineConfigDatabaseOnly))
370+
if err != nil {
371+
t.Fatalf("parse config: %v", err)
372+
}
373+
inputs := &sourceFiles{
374+
Config: &conf, ConfigPath: "sqlc.yaml", Dir: ".",
375+
FileContents: map[string][]byte{"schema.sql": []byte("CREATE TABLE t (id INT);"), "query.sql": []byte(queryContent)},
376+
}
377+
debug := opts.DebugFromString("")
378+
debug.ProcessPlugins = true
379+
o := &Options{
380+
Env: Env{Debug: debug}, Stderr: &bytes.Buffer{}, PluginParseFunc: pluginParse,
381+
CodegenHandlerOverride: ext.HandleFunc(func(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { return &plugin.GenerateResponse{}, nil }),
382+
}
383+
_, err = generate(context.Background(), inputs, o)
384+
if err != nil {
385+
t.Fatalf("generate failed: %v", err)
386+
}
387+
if len(engineRecord.CalledWith) != 1 {
388+
t.Errorf("expected 1 Parse call, got %d", len(engineRecord.CalledWith))
389+
}
390+
if len(engineRecord.CalledWith) > 0 && engineRecord.CalledWith[0].SchemaSQL != "" {
391+
t.Errorf("databaseOnly mode: each Parse call must receive empty schema (connection_params are used by real runner); got %q", engineRecord.CalledWith[0].SchemaSQL)
392+
}
393+
if len(engineRecord.CalledWith) > 0 && engineRecord.CalledWith[0].QuerySQL != queryContent {
394+
t.Errorf("query SQL must still be passed; got %q", engineRecord.CalledWith[0].QuerySQL)
395+
}
396+
}
397+
268398
// TestPluginPipeline_OptionsOverrideNil ensures default Options do not inject mocks.
269399
func TestPluginPipeline_OptionsOverrideNil(t *testing.T) {
270400
o := &Options{}

0 commit comments

Comments
 (0)