Skip to content

Commit 7de2af2

Browse files
authored
fix(cli): dynamically load pg module in "db pull" (#2421)
1 parent 3336505 commit 7de2af2

1 file changed

Lines changed: 37 additions & 22 deletions

File tree

packages/cli/src/actions/pull/provider/postgresql.ts

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
import type { ZModelServices } from '@zenstackhq/language';
12
import type { Attribute, BuiltinType, Enum, Expression } from '@zenstackhq/language/ast';
23
import { AstFactory, DataFieldAttributeFactory, ExpressionBuilder } from '@zenstackhq/language/factory';
3-
import { Client } from 'pg';
4+
import { CliError } from '../../../cli-error';
45
import { getAttributeRef, getDbName, getFunctionRef, normalizeDecimalDefault, normalizeFloatDefault } from '../utils';
56
import type { IntrospectedEnum, IntrospectedSchema, IntrospectedTable, IntrospectionProvider } from './provider';
6-
import type { ZModelServices } from '@zenstackhq/language';
7-
import { CliError } from '../../../cli-error';
87

98
/**
109
* Maps PostgreSQL internal type names to their standard SQL names for comparison.
@@ -110,8 +109,8 @@ const pgTypnameToZenStackNativeType: Record<string, string> = {
110109

111110
export const postgresql: IntrospectionProvider = {
112111
isSupportedFeature(feature) {
113-
const supportedFeatures = ['Schema', 'NativeEnum'];
114-
return supportedFeatures.includes(feature);
112+
const supportedFeatures = ['Schema', 'NativeEnum'];
113+
return supportedFeatures.includes(feature);
115114
},
116115
getBuiltinType(type) {
117116
const t = (type || '').toLowerCase();
@@ -176,7 +175,11 @@ export const postgresql: IntrospectionProvider = {
176175
return { type: 'Unsupported' as const, isArray };
177176
}
178177
},
179-
async introspect(connectionString: string, options: { schemas: string[]; modelCasing: 'pascal' | 'camel' | 'snake' | 'none' }): Promise<IntrospectedSchema> {
178+
async introspect(
179+
connectionString: string,
180+
options: { schemas: string[]; modelCasing: 'pascal' | 'camel' | 'snake' | 'none' },
181+
): Promise<IntrospectedSchema> {
182+
const { Client } = await import('pg');
180183
const client = new Client({ connectionString });
181184
await client.connect();
182185

@@ -233,7 +236,7 @@ export const postgresql: IntrospectionProvider = {
233236
}
234237
}
235238
// Fall through to typeCastingConvert if datatype_name lookup fails
236-
return typeCastingConvert({defaultValue,enums,val,services});
239+
return typeCastingConvert({ defaultValue, enums, val, services });
237240
}
238241

239242
switch (fieldType) {
@@ -243,7 +246,7 @@ export const postgresql: IntrospectionProvider = {
243246
}
244247

245248
if (val.includes('::')) {
246-
return typeCastingConvert({defaultValue,enums,val,services});
249+
return typeCastingConvert({ defaultValue, enums, val, services });
247250
}
248251

249252
// Fallback to string literal for other DateTime defaults
@@ -256,19 +259,19 @@ export const postgresql: IntrospectionProvider = {
256259
}
257260

258261
if (val.includes('::')) {
259-
return typeCastingConvert({defaultValue,enums,val,services});
262+
return typeCastingConvert({ defaultValue, enums, val, services });
260263
}
261264
return (ab) => ab.NumberLiteral.setValue(val);
262265

263266
case 'Float':
264267
if (val.includes('::')) {
265-
return typeCastingConvert({defaultValue,enums,val,services});
268+
return typeCastingConvert({ defaultValue, enums, val, services });
266269
}
267270
return normalizeFloatDefault(val);
268271

269272
case 'Decimal':
270273
if (val.includes('::')) {
271-
return typeCastingConvert({defaultValue,enums,val,services});
274+
return typeCastingConvert({ defaultValue, enums, val, services });
272275
}
273276
return normalizeDecimalDefault(val);
274277

@@ -277,7 +280,7 @@ export const postgresql: IntrospectionProvider = {
277280

278281
case 'String':
279282
if (val.includes('::')) {
280-
return typeCastingConvert({defaultValue,enums,val,services});
283+
return typeCastingConvert({ defaultValue, enums, val, services });
281284
}
282285

283286
if (val.startsWith("'") && val.endsWith("'")) {
@@ -286,12 +289,12 @@ export const postgresql: IntrospectionProvider = {
286289
return (ab) => ab.StringLiteral.setValue(val);
287290
case 'Json':
288291
if (val.includes('::')) {
289-
return typeCastingConvert({defaultValue,enums,val,services});
292+
return typeCastingConvert({ defaultValue, enums, val, services });
290293
}
291294
return (ab) => ab.StringLiteral.setValue(val);
292295
case 'Bytes':
293296
if (val.includes('::')) {
294-
return typeCastingConvert({defaultValue,enums,val,services});
297+
return typeCastingConvert({ defaultValue, enums, val, services });
295298
}
296299
return (ab) => ab.StringLiteral.setValue(val);
297300
}
@@ -303,15 +306,20 @@ export const postgresql: IntrospectionProvider = {
303306
);
304307
}
305308

306-
console.warn(`Unsupported default value type: "${defaultValue}" for field type "${fieldType}". Skipping default value.`);
309+
console.warn(
310+
`Unsupported default value type: "${defaultValue}" for field type "${fieldType}". Skipping default value.`,
311+
);
307312
return null;
308313
},
309314

310315
getFieldAttributes({ fieldName, fieldType, datatype, length, precision, services }) {
311316
const factories: DataFieldAttributeFactory[] = [];
312317

313318
// Add @updatedAt for DateTime fields named updatedAt or updated_at
314-
if (fieldType === 'DateTime' && (fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')) {
319+
if (
320+
fieldType === 'DateTime' &&
321+
(fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')
322+
) {
315323
factories.push(new DataFieldAttributeFactory().setDecl(getAttributeRef('@updatedAt', services)));
316324
}
317325

@@ -338,8 +346,7 @@ export const postgresql: IntrospectionProvider = {
338346
dbAttr &&
339347
defaultDatabaseType &&
340348
(defaultDatabaseType.type !== normalizedDatatype ||
341-
(defaultDatabaseType.precision &&
342-
defaultDatabaseType.precision !== (length ?? precision)))
349+
(defaultDatabaseType.precision && defaultDatabaseType.precision !== (length ?? precision)))
343350
) {
344351
const dbAttrFactory = new DataFieldAttributeFactory().setDecl(dbAttr);
345352
// Only add length/precision if it's meaningful (not the standard bit width for the type)
@@ -628,7 +635,17 @@ WHERE
628635
ORDER BY "ns"."nspname", "cls"."relname" ASC;
629636
`;
630637

631-
function typeCastingConvert({defaultValue, enums, val, services}:{val: string, enums: Enum[], defaultValue:string, services:ZModelServices}): ((builder: ExpressionBuilder) => AstFactory<Expression>) | null {
638+
function typeCastingConvert({
639+
defaultValue,
640+
enums,
641+
val,
642+
services,
643+
}: {
644+
val: string;
645+
enums: Enum[];
646+
defaultValue: string;
647+
services: ZModelServices;
648+
}): ((builder: ExpressionBuilder) => AstFactory<Expression>) | null {
632649
const [value, type] = val
633650
.replace(/'/g, '')
634651
.split('::')
@@ -653,9 +670,7 @@ function typeCastingConvert({defaultValue, enums, val, services}:{val: string, e
653670
}
654671
const enumField = enumDef.fields.find((v) => getDbName(v) === value);
655672
if (!enumField) {
656-
throw new CliError(
657-
`Enum value ${value} not found in enum ${type} for default value ${defaultValue}`,
658-
);
673+
throw new CliError(`Enum value ${value} not found in enum ${type} for default value ${defaultValue}`);
659674
}
660675
return (ab) => ab.ReferenceExpr.setTarget(enumField);
661676
}

0 commit comments

Comments
 (0)