Skip to content

Commit 85cd6a8

Browse files
committed
Better enum handling
1 parent 5cf1494 commit 85cd6a8

File tree

2 files changed

+260
-37
lines changed

2 files changed

+260
-37
lines changed

src/app.ts

Lines changed: 136 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import {
1111
factory,
1212
} from "typescript";
1313

14-
import { GenerateRequest, GenerateResponse, File } from "./gen/plugin/codegen_pb";
14+
import { GenerateRequest, GenerateResponse, File, Enum } from "./gen/plugin/codegen_pb";
1515

1616
import { argName, colName } from "./drivers/utils";
1717
import { assertUniqueNames } from "./validate";
@@ -24,8 +24,64 @@ const result = codegen(input);
2424
// Write the result to stdout
2525
writeOutput(result);
2626

27+
/**
28+
* Build a map of enum names to their values from the catalog.
29+
* This allows us to recognize enum types and generate appropriate TypeScript types.
30+
*/
31+
function buildEnumMap(input: GenerateRequest): Map<string, Enum> {
32+
const enumMap = new Map<string, Enum>();
33+
const defaultSchema = input.catalog?.defaultSchema ?? "public";
34+
35+
for (const schema of input.catalog?.schemas ?? []) {
36+
if (schema.name === "pg_catalog" || schema.name === "information_schema") {
37+
continue;
38+
}
39+
40+
for (const enumDef of schema.enums) {
41+
// Store with both qualified and unqualified names
42+
enumMap.set(enumDef.name, enumDef);
43+
if (schema.name !== defaultSchema) {
44+
enumMap.set(`${schema.name}.${enumDef.name}`, enumDef);
45+
}
46+
}
47+
}
48+
49+
return enumMap;
50+
}
51+
52+
/**
53+
* Generate TypeScript union type for an enum.
54+
* e.g., type EventSource = 'user' | 'runner' | 'system';
55+
*/
56+
function enumTypeDecl(name: string, enumDef: Enum): Node {
57+
const unionType = factory.createUnionTypeNode(
58+
enumDef.vals.map((val) => factory.createLiteralTypeNode(factory.createStringLiteral(val))),
59+
);
60+
61+
return factory.createTypeAliasDeclaration(
62+
[factory.createToken(SyntaxKind.ExportKeyword)],
63+
factory.createIdentifier(pascalCase(name)),
64+
undefined,
65+
unionType,
66+
);
67+
}
68+
69+
/**
70+
* Convert snake_case to PascalCase
71+
*/
72+
function pascalCase(str: string): string {
73+
return str
74+
.split("_")
75+
.map((part) => part.charAt(0).toUpperCase() + part.slice(1).toLowerCase())
76+
.join("");
77+
}
78+
2779
function codegen(input: GenerateRequest): GenerateResponse {
28-
let files = [];
80+
const files = [];
81+
const enumMap = buildEnumMap(input);
82+
83+
// Set the enum map in the postgres driver so columnType can use it
84+
postgres.setEnumMap(enumMap);
2985

3086
const querymap = new Map<string, typeof input.queries>();
3187

@@ -37,9 +93,15 @@ function codegen(input: GenerateRequest): GenerateResponse {
3793
qs?.push(query);
3894
}
3995

96+
// Track which enums are used across all files
97+
const usedEnums = new Set<string>();
98+
4099
for (const [filename, queries] of querymap.entries()) {
41100
const nodes: Node[] = [...postgres.preamble()];
42101

102+
// Track enums used in this file
103+
const fileEnums = new Set<string>();
104+
43105
for (const query of queries) {
44106
const lowerName = query.name[0].toLowerCase() + query.name.slice(1);
45107

@@ -56,22 +118,37 @@ function codegen(input: GenerateRequest): GenerateResponse {
56118
names,
57119
});
58120

59-
nodes.push(
60-
factory.createInterfaceDeclaration(
61-
[factory.createToken(SyntaxKind.ExportKeyword)],
62-
factory.createIdentifier(argIface),
63-
undefined,
64-
undefined,
65-
query.params.map((param, i) =>
66-
factory.createPropertySignature(
67-
undefined,
68-
factory.createIdentifier(argName(i, param.column)),
69-
undefined,
70-
postgres.columnType(param.column),
121+
// Check for enum usage in params
122+
for (const param of query.params) {
123+
const enumName = postgres.getEnumName(param.column);
124+
if (enumName) {
125+
fileEnums.add(enumName);
126+
usedEnums.add(enumName);
127+
}
128+
}
129+
130+
try {
131+
nodes.push(
132+
factory.createInterfaceDeclaration(
133+
[factory.createToken(SyntaxKind.ExportKeyword)],
134+
factory.createIdentifier(argIface),
135+
undefined,
136+
undefined,
137+
query.params.map((param, i) =>
138+
factory.createPropertySignature(
139+
undefined,
140+
factory.createIdentifier(argName(i, param.column)),
141+
undefined,
142+
postgres.columnType(param.column),
143+
),
71144
),
72145
),
73-
),
74-
);
146+
);
147+
} catch (err) {
148+
throw new Error(
149+
`Error in query "${query.name}" (${filename}): ${err instanceof Error ? err.message : String(err)}`,
150+
);
151+
}
75152
}
76153

77154
if (query.columns.length > 0) {
@@ -84,22 +161,37 @@ function codegen(input: GenerateRequest): GenerateResponse {
84161
names,
85162
});
86163

87-
nodes.push(
88-
factory.createInterfaceDeclaration(
89-
[factory.createToken(SyntaxKind.ExportKeyword)],
90-
factory.createIdentifier(returnIface),
91-
undefined,
92-
undefined,
93-
query.columns.map((column, i) =>
94-
factory.createPropertySignature(
95-
undefined,
96-
factory.createIdentifier(colName(i, column)),
97-
undefined,
98-
postgres.columnType(column),
164+
// Check for enum usage in columns
165+
for (const col of query.columns) {
166+
const enumName = postgres.getEnumName(col);
167+
if (enumName) {
168+
fileEnums.add(enumName);
169+
usedEnums.add(enumName);
170+
}
171+
}
172+
173+
try {
174+
nodes.push(
175+
factory.createInterfaceDeclaration(
176+
[factory.createToken(SyntaxKind.ExportKeyword)],
177+
factory.createIdentifier(returnIface),
178+
undefined,
179+
undefined,
180+
query.columns.map((column, i) =>
181+
factory.createPropertySignature(
182+
undefined,
183+
factory.createIdentifier(colName(i, column)),
184+
undefined,
185+
postgres.columnType(column),
186+
),
99187
),
100188
),
101-
),
102-
);
189+
);
190+
} catch (err) {
191+
throw new Error(
192+
`Error in query "${query.name}" (${filename}): ${err instanceof Error ? err.message : String(err)}`,
193+
);
194+
}
103195
}
104196

105197
switch (query.cmd) {
@@ -140,6 +232,19 @@ function codegen(input: GenerateRequest): GenerateResponse {
140232
}
141233
}
142234

235+
// Add enum type declarations at the beginning of the file (after imports)
236+
const enumNodes: Node[] = [];
237+
for (const enumName of fileEnums) {
238+
const enumDef = enumMap.get(enumName);
239+
if (enumDef) {
240+
enumNodes.push(enumTypeDecl(enumName, enumDef));
241+
}
242+
}
243+
244+
// Insert enum declarations after the preamble (imports)
245+
const preambleLength = postgres.preamble().length;
246+
nodes.splice(preambleLength, 0, ...enumNodes);
247+
143248
files.push(
144249
new File({
145250
name: `${filename.replace(".", "_")}.ts`,
@@ -169,7 +274,7 @@ function printNode(nodes: Node[]): string {
169274
);
170275
const printer = createPrinter({ newLine: NewLineKind.LineFeed });
171276
let output = "// Code generated by sqlc. DO NOT EDIT.\n\n";
172-
for (let node of nodes) {
277+
for (const node of nodes) {
173278
output += printer.printNode(EmitHint.Unspecified, node, resultFile);
174279
output += "\n\n";
175280
}

src/drivers/postgres.ts

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,102 @@
99

1010
import { SyntaxKind, NodeFlags, TypeNode, factory, FunctionDeclaration } from "typescript";
1111

12-
import { Parameter, Column } from "../gen/plugin/codegen_pb";
12+
import { Parameter, Column, Enum } from "../gen/plugin/codegen_pb";
1313
import { argName } from "./utils";
1414

15+
// Map of enum names to their definitions, set by app.ts
16+
let enumMap: Map<string, Enum> = new Map();
17+
18+
/**
19+
* Set the enum map from the catalog. Called by app.ts before generating code.
20+
*/
21+
export function setEnumMap(map: Map<string, Enum>): void {
22+
enumMap = map;
23+
}
24+
25+
/**
26+
* Check if a column type is an enum and return the enum name if so.
27+
*/
28+
export function getEnumName(column?: Column): string | null {
29+
if (column === undefined || column.type === undefined) {
30+
return null;
31+
}
32+
const typeName = column.type.name.toLowerCase();
33+
if (enumMap.has(typeName)) {
34+
return typeName;
35+
}
36+
return null;
37+
}
38+
39+
/**
40+
* Convert snake_case to PascalCase for enum type names
41+
*/
42+
function pascalCase(str: string): string {
43+
return str
44+
.split("_")
45+
.map((part) => part.charAt(0).toUpperCase() + part.slice(1).toLowerCase())
46+
.join("");
47+
}
48+
1549
export function columnType(column?: Column): TypeNode {
1650
if (column === undefined || column.type === undefined) {
1751
return factory.createKeywordTypeNode(SyntaxKind.AnyKeyword);
1852
}
19-
let typeName = column.type.name;
53+
const originalTypeName = column.type.name;
54+
let typeName = originalTypeName;
2055
const pgCatalog = "pg_catalog.";
2156
if (typeName.startsWith(pgCatalog)) {
2257
typeName = typeName.slice(pgCatalog.length);
2358
}
2459

25-
typeName = typeName.toLowerCase();
60+
const lowerTypeName = typeName.toLowerCase();
2661

27-
let typ: TypeNode = factory.createKeywordTypeNode(SyntaxKind.StringKeyword);
28-
switch (typeName) {
62+
// Check if it's an enum type
63+
if (enumMap.has(lowerTypeName)) {
64+
const typ = factory.createTypeReferenceNode(
65+
factory.createIdentifier(pascalCase(lowerTypeName)),
66+
undefined,
67+
);
68+
if (column.isArray || column.arrayDims > 0) {
69+
let arrayType: TypeNode = typ;
70+
const dims = Math.max(column.arrayDims || 1);
71+
for (let i = 0; i < dims; i++) {
72+
arrayType = factory.createArrayTypeNode(arrayType);
73+
}
74+
if (column.notNull) {
75+
return arrayType;
76+
}
77+
return factory.createUnionTypeNode([
78+
arrayType,
79+
factory.createLiteralTypeNode(factory.createNull()),
80+
]);
81+
}
82+
if (column.notNull) {
83+
return typ;
84+
}
85+
return factory.createUnionTypeNode([typ, factory.createLiteralTypeNode(factory.createNull())]);
86+
}
87+
88+
let typ: TypeNode;
89+
switch (lowerTypeName) {
90+
// Boolean types
2991
case "bool":
3092
case "boolean":
3193
typ = factory.createKeywordTypeNode(SyntaxKind.BooleanKeyword);
3294
break;
95+
// Binary types
3396
case "bytea":
3497
typ = factory.createTypeReferenceNode(factory.createIdentifier("Buffer"), undefined);
3598
break;
99+
// Date/time types
36100
case "date":
37101
case "timestamp":
38102
case "timestamp without time zone":
39103
case "timestamptz":
40104
case "timestamp with time zone":
41105
typ = factory.createTypeReferenceNode(factory.createIdentifier("Date"), undefined);
42106
break;
107+
// Numeric types
43108
case "float4":
44109
case "real":
45110
case "float8":
@@ -61,11 +126,64 @@ export function columnType(column?: Column): TypeNode {
61126
case "oid":
62127
typ = factory.createKeywordTypeNode(SyntaxKind.NumberKeyword);
63128
break;
129+
// JSON types - any to allow flexible object access
64130
case "json":
65131
case "jsonb":
66132
typ = factory.createKeywordTypeNode(SyntaxKind.AnyKeyword);
67133
break;
68-
// All other types default to string (uuid, text, varchar, etc.)
134+
// Void type (from functions like pg_advisory_xact_lock)
135+
case "void":
136+
typ = factory.createKeywordTypeNode(SyntaxKind.VoidKeyword);
137+
break;
138+
// String types - explicitly listed (unambiguously representable as string)
139+
case "text":
140+
case "varchar":
141+
case "character varying":
142+
case "char":
143+
case "character":
144+
case "bpchar":
145+
case "name":
146+
case "uuid":
147+
case "citext":
148+
case "inet":
149+
case "cidr":
150+
case "macaddr":
151+
case "macaddr8":
152+
case "money":
153+
case "numeric":
154+
case "decimal":
155+
case "xml":
156+
case "bit":
157+
case "varbit":
158+
case "bit varying":
159+
case "interval":
160+
case "time":
161+
case "time without time zone":
162+
case "timetz":
163+
case "time with time zone":
164+
case "tsvector":
165+
case "tsquery":
166+
typ = factory.createKeywordTypeNode(SyntaxKind.StringKeyword);
167+
break;
168+
// Geometric types - postgres.js returns these as objects, not strings
169+
case "point":
170+
case "line":
171+
case "lseg":
172+
case "box":
173+
case "path":
174+
case "polygon":
175+
case "circle":
176+
throw new Error(
177+
`Unrecognized PostgreSQL type: "${originalTypeName}". ` +
178+
`Please add support for this type in sqlc-gen-typescript/src/drivers/postgres.ts`,
179+
);
180+
default:
181+
throw new Error(
182+
`Unrecognized PostgreSQL type: "${originalTypeName}" for column "${column.name || "unknown"}". ` +
183+
`This usually means sqlc couldn't infer the type. ` +
184+
`Try adding an explicit cast like "sqlc.arg(${column.name})::text" or "sqlc.narg('${column.name}')" in your query. ` +
185+
`If this is a valid PostgreSQL type that needs support, please add it to sqlc-gen-typescript/src/drivers/postgres.ts`,
186+
);
69187
}
70188

71189
if (column.isArray || column.arrayDims > 0) {

0 commit comments

Comments
 (0)