Skip to content

Commit fca4321

Browse files
luismingaticlaude
andcommitted
feat(#2348, #2997): add sqlc.nembed() for nullable embeds
Introduces `sqlc.nembed()` function that generates nullable (pointer) embed structs in Go codegen. When a LEFT JOIN may produce NULL rows, nembed scans into temporary pointer variables and conditionally constructs the embedded struct only when at least one field is non-nil. Closes #2348 Closes #2997 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ce83d3f commit fca4321

File tree

30 files changed

+746
-34
lines changed

30 files changed

+746
-34
lines changed

internal/analysis/analysis.pb.go

Lines changed: 9 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/cmd/shim.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
213213
}
214214
}
215215

216+
out.IsNullableEmbed = c.NullableEmbed
217+
216218
return out
217219
}
218220

internal/codegen/golang/field.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ import (
1010
"github.com/sqlc-dev/sqlc/internal/plugin"
1111
)
1212

13+
// NullableEmbedFieldInfo stores metadata for scanning nullable embed fields
14+
type NullableEmbedFieldInfo struct {
15+
TempVarName string // ex: "nembedPostID"
16+
ScanType string // ex: "sql.NullInt32" for stdlib, "*int32" for pgx
17+
ValidExpr string // ex: "nembedPostID.Valid" or "nembedPostID != nil"
18+
AssignExpr string // ex: "nembedPostID.Int32" or "*nembedPostID"
19+
StructField string // ex: "ID"
20+
OriginalType string // original type in the model struct (ex: "int32")
21+
}
22+
1323
type Field struct {
1424
Name string // CamelCased name for Go
1525
DBName string // Name as used in the DB
@@ -19,6 +29,10 @@ type Field struct {
1929
Column *plugin.Column
2030
// EmbedFields contains the embedded fields that require scanning.
2131
EmbedFields []Field
32+
// IsNullableEmbed indicates this field is a nullable embed (*Struct)
33+
IsNullableEmbed bool
34+
// NullableEmbedInfo stores scan metadata for each embedded field
35+
NullableEmbedInfo []NullableEmbedFieldInfo
2236
}
2337

2438
func (gf Field) Tag() string {

internal/codegen/golang/gen.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,9 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu
383383
keepTypes[query.Ret.Type()] = struct{}{}
384384
if query.Ret.IsStruct() {
385385
for _, field := range query.Ret.Struct.Fields {
386-
keepTypes[strings.TrimPrefix(field.Type, "[]")] = struct{}{}
386+
trimmedType := strings.TrimPrefix(field.Type, "[]")
387+
trimmedType = strings.TrimPrefix(trimmedType, "*")
388+
keepTypes[trimmedType] = struct{}{}
387389
for _, embedField := range field.EmbedFields {
388390
keepTypes[embedField.Type] = struct{}{}
389391
}

internal/codegen/golang/imports.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,12 @@ func (i *importer) queryImports(filename string) fileImports {
319319
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
320320
return true
321321
}
322+
// Check nullable embed scan types
323+
for _, info := range f.NullableEmbedInfo {
324+
if hasPrefixIgnoringSliceAndPointerPrefix(info.ScanType, name) {
325+
return true
326+
}
327+
}
322328
}
323329
}
324330
if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) {
@@ -459,6 +465,11 @@ func (i *importer) batchImports() fileImports {
459465
if hasPrefixIgnoringSliceAndPointerPrefix(f.Type, name) {
460466
return true
461467
}
468+
for _, info := range f.NullableEmbedInfo {
469+
if hasPrefixIgnoringSliceAndPointerPrefix(info.ScanType, name) {
470+
return true
471+
}
472+
}
462473
}
463474
}
464475
if hasPrefixIgnoringSliceAndPointerPrefix(q.Ret.Type(), name) {

internal/codegen/golang/query.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ func (v QueryValue) Scan() string {
201201
} else {
202202
for _, f := range v.Struct.Fields {
203203

204+
// nullable embed: scan into temporary variables
205+
if f.IsNullableEmbed && len(f.NullableEmbedInfo) > 0 {
206+
for _, info := range f.NullableEmbedInfo {
207+
out = append(out, "&"+info.TempVarName)
208+
}
209+
continue
210+
}
211+
204212
// append any embedded fields
205213
if len(f.EmbedFields) > 0 {
206214
for _, embed := range f.EmbedFields {
@@ -227,6 +235,77 @@ func (v QueryValue) Scan() string {
227235
return "\n" + strings.Join(out, ",\n")
228236
}
229237

238+
// HasNullableEmbeds returns true if the query value has nullable embed fields
239+
func (v QueryValue) HasNullableEmbeds() bool {
240+
if v.Struct == nil {
241+
return false
242+
}
243+
for _, f := range v.Struct.Fields {
244+
if f.IsNullableEmbed {
245+
return true
246+
}
247+
}
248+
return false
249+
}
250+
251+
// NullableEmbedDecls generates declarations for nullable embed temporary variables
252+
func (v QueryValue) NullableEmbedDecls() string {
253+
if v.Struct == nil {
254+
return ""
255+
}
256+
var lines []string
257+
for _, f := range v.Struct.Fields {
258+
if !f.IsNullableEmbed {
259+
continue
260+
}
261+
for _, info := range f.NullableEmbedInfo {
262+
lines = append(lines, fmt.Sprintf("var %s %s", info.TempVarName, info.ScanType))
263+
}
264+
}
265+
if len(lines) == 0 {
266+
return ""
267+
}
268+
return "\n" + strings.Join(lines, "\n")
269+
}
270+
271+
// NullableEmbedAssigns generates post-scan code to construct nullable embed structs
272+
func (v QueryValue) NullableEmbedAssigns() string {
273+
if v.Struct == nil {
274+
return ""
275+
}
276+
var blocks []string
277+
for _, f := range v.Struct.Fields {
278+
if !f.IsNullableEmbed || len(f.NullableEmbedInfo) == 0 {
279+
continue
280+
}
281+
282+
// Build the validity check: any field non-nil means the row exists
283+
var validChecks []string
284+
for _, info := range f.NullableEmbedInfo {
285+
validChecks = append(validChecks, info.ValidExpr)
286+
}
287+
288+
// Build the struct assignment
289+
modelType := strings.TrimPrefix(f.Type, "*")
290+
var assignments []string
291+
for _, info := range f.NullableEmbedInfo {
292+
assignments = append(assignments, fmt.Sprintf("%s: %s,", info.StructField, info.AssignExpr))
293+
}
294+
295+
block := fmt.Sprintf("if %s {\n%s.%s = &%s{\n%s\n}\n}",
296+
strings.Join(validChecks, " || "),
297+
v.Name, f.Name,
298+
modelType,
299+
strings.Join(assignments, "\n"),
300+
)
301+
blocks = append(blocks, block)
302+
}
303+
if len(blocks) == 0 {
304+
return ""
305+
}
306+
return "\n" + strings.Join(blocks, "\n")
307+
}
308+
230309
// Deprecated: This method does not respect the Emit field set on the
231310
// QueryValue. It's used by the go-sql-driver-mysql/copyfromCopy.tmpl and should
232311
// not be used other places.

internal/codegen/golang/result.go

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,12 @@ type goEmbed struct {
120120
modelType string
121121
modelName string
122122
fields []Field
123+
nullable bool
123124
}
124125

125126
// look through all the structs and attempt to find a matching one to embed
126127
// We need the name of the struct and its field names.
127-
func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string) *goEmbed {
128+
func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string, nullable bool) *goEmbed {
128129
if embed == nil {
129130
return nil
130131
}
@@ -147,6 +148,7 @@ func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string
147148
modelType: s.Name,
148149
modelName: s.Name,
149150
fields: fields,
151+
nullable: nullable,
150152
}
151153
}
152154

@@ -304,7 +306,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
304306
columns = append(columns, goColumn{
305307
id: i,
306308
Column: c,
307-
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema),
309+
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema, c.IsNullableEmbed),
308310
})
309311
}
310312
var err error
@@ -396,6 +398,11 @@ func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name st
396398
}
397399
if c.embed == nil {
398400
f.Type = goType(req, options, c.Column)
401+
} else if c.embed.nullable {
402+
f.Type = "*" + c.embed.modelType
403+
f.EmbedFields = c.embed.fields
404+
f.IsNullableEmbed = true
405+
f.NullableEmbedInfo = computeNullableEmbedInfo(c.embed)
399406
} else {
400407
f.Type = c.embed.modelType
401408
f.EmbedFields = c.embed.fields
@@ -435,6 +442,40 @@ func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name st
435442
return &gs, nil
436443
}
437444

