@@ -15,12 +15,14 @@ public struct SwiftGenerator: Language {
1515
1616 public struct Query {
1717 public let statement : Statement
18- public let input : GeneratedStruct ?
19- public let output : GeneratedStruct ?
18+ public let inputStruct : GeneratedStruct ?
19+ public let inputTypeName : String
20+ public let outputStruct : GeneratedStruct ?
21+ public let outputTypeName : String
2022 public let query : DeclSyntax
2123
2224 public var decls : [ DeclSyntax ] {
23- [ input ? . decl, output ? . decl, query] . compactMap ( \. self)
25+ [ inputStruct ? . decl, outputStruct ? . decl, query] . compactMap ( \. self)
2426 }
2527 }
2628
@@ -57,14 +59,17 @@ public struct SwiftGenerator: Language {
5759 ) throws -> Query {
5860 let parameters = statement. parameters
5961
60- let ( inputTypeName, inputDecl) = try inputType ( statement: statement, name: name)
61- let ( outputTypeName, outputDecl) = try outputType ( statement: statement, name: name)
62+ let inputDecl = try generateInputTypeIfNeeded ( statement: statement, name: name)
63+ let outputDecl = try generateOutputTypeIfNeeded ( statement: statement, name: name)
64+
65+ let inputTypeName = inputType ( statement: statement, generatedInputType: inputDecl)
66+ let outputTypeName = outputType ( statement: statement, generatedOutputType: outputDecl)
6267
6368 let queryType : String = if statement. noOutput {
6469 " VoidQuery< \( inputTypeName) > "
6570 } else {
6671 switch statement. outputCardinality {
67- case . many: " FetchManyQuery< \( inputTypeName) , [ \( outputTypeName) ] >"
72+ case . many: " FetchManyQuery< \( inputTypeName) , \( outputTypeName) > "
6873 case . single: " FetchSingleQuery< \( inputTypeName) , \( outputTypeName) > "
6974 }
7075 }
@@ -122,63 +127,84 @@ public struct SwiftGenerator: Language {
122127
123128 return Query (
124129 statement: statement,
125- input: inputDecl,
126- output: outputDecl,
130+ inputStruct: inputDecl,
131+ inputTypeName: inputTypeName,
132+ outputStruct: outputDecl,
133+ outputTypeName: outputTypeName,
127134 query: DeclSyntax ( query)
128135 )
129136 }
130137
131- private static func inputType (
138+ private static func generateInputTypeIfNeeded (
132139 statement: Statement ,
133140 name: Substring
134- ) throws -> ( String , GeneratedStruct ? ) {
141+ ) throws -> GeneratedStruct ? {
142+ guard statement. parameters. count > 1 else { return nil }
143+
144+ let inputTypeName = " \( name. capitalizedFirst) Input "
145+
146+ let inputType = try structDecl (
147+ name: inputTypeName,
148+ fields: statement. parameters. map { ( $0. name, $0. type) } ,
149+ rowDecodable: false
150+ )
151+
152+ return inputType
153+ }
154+
155+ private static func inputType(
156+ statement: Statement ,
157+ generatedInputType: GeneratedStruct ?
158+ ) -> String {
135159 guard let firstParam = statement. parameters. first else {
136- return ( " () " , nil )
160+ return " () "
137161 }
138162
139- if statement. parameters. count > 1 {
140- let inputTypeName = " \( name. capitalizedFirst) Input "
141-
142- let inputType = try structDecl (
143- name: inputTypeName,
144- fields: statement. parameters. map { ( $0. name, $0. type) } ,
145- rowDecodable: false
146- )
147-
148- return ( inputTypeName, inputType)
149- } else {
150- // Single input parameter, just use the single value as the parameter type
151- return ( swiftType ( for: firstParam. type) , nil )
152- }
163+ return generatedInputType? . name ?? swiftType ( for: firstParam. type)
153164 }
154165
155166 private static func outputType(
167+ statement: Statement ,
168+ generatedOutputType: GeneratedStruct ?
169+ ) -> String {
170+ guard !statement. noOutput,
171+ let firstColumn = statement. resultColumns. columns. values. first else {
172+ return " () "
173+ }
174+
175+ // Returns the entire columns of a table, so we can just return the table
176+ if let table = statement. resultColumns. table {
177+ return table. capitalizedFirst
178+ }
179+
180+ let type = generatedOutputType? . name ?? swiftType ( for: firstColumn. root)
181+
182+ return switch statement. outputCardinality {
183+ case . single: type
184+ case . many: " [ \( type) ] "
185+ }
186+ }
187+
188+ private static func generateOutputTypeIfNeeded(
156189 statement: Statement ,
157190 name: Substring
158- ) throws -> ( String , GeneratedStruct ? ) {
191+ ) throws -> GeneratedStruct ? {
159192 // Make sure there is at least one column else return void
160- guard let first = statement. resultColumns
161- . columns. values. first else { return ( " () " , nil ) }
193+ guard !statement. resultColumns. columns. isEmpty else { return nil }
162194
163195 // Output can be mapped to a table struct
164- if let table = statement. resultColumns. table {
165- return ( table. capitalizedFirst, nil )
166- }
196+ guard statement. resultColumns. table == nil else { return nil }
167197
168198 // Only one column returned, just use it's type
169- guard statement. resultColumns. columns. count > 1 else {
170- return ( swiftType ( for: first) , nil )
171- }
199+ guard statement. resultColumns. columns. count > 1 else { return nil }
172200
173201 let outputTypeName = " \( name. capitalizedFirst) Output "
174202
175- let outputType = try structDecl (
203+ return try structDecl (
176204 name: outputTypeName,
177205 fields: statement. resultColumns. columns. map { ( $0. key. description, $0. value) } ,
178206 rowDecodable: true
179207 )
180-
181- return ( outputTypeName, outputType)
182208 }
183209
184210 public static func file(
@@ -202,11 +228,11 @@ public struct SwiftGenerator: Language {
202228 }
203229
204230 for query in queries {
205- if let input = query. input ? . decl {
231+ if let input = query. inputStruct ? . decl {
206232 input
207233 }
208234
209- if let output = query. output ? . decl {
235+ if let output = query. outputStruct ? . decl {
210236 output
211237 }
212238
@@ -215,7 +241,11 @@ public struct SwiftGenerator: Language {
215241 }
216242
217243 for query in queries {
218- if let input = query. input {
244+ if let name = query. statement. name? . capitalizedFirst {
245+ try TypeAliasDeclSyntax ( " typealias \( raw: name) Query = any Query< \( raw: query. inputTypeName) , \( raw: query. outputTypeName) > " )
246+ }
247+
248+ if let input = query. inputStruct {
219249 try ExtensionDeclSyntax ( " extension Query where Input == DB. \( raw: input. name) " ) {
220250 let parameters = input. fields. map { parameter in
221251 " \( parameter. name) : \( parameter. type) "
0 commit comments