Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Changed-20250401-022038.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Changed
body: Query functions now dont take param-structs
time: 2025-04-01T02:20:38.5896761+02:00
custom:
Author: rayakame
PR: "20"
28 changes: 11 additions & 17 deletions internal/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,35 +197,29 @@ func (gen *PythonGenerator) buildQueries(tables []core.Table) ([]core.Query, err

if len(query.Params) == 1 && qpl != 0 {
p := query.Params[0]
gq.Arg = core.QueryValue{
gq.Args = []core.QueryValue{{
Name: core.Escape(core.ParamName(p)),
DBName: p.Column.GetName(),
Typ: gen.makePythonType(p.Column),
Column: p.Column,
}
}}
} else if len(query.Params) >= 1 {
var cols []goColumn
var values []core.QueryValue
for _, p := range query.Params {
cols = append(cols, goColumn{
id: int(p.Number),
values = append(values, core.QueryValue{
Name: core.Escape(core.ParamName(p)),
DBName: p.Column.GetName(),
Typ: gen.makePythonType(p.Column),
Column: p.Column,
})
}
s, err := gen.columnsToStruct(gq.MethodName+"Params", cols, false)
if err != nil {
return nil, err
}
gq.Arg = core.QueryValue{
Emit: true,
Name: "arg",
Table: s,
}
gq.Args = values

// if query params is 2, and query params limit is 4 AND this is a copyfrom, we still want to emit the query's model
// otherwise we end up with a copyfrom using a struct without the struct definition
if len(query.Params) <= qpl && query.Cmd != ":copyfrom" {
gq.Arg.Emit = false
}
//if len(query.Params) <= qpl && query.Cmd != ":copyfrom" {
// gq.Args.Emit = false
//}
}

if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/rayakame/sqlc-gen-better-python/internal/core"
)

type TypeBuildPyQueryFunc func(*core.Query, *builders.IndentStringBuilder, string, string, bool) error
type TypeBuildPyQueryFunc func(*core.Query, *builders.IndentStringBuilder, []string, string, bool) error
type TypeAcceptedDriverCMDs func() []string

type Driver struct {
Expand Down
82 changes: 25 additions & 57 deletions internal/codegen/drivers/aiosqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

const AioSQLiteConn = "aiosqlite.Connection"

func AioSQLiteBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, argType string, retType string, isClass bool) error {
func AioSQLiteBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, args []string, retType string, isClass bool) error {
indentLevel := 0
params := fmt.Sprintf("conn: %s", AioSQLiteConn)
conn := "conn"
Expand All @@ -20,73 +20,36 @@ func AioSQLiteBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBui
indentLevel = 1
}
body.WriteIndentedString(indentLevel, fmt.Sprintf("async def %s(%s", query.FuncName, params))
if argType != "" {
body.WriteString(fmt.Sprintf(", %s: %s", query.Arg.Name, argType))
for i, arg := range args {
if i == 0 {
body.WriteString(", *")
}
body.WriteString(fmt.Sprintf(", %s", arg))
}
if query.Cmd == metadata.CmdExec {
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName))
if argType != "" {
if query.Arg.IsStruct() {
for _, col := range query.Arg.Table.Columns {
body.WriteString(fmt.Sprintf(", %s.%s", query.Arg.Name, col.Name))
}
} else {
body.WriteString(fmt.Sprintf(", %s", query.Arg.Name))
}
}
aiosqliteWriteParams(query, body)
body.WriteLine(")")
} else if query.Cmd == metadata.CmdExecResult {
body.WriteLine(fmt.Sprintf(") -> %s:", "aiosqlite.Cursor"))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName))
if argType != "" {
if query.Arg.IsStruct() {
for _, col := range query.Arg.Table.Columns {
body.WriteString(fmt.Sprintf(", %s.%s", query.Arg.Name, col.Name))
}
} else {
body.WriteString(fmt.Sprintf(", %s", query.Arg.Name))
}
}
aiosqliteWriteParams(query, body)
body.WriteLine(")")
} else if query.Cmd == metadata.CmdExecRows {
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName))
if argType != "" {
if query.Arg.IsStruct() {
for _, col := range query.Arg.Table.Columns {
body.WriteString(fmt.Sprintf(", %s.%s", query.Arg.Name, col.Name))
}
} else {
body.WriteString(fmt.Sprintf(", %s", query.Arg.Name))
}
}
aiosqliteWriteParams(query, body)
body.WriteLine(").rowcount")
} else if query.Cmd == metadata.CmdExecLastId {
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName))
if argType != "" {
if query.Arg.IsStruct() {
for _, col := range query.Arg.Table.Columns {
body.WriteString(fmt.Sprintf(", %s.%s", query.Arg.Name, col.Name))
}
} else {
body.WriteString(fmt.Sprintf(", %s", query.Arg.Name))
}
}
aiosqliteWriteParams(query, body)
body.WriteLine(").lastrowid")
} else if query.Cmd == metadata.CmdOne {
body.WriteLine(fmt.Sprintf(") -> typing.Optional[%s]:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("row = await (await %s.execute(%s", conn, query.ConstantName))
if argType != "" {
if query.Arg.IsStruct() {
for _, col := range query.Arg.Table.Columns {
body.WriteString(fmt.Sprintf(", %s.%s", query.Arg.Name, col.Name))
}
} else {
body.WriteString(fmt.Sprintf(", %s", query.Arg.Name))
}
}
aiosqliteWriteParams(query, body)
body.WriteLine(")).fetchone()")
body.WriteIndentedLine(indentLevel+1, "if row is None:")
body.WriteIndentedLine(indentLevel+2, "return None")
Expand All @@ -105,15 +68,7 @@ func AioSQLiteBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBui
} else if query.Cmd == metadata.CmdMany {
body.WriteLine(fmt.Sprintf(") -> typing.AsyncIterator[%s]:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("stream = await %s.execute(%s", conn, query.ConstantName))
if argType != "" {
if query.Arg.IsStruct() {
for _, col := range query.Arg.Table.Columns {
body.WriteString(fmt.Sprintf(", %s.%s", query.Arg.Name, col.Name))
}
} else {
body.WriteString(fmt.Sprintf(", %s", query.Arg.Name))
}
}
aiosqliteWriteParams(query, body)
body.WriteLine(")")
body.WriteIndentedLine(indentLevel+1, "async for row in stream:")
if query.Ret.IsStruct() {
Expand Down Expand Up @@ -142,3 +97,16 @@ func AioSQLiteAcceptedDriverCMDs() []string {
metadata.CmdMany,
}
}

func aiosqliteWriteParams(query *core.Query, body *builders.IndentStringBuilder) {
if len(query.Args) == 0 {
return
}
params := "("
for _, arg := range query.Args {
if !arg.IsEmpty() {
params += fmt.Sprintf("%s, ", arg.Name)
}
}
body.WriteString("," + params + ")")
}
42 changes: 24 additions & 18 deletions internal/codegen/drivers/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

const SQLite3Conn = "sqlite3.Connection"

func SQLite3BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, argType string, retType string, isClass bool) error {
func SQLite3BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, args []string, retType string, isClass bool) error {
indentLevel := 0
params := fmt.Sprintf("conn: %s", SQLite3Conn)
conn := "conn"
Expand All @@ -20,33 +20,36 @@ func SQLite3BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuild
indentLevel = 1
}
body.WriteIndentedString(indentLevel, fmt.Sprintf("def %s(%s", query.FuncName, params))
if argType != "" {
body.WriteString(fmt.Sprintf(", %s: %s", query.Arg.Name, argType))
for i, arg := range args {
if i == 0 {
body.WriteString(", *")
}
body.WriteString(fmt.Sprintf(", %s", arg))
}
if query.Cmd == metadata.CmdExec {
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName))
writeParams(query, body, argType)
sqlite3WriteParams(query, body)
body.WriteLine(")")
} else if query.Cmd == metadata.CmdExecResult {
body.WriteLine(fmt.Sprintf(") -> %s:", "sqlite3.Cursor"))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName))
writeParams(query, body, argType)
sqlite3WriteParams(query, body)
body.WriteLine(")")
} else if query.Cmd == metadata.CmdExecRows {
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName))
writeParams(query, body, argType)
sqlite3WriteParams(query, body)
body.WriteLine(").rowcount")
} else if query.Cmd == metadata.CmdExecLastId {
body.WriteLine(fmt.Sprintf(") -> %s:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName))
writeParams(query, body, argType)
sqlite3WriteParams(query, body)
body.WriteLine(").lastrowid")
} else if query.Cmd == metadata.CmdOne {
body.WriteLine(fmt.Sprintf(") -> typing.Optional[%s]:", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("row = %s.execute(%s", conn, query.ConstantName))
writeParams(query, body, argType)
sqlite3WriteParams(query, body)
body.WriteLine(").fetchone()")
body.WriteIndentedLine(indentLevel+1, "if row is None:")
body.WriteIndentedLine(indentLevel+2, "return None")
Expand All @@ -66,7 +69,7 @@ func SQLite3BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuild
body.WriteLine(fmt.Sprintf(") -> typing.List[%s]:", retType))
body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("rows: typing.List[%s] = []", retType))
body.WriteIndentedString(indentLevel+1, fmt.Sprintf("for row in %s.execute(%s", conn, query.ConstantName))
writeParams(query, body, argType)
sqlite3WriteParams(query, body)
body.WriteLine(").fetchall():")
if query.Ret.IsStruct() {
body.WriteIndentedString(indentLevel+2, fmt.Sprintf("rows.append(%s(", retType))
Expand Down Expand Up @@ -96,16 +99,19 @@ func SQLite3AcceptedDriverCMDs() []string {
}
}

func writeParams(query *core.Query, body *builders.IndentStringBuilder, argType string) {
if argType != "" {
params := "("
if query.Arg.IsStruct() {
for _, col := range query.Arg.Table.Columns {
params += fmt.Sprintf("%s.%s, ", query.Arg.Name, col.Name)
func sqlite3WriteParams(query *core.Query, body *builders.IndentStringBuilder) {
if len(query.Args) == 0 {
return
}
params := "("
for i, arg := range query.Args {
if !arg.IsEmpty() {
if i == len(query.Args)-1 && i != 0 {
params += fmt.Sprintf("%s", arg.Name)
} else {
params += fmt.Sprintf("%s, ", arg.Name)
}
} else {
params += fmt.Sprintf("%s, ", query.Arg.Name)
}
body.WriteString("," + params + ")")
}
body.WriteString("," + params + ")")
}
26 changes: 12 additions & 14 deletions internal/codegen/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,16 @@ import (
"strings"
)

func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.IndentStringBuilder) (string, string, []string) {
func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.IndentStringBuilder) ([]string, string, []string) {
pyTableNames := make([]string, 0)
argType := ""
if query.Arg.EmitStruct() && query.Arg.IsStruct() {
BuildPyTabel(dr.conf.ModelType, query.Arg.Table, body)
body.WriteString("\n\n")
argType = query.Arg.Table.Name
pyTableNames = append(pyTableNames, query.Arg.Table.Name)
} else if !query.Arg.IsEmpty() {
argType = query.Arg.Typ.Type
if query.Arg.Typ.IsList {
argType = fmt.Sprintf("typing.Sequence[%s]", argType)
args := make([]string, 0)
for _, arg := range query.Args {
if !arg.IsEmpty() {
argType := arg.Typ.Type
if arg.Typ.IsList {
argType = fmt.Sprintf("typing.Sequence[%s]", argType)
}
args = append(args, fmt.Sprintf("%s: %s", arg.Name, argType))
}
}
retType := "None"
Expand All @@ -40,7 +38,7 @@ func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.Indent
if query.Cmd == metadata.CmdExecLastId || query.Cmd == metadata.CmdExecRows {
retType = "int"
}
return argType, retType, pyTableNames
return args, retType, pyTableNames
}

func (dr *Driver) BuildPyQueriesFiles(imp *core.Importer, queries []core.Query) ([]*plugin.File, error) {
Expand Down Expand Up @@ -113,9 +111,9 @@ func (dr *Driver) buildPyQueriesFile(imp *core.Importer, queries []core.Query, s
allNames = append(allNames, dr.buildClassTemplate(sourceName, funcBody))
}
for i, query := range queries {
argType, retType, addedPyTableNames := dr.prepareFunctionHeader(&query, pyTableBody)
args, retType, addedPyTableNames := dr.prepareFunctionHeader(&query, pyTableBody)
allNames = append(allNames, addedPyTableNames...)
err := dr.buildPyQueryFunc(&query, funcBody, argType, retType, dr.conf.EmitClasses)
err := dr.buildPyQueryFunc(&query, funcBody, args, retType, dr.conf.EmitClasses)
if err != nil {
return nil, err
}
Expand Down
7 changes: 4 additions & 3 deletions internal/core/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ func (i *Importer) queryImportSpecs(fileName string) (map[string]importSpec, map
if queryValueUses(name, q.Ret) {
return true
}
if queryValueUses(name, q.Arg) {
return true
for _, arg := range q.Args {
if queryValueUses(name, arg) {
return true
}
}
}
return false
Expand All @@ -143,7 +145,6 @@ func (i *Importer) queryImportSpecs(fileName string) (map[string]importSpec, map
// continue
//}
queryValueModelImports(q.Ret)
queryValueModelImports(q.Arg)
}

loc["models"] = importSpec{Module: i.C.Package, Name: "models"}
Expand Down
18 changes: 7 additions & 11 deletions internal/core/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ func EnumReplace(value string) string {
}

type QueryValue struct {
Emit bool
EmitPointer bool
Name string
DBName string // The name of the field in the database. Only set if Struct==nil.
Table *Table
Typ PyType
Emit bool
Name string
DBName string // The name of the field in the database. Only set if Struct==nil.
Table *Table
Typ PyType

// Column is kept so late in the generation process around to differentiate
// between mysql slices and pg arrays
Expand All @@ -71,10 +70,6 @@ func (v QueryValue) IsStruct() bool {
return v.Table != nil
}

func (v QueryValue) IsPointer() bool {
return v.EmitPointer && v.Table != nil
}

func (v QueryValue) IsEmpty() bool {
return v.Typ.Type == "" && v.Name == "" && v.Table == nil
}
Expand All @@ -99,7 +94,8 @@ type Query struct {
SQL string
SourceName string
Ret QueryValue
Arg QueryValue
Args []QueryValue

// Used for :copyfrom
Table *plugin.Identifier
}
Expand Down
Loading