445+
// computeNullableEmbedInfo computes scan metadata for nullable embed fields.
446+
// For each field in the embed, we scan into a pointer-typed temporary variable,
447+
// then check if any temp var is non-nil to construct the struct.
448+
func computeNullableEmbedInfo(embed *goEmbed) []NullableEmbedFieldInfo {
449+
var infos []NullableEmbedFieldInfo
450+
for _, f := range embed.fields {
451+
varName := "nembed" + embed.modelName + f.Name
452+
originalType := f.Type
453+
var scanType, validExpr, assignExpr string
454+
455+
if strings.HasPrefix(originalType, "*") {
456+
// Already a pointer type (e.g., pgx nullable), scan directly
457+
scanType = originalType
458+
validExpr = varName + " != nil"
459+
assignExpr = varName
460+
} else {
461+
// Wrap in pointer for nullable scan
462+
scanType = "*" + originalType
463+
validExpr = varName + " != nil"
464+
assignExpr = "*" + varName
465+
}
466+
467+
infos = append(infos, NullableEmbedFieldInfo{
468+
TempVarName: varName,
469+
ScanType: scanType,
470+
ValidExpr: validExpr,
471+
AssignExpr: assignExpr,
472+
StructField: f.Name,
473+
OriginalType: originalType,
474+
})
475+
}
476+
return infos
477+
}
478+
438479
func checkIncompatibleFieldTypes(fields []Field) error {
439480
fieldTypes := map[string]string{}
440481
for _, field := range fields {

internal/codegen/golang/templates/pgx/batchCode.tmpl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,11 @@ func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, e
9191
defer rows.Close()
9292
for rows.Next() {
9393
var {{.Ret.Name}} {{.Ret.Type}}
94+
{{- .Ret.NullableEmbedDecls}}
9495
if err := rows.Scan({{.Ret.Scan}}); err != nil {
9596
return err
9697
}
98+
{{- .Ret.NullableEmbedAssigns}}
9799
items = append(items, {{.Ret.ReturnName}})
98100
}
99101
return rows.Err()
@@ -110,6 +112,7 @@ func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}},
110112
defer b.br.Close()
111113
for t := 0; t < b.tot; t++ {
112114
var {{.Ret.Name}} {{.Ret.Type}}
115+
{{- .Ret.NullableEmbedDecls}}
113116
if b.closed {
114117
if f != nil {
115118
f(t, {{if .Ret.IsPointer}}nil{{else}}{{.Ret.Name}}{{end}}, ErrBatchAlreadyClosed)
@@ -118,6 +121,7 @@ func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}},
118121
}
119122
row := b.br.QueryRow()
120123
err := row.Scan({{.Ret.Scan}})
124+
{{- .Ret.NullableEmbedAssigns}}
121125
if f != nil {
122126
f(t, {{.Ret.ReturnName}}, err)
123127
}

