Skip to content

Commit 8ed5990

Browse files
authored
Add support for sqlc.embed (#21)
* chore(queries): Add `GetStudentAndScore` to queries.py (feature/sqlc-embed) * feat(sqlc): Add support for all sqlc macros and list supported Query commands (feature/sqlc-embed) * fix: Update sha256 checksum for python wasm file The sha256 checksum for the python wasm file was updated to 9a83b2a13344cebb64f33b3deacbab37a098409bd8027b996e32f041aead9267. (feature/sqlc-embed) * feat(types): add support for 'bigserial' type in SQLite (sqlc-embed) * feat(sqlc-gen): Refactor SQLite3 driver to handle embedded fields (sqlc-embed) * feat(sqlc-gen): Refactor AioSQLite driver to handle embedded fields (sqlc-embed) * refactor: Remove unnecessary log package import and log calls (feature/sqlc-embed) * fix: Remove 'models.' prefix from type names before storing in keepTypes (feature/sqlc-embed) * feat(models): Add SourceName field to Table struct (sqlc-embed) * feat(models): Remove SourceName field to Table struct (sqlc-embed) * feat: add GetStudentAndScoresRow dataclass (feature/sqlc-embed) * feat(sqlc-gen-better-python): add GetStudentAndScores query (sqlc-embed) * fragment
1 parent e7295a4 commit 8ed5990

12 files changed

Lines changed: 162 additions & 26 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
kind: Added
2+
body: Added support for `sqlc.embed()`
3+
time: 2025-04-01T03:29:27.9306375+02:00
4+
custom:
5+
Author: rayakame
6+
PR: "21"

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ A WASM plugin for SQLC allowing the generation of Python code.
77
> Please wait for the first GitHub release; before that, this plugin is likely to not work.
88
99
## Feature Support
10+
Every [sqlc macro](https://docs.sqlc.dev/en/latest/reference/macros.html) is supported.
11+
The supported Query commands depend on the SQL driver you are using, supported commands are listed below.
1012
> Every `:batch*` command is not supported by this plugin and probably will never be.
1113
1214
> Prepared Queries are not planned for the near future, but will be implemented sooner or later

internal/builders.go

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"fmt"
55
"github.com/rayakame/sqlc-gen-better-python/internal/core"
66
"github.com/rayakame/sqlc-gen-better-python/internal/inflection"
7-
"github.com/rayakame/sqlc-gen-better-python/internal/log"
87
"github.com/sqlc-dev/plugin-sdk-go/metadata"
98
"github.com/sqlc-dev/plugin-sdk-go/plugin"
109
"github.com/sqlc-dev/plugin-sdk-go/sdk"
@@ -156,7 +155,6 @@ func newGoEmbed(embed *plugin.Identifier, structs []core.Table, defaultSchema st
156155
for i, f := range s.Columns {
157156
fields[i] = f
158157
}
159-
160158
return &goEmbed{
161159
modelType: s.Name,
162160
modelName: s.Name,
@@ -245,15 +243,6 @@ func (gen *PythonGenerator) buildQueries(tables []core.Table) ([]core.Query, err
245243
sameName := f.Name == core.ColumnName(c, i)
246244
sameType := f.Type == gen.makePythonType(c)
247245
sameTable := sdk.SameTableName(c.Table, s.Table, gen.req.Catalog.DefaultSchema)
248-
if gq.MethodName == "ListAuthors" {
249-
log.GlobalLogger.Log(core.SnakeToCamel(core.ColumnName(c, i), gen.config))
250-
if !sameType {
251-
log.GlobalLogger.Log("TypeError")
252-
}
253-
if !sameTable {
254-
log.GlobalLogger.Log("TableError")
255-
}
256-
}
257246
if !sameName || !sameType || !sameTable {
258247
same = false
259248
}
@@ -321,17 +310,22 @@ func (gen *PythonGenerator) columnsToStruct(name string, columns []goColumn, use
321310
if suffix > 0 {
322311
fieldName = fmt.Sprintf("%s_%d", fieldName, suffix)
323312
}
313+
324314
f := core.Column{
325-
Name: core.ColumnName(c.Column, i),
315+
Name: inflection.Singular(inflection.SingularParams{
316+
Name: core.ColumnName(c.Column, i),
317+
Exclusions: gen.config.InflectionExcludeTableNames,
318+
}),
326319
DBName: colName,
327320
Column: c.Column,
328321
}
322+
329323
if c.embed == nil {
330324
f.Type = gen.makePythonType(c.Column)
331325
} else {
332326
f.Type = core.PyType{
333327
SqlType: c.embed.modelType,
334-
Type: c.embed.modelType,
328+
Type: "models." + c.embed.modelType,
335329
IsList: false,
336330
IsNullable: false,
337331
IsEnum: false,

internal/codegen/drivers/aiosqlite.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/rayakame/sqlc-gen-better-python/internal/core"
77
"github.com/sqlc-dev/plugin-sdk-go/metadata"
88
"strconv"
9+
"strings"
910
)
1011

1112
const AioSQLiteConn = "aiosqlite.Connection"
@@ -55,11 +56,23 @@ func AioSQLiteBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBui
5556
body.WriteIndentedLine(indentLevel+2, "return None")
5657
if query.Ret.IsStruct() {
5758
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s(", retType))
58-
for i, col := range query.Ret.Table.Columns {
59+
i := 0
60+
for _, col := range query.Ret.Table.Columns {
5961
if i != 0 {
6062
body.WriteString(", ")
6163
}
62-
body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i)))
64+
if len(col.EmbedFields) != 0 {
65+
var inner []string
66+
body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type))
67+
for _, embedCol := range col.EmbedFields {
68+
inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i)))
69+
i++
70+
}
71+
body.WriteString(strings.Join(inner, ", ") + ")")
72+
} else {
73+
body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i)))
74+
i++
75+
}
6376
}
6477
body.WriteLine(")")
6578
} else {
@@ -73,11 +86,23 @@ func AioSQLiteBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBui
7386
body.WriteIndentedLine(indentLevel+1, "async for row in stream:")
7487
if query.Ret.IsStruct() {
7588
body.WriteIndentedString(indentLevel+2, fmt.Sprintf("yield %s(", retType))
76-
for i, col := range query.Ret.Table.Columns {
89+
i := 0
90+
for _, col := range query.Ret.Table.Columns {
7791
if i != 0 {
7892
body.WriteString(", ")
7993
}
80-
body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i)))
94+
if len(col.EmbedFields) != 0 {
95+
var inner []string
96+
body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type))
97+
for _, embedCol := range col.EmbedFields {
98+
inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i)))
99+
i++
100+
}
101+
body.WriteString(strings.Join(inner, ", ") + ")")
102+
} else {
103+
body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i)))
104+
i++
105+
}
81106
}
82107
body.WriteLine(")")
83108
} else {

internal/codegen/drivers/sqlite3.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/rayakame/sqlc-gen-better-python/internal/core"
77
"github.com/sqlc-dev/plugin-sdk-go/metadata"
88
"strconv"
9+
"strings"
910
)
1011

