@@ -11,7 +11,7 @@ import {
1111 factory ,
1212} from "typescript" ;
1313
14- import { GenerateRequest , GenerateResponse , File } from "./gen/plugin/codegen_pb" ;
14+ import { GenerateRequest , GenerateResponse , File , Enum } from "./gen/plugin/codegen_pb" ;
1515
1616import { argName , colName } from "./drivers/utils" ;
1717import { assertUniqueNames } from "./validate" ;
@@ -24,8 +24,64 @@ const result = codegen(input);
2424// Write the result to stdout
2525writeOutput ( result ) ;
2626
27+ /**
28+ * Build a map of enum names to their values from the catalog.
29+ * This allows us to recognize enum types and generate appropriate TypeScript types.
30+ */
31+ function buildEnumMap ( input : GenerateRequest ) : Map < string , Enum > {
32+ const enumMap = new Map < string , Enum > ( ) ;
33+ const defaultSchema = input . catalog ?. defaultSchema ?? "public" ;
34+
35+ for ( const schema of input . catalog ?. schemas ?? [ ] ) {
36+ if ( schema . name === "pg_catalog" || schema . name === "information_schema" ) {
37+ continue ;
38+ }
39+
40+ for ( const enumDef of schema . enums ) {
41+ // Store with both qualified and unqualified names
42+ enumMap . set ( enumDef . name , enumDef ) ;
43+ if ( schema . name !== defaultSchema ) {
44+ enumMap . set ( `${ schema . name } .${ enumDef . name } ` , enumDef ) ;
45+ }
46+ }
47+ }
48+
49+ return enumMap ;
50+ }
51+
52+ /**
53+ * Generate TypeScript union type for an enum.
54+ * e.g., type EventSource = 'user' | 'runner' | 'system';
55+ */
56+ function enumTypeDecl ( name : string , enumDef : Enum ) : Node {
57+ const unionType = factory . createUnionTypeNode (
58+ enumDef . vals . map ( ( val ) => factory . createLiteralTypeNode ( factory . createStringLiteral ( val ) ) ) ,
59+ ) ;
60+
61+ return factory . createTypeAliasDeclaration (
62+ [ factory . createToken ( SyntaxKind . ExportKeyword ) ] ,
63+ factory . createIdentifier ( pascalCase ( name ) ) ,
64+ undefined ,
65+ unionType ,
66+ ) ;
67+ }
68+
69+ /**
70+ * Convert snake_case to PascalCase
71+ */
72+ function pascalCase ( str : string ) : string {
73+ return str
74+ . split ( "_" )
75+ . map ( ( part ) => part . charAt ( 0 ) . toUpperCase ( ) + part . slice ( 1 ) . toLowerCase ( ) )
76+ . join ( "" ) ;
77+ }
78+
2779function codegen ( input : GenerateRequest ) : GenerateResponse {
28- let files = [ ] ;
80+ const files = [ ] ;
81+ const enumMap = buildEnumMap ( input ) ;
82+
83+ // Set the enum map in the postgres driver so columnType can use it
84+ postgres . setEnumMap ( enumMap ) ;
2985
3086 const querymap = new Map < string , typeof input . queries > ( ) ;
3187
@@ -37,9 +93,15 @@ function codegen(input: GenerateRequest): GenerateResponse {
3793 qs ?. push ( query ) ;
3894 }
3995
96+ // Track which enums are used across all files
97+ const usedEnums = new Set < string > ( ) ;
98+
4099 for ( const [ filename , queries ] of querymap . entries ( ) ) {
41100 const nodes : Node [ ] = [ ...postgres . preamble ( ) ] ;
42101
102+ // Track enums used in this file
103+ const fileEnums = new Set < string > ( ) ;
104+
43105 for ( const query of queries ) {
44106 const lowerName = query . name [ 0 ] . toLowerCase ( ) + query . name . slice ( 1 ) ;
45107
@@ -56,22 +118,37 @@ function codegen(input: GenerateRequest): GenerateResponse {
56118 names,
57119 } ) ;
58120
59- nodes . push (
60- factory . createInterfaceDeclaration (
61- [ factory . createToken ( SyntaxKind . ExportKeyword ) ] ,
62- factory . createIdentifier ( argIface ) ,
63- undefined ,
64- undefined ,
65- query . params . map ( ( param , i ) =>
66- factory . createPropertySignature (
67- undefined ,
68- factory . createIdentifier ( argName ( i , param . column ) ) ,
69- undefined ,
70- postgres . columnType ( param . column ) ,
121+ // Check for enum usage in params
122+ for ( const param of query . params ) {
123+ const enumName = postgres . getEnumName ( param . column ) ;
124+ if ( enumName ) {
125+ fileEnums . add ( enumName ) ;
126+ usedEnums . add ( enumName ) ;
127+ }
128+ }
129+
130+ try {
131+ nodes . push (
132+ factory . createInterfaceDeclaration (
133+ [ factory . createToken ( SyntaxKind . ExportKeyword ) ] ,
134+ factory . createIdentifier ( argIface ) ,
135+ undefined ,
136+ undefined ,
137+ query . params . map ( ( param , i ) =>
138+ factory . createPropertySignature (
139+ undefined ,
140+ factory . createIdentifier ( argName ( i , param . column ) ) ,
141+ undefined ,
142+ postgres . columnType ( param . column ) ,
143+ ) ,
71144 ) ,
72145 ) ,
73- ) ,
74- ) ;
146+ ) ;
147+ } catch ( err ) {
148+ throw new Error (
149+ `Error in query "${ query . name } " (${ filename } ): ${ err instanceof Error ? err . message : String ( err ) } ` ,
150+ ) ;
151+ }
75152 }
76153
77154 if ( query . columns . length > 0 ) {
@@ -84,22 +161,37 @@ function codegen(input: GenerateRequest): GenerateResponse {
84161 names,
85162 } ) ;
86163
87- nodes . push (
88- factory . createInterfaceDeclaration (
89- [ factory . createToken ( SyntaxKind . ExportKeyword ) ] ,
90- factory . createIdentifier ( returnIface ) ,
91- undefined ,
92- undefined ,
93- query . columns . map ( ( column , i ) =>
94- factory . createPropertySignature (
95- undefined ,
96- factory . createIdentifier ( colName ( i , column ) ) ,
97- undefined ,
98- postgres . columnType ( column ) ,
164+ // Check for enum usage in columns
165+ for ( const col of query . columns ) {
166+ const enumName = postgres . getEnumName ( col ) ;
167+ if ( enumName ) {
168+ fileEnums . add ( enumName ) ;
169+ usedEnums . add ( enumName ) ;
170+ }
171+ }
172+
173+ try {
174+ nodes . push (
175+ factory . createInterfaceDeclaration (
176+ [ factory . createToken ( SyntaxKind . ExportKeyword ) ] ,
177+ factory . createIdentifier ( returnIface ) ,
178+ undefined ,
179+ undefined ,
180+ query . columns . map ( ( column , i ) =>
181+ factory . createPropertySignature (
182+ undefined ,
183+ factory . createIdentifier ( colName ( i , column ) ) ,
184+ undefined ,
185+ postgres . columnType ( column ) ,
186+ ) ,
99187 ) ,
100188 ) ,
101- ) ,
102- ) ;
189+ ) ;
190+ } catch ( err ) {
191+ throw new Error (
192+ `Error in query "${ query . name } " (${ filename } ): ${ err instanceof Error ? err . message : String ( err ) } ` ,
193+ ) ;
194+ }
103195 }
104196
105197 switch ( query . cmd ) {
@@ -140,6 +232,19 @@ function codegen(input: GenerateRequest): GenerateResponse {
140232 }
141233 }
142234
235+ // Add enum type declarations at the beginning of the file (after imports)
236+ const enumNodes : Node [ ] = [ ] ;
237+ for ( const enumName of fileEnums ) {
238+ const enumDef = enumMap . get ( enumName ) ;
239+ if ( enumDef ) {
240+ enumNodes . push ( enumTypeDecl ( enumName , enumDef ) ) ;
241+ }
242+ }
243+
244+ // Insert enum declarations after the preamble (imports)
245+ const preambleLength = postgres . preamble ( ) . length ;
246+ nodes . splice ( preambleLength , 0 , ...enumNodes ) ;
247+
143248 files . push (
144249 new File ( {
145250 name : `${ filename . replace ( "." , "_" ) } .ts` ,
@@ -169,7 +274,7 @@ function printNode(nodes: Node[]): string {
169274 ) ;
170275 const printer = createPrinter ( { newLine : NewLineKind . LineFeed } ) ;
171276 let output = "// Code generated by sqlc. DO NOT EDIT.\n\n" ;
172- for ( let node of nodes ) {
277+ for ( const node of nodes ) {
173278 output += printer . printNode ( EmitHint . Unspecified , node , resultFile ) ;
174279 output += "\n\n" ;
175280 }
0 commit comments