@@ -4,14 +4,15 @@ import (
44 "fmt"
55 "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders"
66 "github.com/rayakame/sqlc-gen-better-python/internal/core"
7+ "github.com/rayakame/sqlc-gen-better-python/internal/log"
78 "github.com/sqlc-dev/plugin-sdk-go/metadata"
89 "strconv"
910 "strings"
1011)
1112
1213const AsyncpgConn = "asyncpg.Connection[asyncpg.Record]"
1314
14- func AsyncpgBuildPyQueryFunc (query * core.Query , body * builders.IndentStringBuilder , args []string , retType string , isClass bool ) error {
15+ func AsyncpgBuildPyQueryFunc (query * core.Query , body * builders.IndentStringBuilder , args []string , retType core. PyType , isClass bool ) error {
1516 indentLevel := 0
1617 params := fmt .Sprintf ("conn: %s" , AsyncpgConn )
1718 conn := "conn"
@@ -28,19 +29,19 @@ func AsyncpgBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuild
2829 body .WriteString (fmt .Sprintf (", %s" , arg ))
2930 }
3031 if query .Cmd == metadata .CmdExec {
31- body .WriteLine (fmt .Sprintf (") -> %s:" , retType ))
32+ body .WriteLine (fmt .Sprintf (") -> %s:" , retType . Type ))
3233 body .WriteIndentedString (indentLevel + 1 , fmt .Sprintf ("await %s.execute(%s" , conn , query .ConstantName ))
3334 asyncpgWriteParams (query , body )
3435 body .WriteLine (")" )
3536 } else if query .Cmd == metadata .CmdOne {
36- body .WriteLine (fmt .Sprintf (") -> typing.Optional[%s]:" , retType ))
37+ body .WriteLine (fmt .Sprintf (") -> typing.Optional[%s]:" , retType . Type ))
3738 body .WriteIndentedString (indentLevel + 1 , fmt .Sprintf ("row = await %s.fetchrow(%s" , conn , query .ConstantName ))
3839 asyncpgWriteParams (query , body )
3940 body .WriteLine (")" )
4041 body .WriteIndentedLine (indentLevel + 1 , "if row is None:" )
4142 body .WriteIndentedLine (indentLevel + 2 , "return None" )
4243 if query .Ret .IsStruct () {
43- body .WriteIndentedString (indentLevel + 1 , fmt .Sprintf ("return %s(" , retType ))
44+ body .WriteIndentedString (indentLevel + 1 , fmt .Sprintf ("return %s(" , retType . Type ))
4445 i := 0
4546 for _ , col := range query .Ret .Table .Columns {
4647 if i != 0 {
@@ -50,28 +51,43 @@ func AsyncpgBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuild
5051 var inner []string
5152 body .WriteString (fmt .Sprintf ("%s=%s(" , col .Name , col .Type .Type ))
5253 for _ , embedCol := range col .EmbedFields {
53- inner = append (inner , fmt .Sprintf ("%s=row[%s]" , embedCol .Name , strconv .Itoa (i )))
54+ if embedCol .Name == "age" || embedCol .Name == "id" {
55+ log .GlobalLogger .LogByte ([]byte (embedCol .Type .SqlType ))
56+ }
57+ if _ , found := AsyncpgDoTypeConversion ()[embedCol .Type .SqlType ]; found {
58+ inner = append (inner , fmt .Sprintf ("%s=%s(row[%s])" , embedCol .Name , embedCol .Type .Type , strconv .Itoa (i )))
59+ } else {
60+ inner = append (inner , fmt .Sprintf ("%s=row[%s]" , embedCol .Name , strconv .Itoa (i )))
61+ }
5462 i ++
5563 }
5664 body .WriteString (strings .Join (inner , ", " ) + ")" )
5765 } else {
58- body .WriteString (fmt .Sprintf ("%s=row[%s]" , col .Name , strconv .Itoa (i )))
66+ if _ , found := AsyncpgDoTypeConversion ()[col .Type .SqlType ]; found {
67+ body .WriteString (fmt .Sprintf ("%s=%s(row[%s])" , col .Name , col .Type .Type , strconv .Itoa (i )))
68+ } else {
69+ body .WriteString (fmt .Sprintf ("%s=row[%s]" , col .Name , strconv .Itoa (i )))
70+ }
5971 i ++
6072 }
6173 }
6274 body .WriteLine (")" )
6375 } else {
64- body .WriteIndentedLine (indentLevel + 1 , fmt .Sprintf ("return %s(row[0])" , retType ))
76+ if _ , found := AsyncpgDoTypeConversion ()[retType .SqlType ]; found {
77+ body .WriteIndentedLine (indentLevel + 1 , fmt .Sprintf ("return %s(row[0])" , retType .Type ))
78+ } else {
79+ body .WriteIndentedLine (indentLevel + 1 , "return row[0]" )
80+ }
6581 }
6682 } else if query .Cmd == metadata .CmdMany {
67- body .WriteLine (fmt .Sprintf (") -> typing.Sequence[%s]:" , retType ))
83+ body .WriteLine (fmt .Sprintf (") -> typing.Sequence[%s]:" , retType . Type ))
6884 body .WriteIndentedString (indentLevel + 1 , fmt .Sprintf ("rows = await %s.fetch(%s" , conn , query .ConstantName ))
6985 asyncpgWriteParams (query , body )
7086 body .WriteLine (")" )
71- body .WriteIndentedLine (indentLevel + 1 , fmt .Sprintf ("return_rows: typing.List[%s] = []" , retType ))
87+ body .WriteIndentedLine (indentLevel + 1 , fmt .Sprintf ("return_rows: typing.List[%s] = []" , retType . Type ))
7288 body .WriteIndentedLine (indentLevel + 1 , "for row in rows:" )
7389 if query .Ret .IsStruct () {
74- body .WriteIndentedString (indentLevel + 2 , fmt .Sprintf ("return_rows.append(%s(" , retType ))
90+ body .WriteIndentedString (indentLevel + 2 , fmt .Sprintf ("return_rows.append(%s(" , retType . Type ))
7591 i := 0
7692 for _ , col := range query .Ret .Table .Columns {
7793 if i != 0 {
@@ -81,7 +97,11 @@ func AsyncpgBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuild
8197 var inner []string
8298 body .WriteString (fmt .Sprintf ("%s=%s(" , col .Name , col .Type .Type ))
8399 for _ , embedCol := range col .EmbedFields {
84- inner = append (inner , fmt .Sprintf ("%s=row[%s]" , embedCol .Name , strconv .Itoa (i )))
100+ if _ , found := AsyncpgDoTypeConversion ()[embedCol .Type .SqlType ]; found {
101+ inner = append (inner , fmt .Sprintf ("%s=%s(row[%s])" , embedCol .Name , embedCol .Type .Type , strconv .Itoa (i )))
102+ } else {
103+ inner = append (inner , fmt .Sprintf ("%s=row[%s]" , embedCol .Name , strconv .Itoa (i )))
104+ }
85105 i ++
86106 }
87107 body .WriteString (strings .Join (inner , ", " ) + ")" )
@@ -93,7 +113,11 @@ func AsyncpgBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuild
93113 body .WriteLine ("))" )
94114 body .WriteIndentedLine (indentLevel + 1 , "return return_rows" )
95115 } else {
96- body .WriteIndentedLine (indentLevel + 2 , fmt .Sprintf ("return_rows.append(%s(row[0]))" , retType ))
116+ if _ , found := AsyncpgDoTypeConversion ()[retType .SqlType ]; found {
117+ body .WriteIndentedLine (indentLevel + 2 , fmt .Sprintf ("return_rows.append(%s(row[0]))" , retType .Type ))
118+ } else {
119+ body .WriteIndentedLine (indentLevel + 2 , "return_rows.append(row[0])" )
120+ }
97121 body .WriteIndentedLine (indentLevel + 1 , "return return_rows" )
98122 }
99123 }
@@ -108,6 +132,16 @@ func AsyncpgAcceptedDriverCMDs() []string {
108132 }
109133}
110134
135+ func AsyncpgDoTypeConversion () map [string ]struct {} {
136+ return map [string ]struct {}{
137+ "bytea" : {},
138+ "blob" : {},
139+ "pg_catalog.bytea" : {},
140+ "inet" : {},
141+ "cidr" : {},
142+ }
143+ }
144+
111145func asyncpgWriteParams (query * core.Query , body * builders.IndentStringBuilder ) {
112146 if len (query .Args ) == 0 {
113147 return
0 commit comments