1112
const SQLite3Conn = "sqlite3.Connection"
@@ -55,11 +56,23 @@ func SQLite3BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuild
5556
body.WriteIndentedLine(indentLevel+2, "return None")
5657
if query.Ret.IsStruct() {
5758
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s(", retType))
58-
for i, col := range query.Ret.Table.Columns {
59+
i := 0
60+
for _, col := range query.Ret.Table.Columns {
5961
if i != 0 {
6062
body.WriteString(", ")
6163
}
62-
body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i)))
64+
if len(col.EmbedFields) != 0 {
65+
var inner []string
66+
body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type))
67+
for _, embedCol := range col.EmbedFields {
68+
inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i)))
69+
i++
70+
}
71+
body.WriteString(strings.Join(inner, ", ") + ")")
72+
} else {
73+
body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i)))
74+
i++
75+
}
6376
}
6477
body.WriteLine(")")
6578
} else {
@@ -73,11 +86,23 @@ func SQLite3BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuild
7386
body.WriteLine(").fetchall():")
7487
if query.Ret.IsStruct() {
7588
body.WriteIndentedString(indentLevel+2, fmt.Sprintf("rows.append(%s(", retType))
76-
for i, col := range query.Ret.Table.Columns {
89+
i := 0
90+
for _, col := range query.Ret.Table.Columns {
7791
if i != 0 {
7892
body.WriteString(", ")
7993
}
80-
body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i)))
94+
if len(col.EmbedFields) != 0 {
95+
var inner []string
96+
body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type))
97+
for _, embedCol := range col.EmbedFields {
98+
inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i)))
99+
i++
100+
}
101+
body.WriteString(strings.Join(inner, ", ") + ")")
102+
} else {
103+
body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i)))
104+
i++
105+
}
81106
}
82107
body.WriteLine("))")
83108
} else {

internal/gen.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/rayakame/sqlc-gen-better-python/internal/log"
1010
"github.com/rayakame/sqlc-gen-better-python/internal/types"
1111
"github.com/sqlc-dev/plugin-sdk-go/plugin"
12+
"strings"
1213
)
1314