internal/codegen/golang/templates/pgx/queryCode.tmpl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De
3636
{{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }}
3737
var {{.Ret.Name}} {{.Ret.Type}}
3838
{{- end}}
39+
{{- .Ret.NullableEmbedDecls}}
3940
err := row.Scan({{.Ret.Scan}})
41+
{{- .Ret.NullableEmbedAssigns}}
4042
{{- if $.WrapErrors}}
4143
if err != nil {
4244
err = fmt.Errorf("query {{.MethodName}}: %w", err)
@@ -67,9 +69,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.
6769
{{end -}}
6870
for rows.Next() {
6971
var {{.Ret.Name}} {{.Ret.Type}}
72+
{{- .Ret.NullableEmbedDecls}}
7073
if err := rows.Scan({{.Ret.Scan}}); err != nil {
7174
return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}}
7275
}
76+
{{- .Ret.NullableEmbedAssigns}}
7377
items = append(items, {{.Ret.ReturnName}})
7478
}
7579
if err := rows.Err(); err != nil {

internal/codegen/golang/templates/stdlib/queryCode.tmpl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}
2727
{{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }}
2828
var {{.Ret.Name}} {{.Ret.Type}}
2929
{{- end}}
30+
{{- .Ret.NullableEmbedDecls}}
3031
err := row.Scan({{.Ret.Scan}})
32+
{{- .Ret.NullableEmbedAssigns}}
3133
{{- if $.WrapErrors}}
3234
if err != nil {
3335
err = fmt.Errorf("query {{.MethodName}}: %w", err)
@@ -53,9 +55,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}
5355
{{end -}}
5456
for rows.Next() {
5557
var {{.Ret.Name}} {{.Ret.Type}}
58+
{{- .Ret.NullableEmbedDecls}}
5659
if err := rows.Scan({{.Ret.Scan}}); err != nil {
5760
return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}}
5861
}
62+
{{- .Ret.NullableEmbedAssigns}}
5963
items = append(items, {{.Ret.ReturnName}})
6064
}
6165
if err := rows.Close(); err != nil {

0 commit comments

Comments
 (0)