Skip to content

Commit 11e1d2c

Browse files
authored
Add support for emitting classes (#13)
* feat(aiosqlite): update BuildPyQueryFunc to use const for connection type (main) * feat(codegen): add connection type to driver struct Added a new field `connType` to the `Driver` struct in the common.go file to store the connection type. This field is initialized based on the SQL driver configuration. (main) * feat(queries): Add method to build Python class template Added a method to build a Python class template based on the source name. (issue #main) * feat(queries): add Queries class with __init__ method Add `Queries` class with `__init__` method to manage a connection to the DB. Issue: main * feat(builders): add NewLine method to IndentStringBuilder (main) * feat(queries): Added classes for author creation, retrieval, and update Added CreateAuthorParams, GetAuthorRow, UpdateAuthorParams, UpdateAuthorTParams, and UpsertAuthorNameParams classes. Modified queries.py accordingly. Issue: feature/emit-classes (feature/emit-classes) * update queries.py * update queries.py * update queries.py * update queries.py * update queries.py * update queries.py * feat(aiosqlite): add support for class-based query functions (feature/emit-classes) * feat(common): update TypeBuildPyQueryFunc signature (feature/emit-classes) * feat(core): add option to emit classes (feature/emit-classes) * feat(internal/codegen/queries): Add support for emitting Python classes in query generation (emit-classes) * feat: enable class emission in Python code generation (feature/emit-classes)
1 parent 0f86175 commit 11e1d2c

7 files changed

Lines changed: 175 additions & 122 deletions

File tree

internal/codegen/builders/string.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,13 @@ func (b *IndentStringBuilder) WriteImportAnnotations() {
4848
b.WriteLine("from __future__ import annotations")
4949
b.WriteString("\n")
5050
}
51+
52+
func (b *IndentStringBuilder) NewLine() {
53+
b.WriteString("\n")
54+
}
55+
56+
func (b *IndentStringBuilder) NNewLine(n int) {
57+
for range n {
58+
b.WriteString("\n")
59+
}
60+
}

internal/codegen/common.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ import (
77
"github.com/rayakame/sqlc-gen-better-python/internal/core"
88
)
99

10-
type TypeBuildPyQueryFunc func(*core.Query, *builders.IndentStringBuilder, string, string) error
10+
type TypeBuildPyQueryFunc func(*core.Query, *builders.IndentStringBuilder, string, string, bool) error
1111
type TypeAcceptedDriverCMDs func() []string
1212

1313
type Driver struct {
1414
conf *core.Config
1515

16+
connType string
1617
buildPyQueryFunc TypeBuildPyQueryFunc
1718
acceptedDriverCMDs TypeAcceptedDriverCMDs
1819

@@ -22,15 +23,18 @@ type Driver struct {
2223
func NewDriver(conf *core.Config) (*Driver, error) {
2324
var buildPyQueryFunc TypeBuildPyQueryFunc
2425
var acceptedDriverCMDs TypeAcceptedDriverCMDs
26+
var connType string
2527
switch conf.SqlDriver {
2628
case core.SQLDriverAioSQLite:
2729
buildPyQueryFunc = drivers.BuildPyQueryFunc
2830
acceptedDriverCMDs = drivers.AcceptedDriverCMDs
31+
connType = drivers.AioSQLiteConn
32+
2933
default:
3034
return nil, fmt.Errorf("unsupported driver: %s", conf.SqlDriver.String())
3135
}
3236

33-
return &Driver{buildPyQueryFunc: buildPyQueryFunc, acceptedDriverCMDs: acceptedDriverCMDs, conf: conf}, nil
37+
return &Driver{buildPyQueryFunc: buildPyQueryFunc, acceptedDriverCMDs: acceptedDriverCMDs, conf: conf, connType: connType}, nil
3438
}
3539

3640
func (dr *Driver) supportedCMD(command string) error {

internal/codegen/drivers/aiosqlite.go

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,24 @@ import (
88
"strconv"
99
)
1010

11-
func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, argType string, retType string) error {
12-
body.WriteString(fmt.Sprintf("async def %s(conn: aiosqlite.Connection", query.FuncName))
11+
const AioSQLiteConn = "aiosqlite.Connection"
12+
13+
func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, argType string, retType string, isClass bool) error {
14+
indentLevel := 0
15+
params := fmt.Sprintf("conn: %s", AioSQLiteConn)
16+
conn := "conn"
17+
if isClass {
18+
params = "self"
19+
conn = "self._conn"
20+
indentLevel = 1
21+
}
22+
body.WriteIndentedString(indentLevel, fmt.Sprintf("async def %s(%s", query.FuncName, params))
1323
if argType != "" {
1424
body.WriteString(fmt.Sprintf(", %s: %s", query.Arg.Name, argType))
1525
}
1626
if query.Cmd == metadata.CmdExec {
1727
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
18-
body.WriteIndentedString(1, fmt.Sprintf("await conn.execute(%s", query.ConstantName))
28+
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName))
1929
if argType != "" {
2030
if query.Arg.IsStruct() {
2131
for _, col := range query.Arg.Table.Columns {
@@ -28,7 +38,7 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg
2838
body.WriteLine(")")
2939
} else if query.Cmd == metadata.CmdExecResult {
3040
body.WriteLine(fmt.Sprintf(") -> %s:", "aiosqlite.Cursor"))
31-
body.WriteIndentedString(1, fmt.Sprintf("return await conn.execute(%s", query.ConstantName))
41+
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName))
3242
if argType != "" {
3343
if query.Arg.IsStruct() {
3444
for _, col := range query.Arg.Table.Columns {
@@ -41,7 +51,7 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg
4151
body.WriteLine(")")
4252
} else if query.Cmd == metadata.CmdExecRows {
4353
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
44-
body.WriteIndentedString(1, fmt.Sprintf("return await conn.execute(%s", query.ConstantName))
54+
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName))
4555
if argType != "" {
4656
if query.Arg.IsStruct() {
4757
for _, col := range query.Arg.Table.Columns {
@@ -54,7 +64,7 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg
5464
body.WriteLine(").rowcount")
5565
} else if query.Cmd == metadata.CmdExecLastId {
5666
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
57-
body.WriteIndentedString(1, fmt.Sprintf("return await conn.execute(%s", query.ConstantName))
67+
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName))
5868
if argType != "" {
5969
if query.Arg.IsStruct() {
6070
for _, col := range query.Arg.Table.Columns {
@@ -67,7 +77,7 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg
6777
body.WriteLine(").lastrowid")
6878
} else if query.Cmd == metadata.CmdOne {
6979
body.WriteLine(fmt.Sprintf(") -> typing.Optional[%s]:", retType))
70-
body.WriteIndentedString(1, fmt.Sprintf("row = await (await conn.execute(%s", query.ConstantName))
80+
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("row = await (await %s.execute(%s", conn, query.ConstantName))
7181
if argType != "" {
7282
if query.Arg.IsStruct() {
7383
for _, col := range query.Arg.Table.Columns {
@@ -78,10 +88,10 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg
7888
}
7989
}
8090
body.WriteLine(")).fetchone()")
81-
body.WriteIndentedLine(1, "if row is None:")
82-
body.WriteIndentedLine(2, "return None")
91+
body.WriteIndentedLine(indentLevel+1, "if row is None:")
92+
body.WriteIndentedLine(indentLevel+2, "return None")
8393
if query.Ret.IsStruct() {
84-
body.WriteIndentedString(1, fmt.Sprintf("return %s(", retType))
94+
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s(", retType))
8595
for i, col := range query.Ret.Table.Columns {
8696
if i != 0 {
8797
body.WriteString(", ")
@@ -90,11 +100,11 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg
90100
}
91101
body.WriteLine(")")
92102
} else {
93-
body.WriteIndentedLine(1, fmt.Sprintf("return %s(row[0])", retType))
103+
body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("return %s(row[0])", retType))
94104
}
95105
} else if query.Cmd == metadata.CmdMany {
96106
body.WriteLine(fmt.Sprintf(") -> typing.AsyncIterator[%s]:", retType))
97-
body.WriteIndentedString(1, fmt.Sprintf("stream = await conn.execute(%s", query.ConstantName))
107+
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("stream = await %s.execute(%s", conn, query.ConstantName))
98108
if argType != "" {
99109
if query.Arg.IsStruct() {
100110
for _, col := range query.Arg.Table.Columns {
@@ -105,9 +115,9 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg
105115
}
106116
}
107117
body.WriteLine(")")
108-
body.WriteIndentedLine(1, "async for row in stream:")
118+
body.WriteIndentedLine(indentLevel+1, "async for row in stream:")
109119
if query.Ret.IsStruct() {
110-
body.WriteIndentedString(2, fmt.Sprintf("yield %s(", retType))
120+
body.WriteIndentedString(indentLevel+2, fmt.Sprintf("yield %s(", retType))
111121
for i, col := range query.Ret.Table.Columns {
112122
if i != 0 {
113123
body.WriteString(", ")
@@ -116,7 +126,7 @@ func BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, arg
116126
}
117127
body.WriteLine(")")
118128
} else {
119-
body.WriteIndentedLine(2, fmt.Sprintf("yield %s(row[0])", retType))
129+
body.WriteIndentedLine(indentLevel+2, fmt.Sprintf("yield %s(row[0])", retType))
120130
}
121131
}
122132
return nil

internal/codegen/queries.go

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@ import (
66
"github.com/rayakame/sqlc-gen-better-python/internal/core"
77
"github.com/sqlc-dev/plugin-sdk-go/metadata"
88
"github.com/sqlc-dev/plugin-sdk-go/plugin"
9+
"sort"
10+
"strings"
911
)
1012

11-
func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.IndentStringBuilder) (string, string) {
13+
func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.IndentStringBuilder) (string, string, []string) {
14+
pyTableNames := make([]string, 0)
1215
argType := ""
1316
if query.Arg.EmitStruct() && query.Arg.IsStruct() {
1417
BuildPyTabel(dr.conf.ModelType, query.Arg.Table, body)
1518
body.WriteString("\n\n")
1619
argType = query.Arg.Table.Name
20+
pyTableNames = append(pyTableNames, query.Arg.Table.Name)
1721
} else if !query.Arg.IsEmpty() {
1822
argType = query.Arg.Typ.Type
1923
if query.Arg.Typ.IsList {
@@ -25,6 +29,7 @@ func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.Indent
2529
BuildPyTabel(dr.conf.ModelType, query.Ret.Table, body)
2630
body.WriteString("\n\n")
2731
retType = query.Ret.Table.Name
32+
pyTableNames = append(pyTableNames, query.Ret.Table.Name)
2833
} else if !query.Ret.IsEmpty() {
2934
if query.Ret.IsStruct() {
3035
retType = fmt.Sprintf("models.%s", query.Ret.Table.Name)
@@ -35,7 +40,7 @@ func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.Indent
3540
if query.Cmd == metadata.CmdExecLastId || query.Cmd == metadata.CmdExecRows {
3641
retType = "int"
3742
}
38-
return argType, retType
43+
return argType, retType, pyTableNames
3944
}
4045

4146
func (dr *Driver) BuildPyQueriesFiles(imp *core.Importer, queries []core.Query) ([]*plugin.File, error) {
@@ -67,41 +72,69 @@ func (dr *Driver) BuildPyQueriesFiles(imp *core.Importer, queries []core.Query)
6772
}
6873

6974
func (dr *Driver) buildQueryHeader(query *core.Query, body *builders.IndentStringBuilder) {
70-
body.WriteLine(fmt.Sprintf(`%s = """-- name: %s %s`, query.ConstantName, query.MethodName, query.Cmd))
75+
body.WriteLine(fmt.Sprintf(`%s: typing.Final[str] = """-- name: %s %s`, query.ConstantName, query.MethodName, query.Cmd))
7176
body.WriteLine(query.SQL)
7277
body.WriteLine(`"""`)
7378
}
7479

80+
func (dr *Driver) buildClassTemplate(sourceName string, body *builders.IndentStringBuilder) string {
81+
className := core.SnakeToCamel(strings.ReplaceAll(sourceName, ".sql", ""), dr.conf)
82+
body.WriteLine(fmt.Sprintf("class %s:", className))
83+
body.WriteIndentedLine(1, `__slots__ = ("_conn",)`)
84+
body.NewLine()
85+
body.WriteIndentedLine(1, fmt.Sprintf(`def __init__(self, conn: %s):`, dr.connType))
86+
body.WriteIndentedLine(2, "self._conn = conn")
87+
body.NewLine()
88+
return className
89+
}
90+
7591
func (dr *Driver) buildPyQueriesFile(imp *core.Importer, queries []core.Query, sourceName string) ([]byte, error) {
7692
body := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel)
7793
body.WriteSqlcHeader()
7894
body.WriteImportAnnotations()
7995

80-
funcNames := make([]string, 0)
81-
queryBody := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel)
82-
for i, query := range queries {
83-
funcNames = append(funcNames, query.FuncName)
84-
if i != 0 {
85-
queryBody.WriteString("\n\n")
96+
newLines := 2
97+
if dr.conf.EmitClasses {
98+
newLines = 1
99+
}
100+
101+
allNames := make([]string, 0)
102+
funcBody := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel)
103+
pyTableBody := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel)
104+
for _, query := range queries {
105+
if !dr.conf.EmitClasses {
106+
allNames = append(allNames, query.FuncName)
86107
}
87-
dr.buildQueryHeader(&query, queryBody)
88-
queryBody.WriteString("\n\n")
89-
argType, retType := dr.prepareFunctionHeader(&query, queryBody)
90-
err := dr.buildPyQueryFunc(&query, queryBody, argType, retType)
108+
dr.buildQueryHeader(&query, funcBody)
109+
funcBody.NewLine()
110+
}
111+
funcBody.NewLine()
112+
if dr.conf.EmitClasses {
113+
allNames = append(allNames, dr.buildClassTemplate(sourceName, funcBody))
114+
}
115+
for i, query := range queries {
116+
argType, retType, addedPyTableNames := dr.prepareFunctionHeader(&query, pyTableBody)
117+
allNames = append(allNames, addedPyTableNames...)
118+
err := dr.buildPyQueryFunc(&query, funcBody, argType, retType, dr.conf.EmitClasses)
91119
if err != nil {
92120
return nil, err
93121
}
122+
if i != len(queries)-1 {
123+
funcBody.NNewLine(newLines)
124+
}
94125
}
95126
body.WriteLine("__all__: typing.Sequence[str] = (")
96-
for _, n := range funcNames {
127+
if len(allNames) > 0 {
128+
sort.Slice(allNames, func(i, j int) bool { return allNames[i] < allNames[j] })
129+
}
130+
for _, n := range allNames {
97131
body.WriteIndentedLine(1, fmt.Sprintf("\"%s\",", n))
98132
}
99133
body.WriteLine(")")
100-
body.WriteString("\n")
134+
body.NewLine()
101135
for _, imp := range imp.Imports(sourceName) {
102136
body.WriteLine(imp)
103137
}
104-
body.WriteString("\n")
105-
106-
return []byte(body.String() + queryBody.String()), nil
138+
body.NNewLine(2)
139+
return []byte(body.String() + pyTableBody.String() + funcBody.String()), nil
107140
}

internal/core/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type Config struct {
1414
ModelType string `json:"model_type" yaml:"model_type"`
1515
Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms"`
1616
EmitExactTableNames bool `json:"emit_exact_table_names,omitempty" yaml:"emit_exact_table_names"`
17+
EmitClasses bool `json:"emit_classes" yaml:"emit_classes"`
1718
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names"`
1819
OmitUnusedStructs bool `json:"omit_unused_structs,omitempty" yaml:"omit_unused_structs"`
1920
QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"`

sqlc.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: python
44
wasm:
55
url: file://sqlc-gen-better-python.wasm
6-
sha256: 5befaea487aa1950817a03b9576da0d5e331541894eb63a95f8e7770194ce83c
6+
sha256: 67389a6e3bfdaf7e78ff7e85a4c497beb849aff7de9d4e283feda44ffe3f22a3
77
sql:
88
- schema: test/schema.sql
99
queries: test/queries.sql
@@ -15,4 +15,5 @@ sql:
1515
package: test
1616
sql_driver: aiosqlite
1717
model_type: dataclass
18+
emit_classes: true
1819

0 commit comments

Comments
 (0)