1415
type PythonGenerator struct {
@@ -141,9 +142,9 @@ func filterUnusedStructs(enums []core.Enum, tables []core.Table, queries []core.
141142
keepTypes[query.Ret.Type()] = struct{}{}
142143
if query.Ret.IsStruct() {
143144
for _, field := range query.Ret.Table.Columns {
144-
keepTypes[field.Type.Type] = struct{}{}
145+
keepTypes[strings.ReplaceAll(field.Type.Type, "models.", "")] = struct{}{}
145146
for _, embedField := range field.EmbedFields {
146-
keepTypes[embedField.Type.Type] = struct{}{}
147+
keepTypes[strings.ReplaceAll(embedField.Type.Type, "models.", "")] = struct{}{}
147148
}
148149
}
149150
}

internal/types/sqlite.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func SqliteTypeToPython(_ *plugin.GenerateRequest, col *plugin.Column, _ *core.C
1414

1515
// see: https://github.com/sqlc-dev/sqlc/blob/main/internal/codegen/golang/sqlite_type.go
1616
switch columnType {
17-
case "int", "integer", "tinyint", "smallint", "mediumint", "bigint", "unsignedbigint", "int2", "int8":
17+
case "int", "integer", "tinyint", "smallint", "mediumint", "bigint", "unsignedbigint", "int2", "int8", "bigserial":
1818
return "int"
1919
case "blob":
2020
return "bytes"

sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: python
44
wasm:
55
url: file://sqlc-gen-better-python.wasm
6-
sha256: 353b5d20dd9bf75f4ca24329fcae7655ffc3f86e43c30aec72b0e7462f94617b
6+
sha256: 9a83b2a13344cebb64f33b3deacbab37a098409bd8027b996e32f041aead9267
77
sql:
88
- schema: test/schema.sql
99
queries: test/queries.sql

test/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
__all__: typing.Sequence[str] = (
88
"Author",
9+
"Student",
10+
"TestScore",
911
)
1012

1113
import dataclasses
@@ -17,3 +19,17 @@ class Author:
1719
id: int
1820
name: str
1921
bio: typing.Optional[str]
22+
23+
24+
@dataclasses.dataclass()
25+
class Student:
26+
id: int
27+
name: typing.Optional[str]
28+
age: typing.Optional[int]
29+
30+
31+
@dataclasses.dataclass()
32+
class TestScore:
33+
student_id: typing.Optional[int]
34+
score: typing.Optional[int]
35+
grade: typing.Optional[str]

test/queries.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
__all__: typing.Sequence[str] = (
88
"GetAuthorRow",
9+
"GetStudentAndScoreRow",
10+
"GetStudentAndScoresRow",
911
"create_author",
1012
"delete_author",
1113
"get_author",
14+
"get_student_and_score",
15+
"get_student_and_scores",
1216
"list_authors",
1317
"update_author",
1418
"update_author_t",
@@ -29,6 +33,18 @@ class GetAuthorRow:
2933
name: str
3034

3135

36+
@dataclasses.dataclass()
37+
class GetStudentAndScoreRow:
38+
student: models.Student
39+
test_score: models.TestScore
40+
41+
42+
@dataclasses.dataclass()
43+
class GetStudentAndScoresRow:
44+
student: models.Student
45+
test_score: models.TestScore
46+
47+
3248
CREATE_AUTHOR: typing.Final[str] = """-- name: CreateAuthor :one
3349
INSERT INTO authors (name, bio)
3450
VALUES (?, ?) RETURNING id, name, bio
@@ -46,6 +62,19 @@ class GetAuthorRow:
4662
WHERE id = ? LIMIT 1
4763
"""
4864

65+
GET_STUDENT_AND_SCORE: typing.Final[str] = """-- name: GetStudentAndScore :one
66+
SELECT students.id, students.name, students.age, test_scores.student_id, test_scores.score, test_scores.grade
67+
FROM students
68+
JOIN test_scores ON test_scores.student_id = students.id
69+
WHERE students.id = ?
70+
"""
71+
72+
GET_STUDENT_AND_SCORES: typing.Final[str] = """-- name: GetStudentAndScores :many
73+
SELECT students.id, students.name, students.age, test_scores.student_id, test_scores.score, test_scores.grade
74+
FROM students
75+
JOIN test_scores ON test_scores.student_id = students.id
76+
"""
77+
4978
LIST_AUTHORS: typing.Final[str] = """-- name: ListAuthors :many
5079
SELECT id, name, bio
5180
FROM authors
@@ -91,6 +120,20 @@ def get_author(conn: sqlite3.Connection, *, id: int) -> typing.Optional[GetAutho
91120
return GetAuthorRow(id=row[0], name=row[1])
92121

93122

123+
def get_student_and_score(conn: sqlite3.Connection, *, id: int) -> typing.Optional[GetStudentAndScoreRow]:
124+
row = conn.execute(GET_STUDENT_AND_SCORE,(id, )).fetchone()
125+
if row is None:
126+
return None
127+
return GetStudentAndScoreRow(student=models.Student(id=row[0], name=row[1], age=row[2]), test_score=models.TestScore(student_id=row[3], score=row[4], grade=row[5]))
128+
129+
130+
def get_student_and_scores(conn: sqlite3.Connection) -> typing.List[GetStudentAndScoresRow]:
131+
rows: typing.List[GetStudentAndScoresRow] = []
132+
for row in conn.execute(GET_STUDENT_AND_SCORES).fetchall():
133+
rows.append(GetStudentAndScoresRow(student=models.Student(id=row[0], name=row[1], age=row[2]), test_score=models.TestScore(student_id=row[3], score=row[4], grade=row[5])))
134+
return rows
135+
136+
94137
def list_authors(conn: sqlite3.Connection, *, ids: typing.Sequence[int]) -> typing.List[models.Author]:
95138
rows: typing.List[models.Author] = []
96139
for row in conn.execute(LIST_AUTHORS,(ids, )).fetchall():

0 commit comments

Comments
 (0)