Skip to content

Commit 4686720

Browse files
ymc9claude
andauthored
fix(orm): fix PostgreSQL type mismatch when @db.Uuid fields used in policy expressions (#2532)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 85d8b6b commit 4686720

4 files changed

Lines changed: 284 additions & 21 deletions

File tree

packages/orm/src/client/crud/dialects/base-dialect.ts

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ import {
2424
flattenCompoundUniqueFilters,
2525
getDelegateDescendantModels,
2626
getManyToManyRelation,
27+
getModelFields,
2728
getRelationForeignKeyFieldPairs,
2829
isEnum,
2930
isTypeDef,
30-
getModelFields,
3131
makeDefaultOrderBy,
3232
requireField,
3333
requireIdFields,
@@ -1162,7 +1162,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
11621162

11631163
// client-level: check both uncapitalized (current) and original (backward compat) model name
11641164
const uncapModel = lowerCaseFirst(model);
1165-
const omitConfig = (this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
1165+
const omitConfig =
1166+
(this.options.omit as Record<string, any> | undefined)?.[uncapModel] ??
11661167
(this.options.omit as Record<string, any> | undefined)?.[model];
11671168
if (omitConfig && typeof omitConfig === 'object' && typeof omitConfig[field] === 'boolean') {
11681169
return omitConfig[field];
@@ -1357,7 +1358,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
13571358
const computedFields = this.options.computedFields as Record<string, any>;
13581359
// check both uncapitalized (current) and original (backward compat) model name
13591360
const computedModel = fieldDef.originModel ?? model;
1360-
computer = computedFields?.[lowerCaseFirst(computedModel)]?.[field] ?? computedFields?.[computedModel]?.[field];
1361+
computer =
1362+
computedFields?.[lowerCaseFirst(computedModel)]?.[field] ??
1363+
computedFields?.[computedModel]?.[field];
13611364
}
13621365
if (!computer) {
13631366
throw createConfigError(`Computed field "${field}" implementation not provided for model "${model}"`);
@@ -1489,6 +1492,19 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
14891492
*/
14901493
abstract buildValuesTableSelect(fields: FieldDef[], rows: unknown[][]): SelectQueryBuilder<any, any, any>;
14911494

1495+
/**
1496+
* Builds a binary comparison expression between two operands.
1497+
*/
1498+
buildComparison(
1499+
left: Expression<unknown>,
1500+
_leftFieldDef: FieldDef | undefined,
1501+
op: string,
1502+
right: Expression<unknown>,
1503+
_rightFieldDef: FieldDef | undefined,
1504+
): Expression<SqlBool> {
1505+
return this.eb(left, op as any, right) as Expression<SqlBool>;
1506+
}
1507+
14921508
/**
14931509
* Builds a JSON path selection expression.
14941510
*/

packages/orm/src/client/crud/dialects/postgresql.ts

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,34 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
3232
Json: 'jsonb',
3333
};
3434

35+
// Maps @db.* attribute names to PostgreSQL SQL types for use in VALUES table casts
36+
private static readonly dbAttributeToSqlTypeMap: Record<string, string> = {
37+
'@db.Uuid': 'uuid',
38+
'@db.Citext': 'citext',
39+
'@db.Inet': 'inet',
40+
'@db.Bit': 'bit',
41+
'@db.VarBit': 'varbit',
42+
'@db.Xml': 'xml',
43+
'@db.Json': 'json',
44+
'@db.JsonB': 'jsonb',
45+
'@db.ByteA': 'bytea',
46+
'@db.Text': 'text',
47+
'@db.Char': 'bpchar',
48+
'@db.VarChar': 'varchar',
49+
'@db.Date': 'date',
50+
'@db.Time': 'time',
51+
'@db.Timetz': 'timetz',
52+
'@db.Timestamp': 'timestamp',
53+
'@db.Timestamptz': 'timestamptz',
54+
'@db.SmallInt': 'smallint',
55+
'@db.Integer': 'integer',
56+
'@db.BigInt': 'bigint',
57+
'@db.Real': 'real',
58+
'@db.DoublePrecision': 'double precision',
59+
'@db.Decimal': 'decimal',
60+
'@db.Boolean': 'boolean',
61+
};
62+
3563
constructor(schema: Schema, options: ClientOptions<Schema>) {
3664
super(schema, options);
3765
this.overrideTypeParsers();
@@ -406,7 +434,16 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
406434
);
407435
}
408436

409-
private getSqlType(zmodelType: string) {
437+
private getSqlType(zmodelType: string, attributes?: FieldDef['attributes']) {
438+
// Check @db.* attributes first — they specify the exact native PostgreSQL type
439+
if (attributes) {
440+
for (const attr of attributes) {
441+
const mapped = PostgresCrudDialect.dbAttributeToSqlTypeMap[attr.name];
442+
if (mapped) {
443+
return mapped;
444+
}
445+
}
446+
}
410447
if (isEnum(this.schema, zmodelType)) {
411448
// reduce enum to text for type compatibility
412449
return 'text';
@@ -415,6 +452,42 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
415452
}
416453
}
417454

455+
// Resolves the effective SQL type for a field: the native type from any @db.* attribute,
456+
// or the base ZModel SQL type if no attribute is present, or undefined if the field is unknown.
457+
private resolveFieldSqlType(fieldDef: FieldDef | undefined): { sqlType: string | undefined; hasDbOverride: boolean } {
458+
if (!fieldDef) {
459+
return { sqlType: undefined, hasDbOverride: false };
460+
}
461+
const dbAttr = fieldDef.attributes?.find((a) => a.name.startsWith('@db.'));
462+
if (dbAttr) {
463+
return { sqlType: PostgresCrudDialect.dbAttributeToSqlTypeMap[dbAttr.name], hasDbOverride: true };
464+
}
465+
return { sqlType: this.getSqlType(fieldDef.type), hasDbOverride: false };
466+
}
467+
468+
override buildComparison(
469+
left: Expression<unknown>,
470+
leftFieldDef: FieldDef | undefined,
471+
op: string,
472+
right: Expression<unknown>,
473+
rightFieldDef: FieldDef | undefined,
474+
) {
475+
const leftResolved = this.resolveFieldSqlType(leftFieldDef);
476+
const rightResolved = this.resolveFieldSqlType(rightFieldDef);
477+
// If the resolved SQL types differ and at least one side carries a @db.* native type override,
478+
// cast that side back to its base ZModel SQL type so PostgreSQL doesn't reject the comparison
479+
// (e.g. "operator does not exist: uuid = text").
480+
if (leftResolved.sqlType !== rightResolved.sqlType && (leftResolved.hasDbOverride || rightResolved.hasDbOverride)) {
481+
if (leftResolved.hasDbOverride) {
482+
left = this.eb.cast(left, sql.raw(this.getSqlType(leftFieldDef!.type)));
483+
}
484+
if (rightResolved.hasDbOverride) {
485+
right = this.eb.cast(right, sql.raw(this.getSqlType(rightFieldDef!.type)));
486+
}
487+
}
488+
return super.buildComparison(left, leftFieldDef, op, right, rightFieldDef);
489+
}
490+
418491
override getStringCasingBehavior() {
419492
// Postgres `LIKE` is case-sensitive, `ILIKE` is case-insensitive
420493
return { supportsILike: true, likeCaseSensitive: true };
@@ -449,7 +522,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
449522
)
450523
.select(
451524
fields.map((f, i) => {
452-
const mappedType = this.getSqlType(f.type);
525+
const mappedType = this.getSqlType(f.type, f.attributes);
453526
const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType);
454527
return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name);
455528
}),

packages/plugins/policy/src/expression-transformer.ts

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,13 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
265265
} else if (this.isNullNode(left)) {
266266
return this.transformNullCheck(right, expr.op);
267267
} else {
268-
return BinaryOperationNode.create(left, this.transformOperator(op), right);
268+
const leftFieldDef = this.getFieldDefFromFieldRef(normalizedLeft, context);
269+
const rightFieldDef = this.getFieldDefFromFieldRef(normalizedRight, context);
270+
// Map ZModel operator to SQL operator string
271+
const sqlOp = op === '==' ? '=' : op;
272+
return this.dialect
273+
.buildComparison(new ExpressionWrapper(left), leftFieldDef, sqlOp, new ExpressionWrapper(right), rightFieldDef)
274+
.toOperationNode();
269275
}
270276
}
271277

@@ -298,17 +304,17 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
298304
// if relation fields are used directly in comparison, it can only be compared with null,
299305
// so we normalize the args with the id field (use the first id field if multiple)
300306
let normalizedLeft: Expression = expr.left;
301-
if (this.isRelationField(expr.left, context.modelOrType)) {
307+
if (this.isRelationField(expr.left, context)) {
302308
invariant(ExpressionUtils.isNull(expr.right), 'only null comparison is supported for relation field');
303-
const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context.modelOrType);
309+
const leftRelDef = this.getFieldDefFromFieldRef(expr.left, context);
304310
invariant(leftRelDef, 'failed to get relation field definition');
305311
const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type);
306312
normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]!);
307313
}
308314
let normalizedRight: Expression = expr.right;
309-
if (this.isRelationField(expr.right, context.modelOrType)) {
315+
if (this.isRelationField(expr.right, context)) {
310316
invariant(ExpressionUtils.isNull(expr.left), 'only null comparison is supported for relation field');
311-
const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context.modelOrType);
317+
const rightRelDef = this.getFieldDefFromFieldRef(expr.right, context);
312318
invariant(rightRelDef, 'failed to get relation field definition');
313319
const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type);
314320
normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]!);
@@ -349,7 +355,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
349355
);
350356

351357
let newContextModel: string;
352-
const fieldDef = this.getFieldDefFromFieldRef(expr.left, context.modelOrType);
358+
const fieldDef = this.getFieldDefFromFieldRef(expr.left, context);
353359
if (fieldDef) {
354360
invariant(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr.left)}`);
355361
newContextModel = fieldDef.type;
@@ -578,13 +584,6 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
578584
return logicalNot(this.dialect, this.transform(expr.operand, context));
579585
}
580586

581-
private transformOperator(op: Exclude<BinaryOperator, '?' | '!' | '^'>) {
582-
const mappedOp = match(op)
583-
.with('==', () => '=' as const)
584-
.otherwise(() => op);
585-
return OperatorNode.create(mappedOp);
586-
}
587-
588587
@expr('call')
589588
// @ts-ignore
590589
private _call(expr: CallExpression, context: ExpressionTransformerContext) {
@@ -979,12 +978,18 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
979978
}
980979
}
981980

982-
private isRelationField(expr: Expression, model: string) {
983-
const fieldDef = this.getFieldDefFromFieldRef(expr, model);
981+
private isRelationField(expr: Expression, context: ExpressionTransformerContext) {
982+
const fieldDef = this.getFieldDefFromFieldRef(expr, context);
984983
return !!fieldDef?.relation;
985984
}
986985

987-
private getFieldDefFromFieldRef(expr: Expression, model: string): FieldDef | undefined {
986+
private getFieldDefFromFieldRef(expr: Expression, context: ExpressionTransformerContext): FieldDef | undefined {
987+
// `this.foo` references belong to `thisType` (the outer model in collection-predicate
988+
// contexts); everything else uses `modelOrType`.
989+
const model =
990+
ExpressionUtils.isMember(expr) && ExpressionUtils.isThis(expr.receiver)
991+
? context.thisType
992+
: context.modelOrType;
988993
if (ExpressionUtils.isField(expr)) {
989994
return QueryUtils.getField(this.schema, model, expr.field);
990995
} else if (
@@ -993,6 +998,19 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
993998
ExpressionUtils.isThis(expr.receiver)
994999
) {
9951000
return QueryUtils.getField(this.schema, model, expr.members[0]!);
1001+
} else if (ExpressionUtils.isMember(expr) && ExpressionUtils.isField(expr.receiver)) {
1002+
// relation chain access (e.g. `owner.id`, `user.profile.uuid_field`): walk the
1003+
// relation hops and return the terminal field's FieldDef so native-type info
1004+
// (@db.*) is available for casting in buildComparison
1005+
const receiverDef = QueryUtils.getField(this.schema, model, expr.receiver.field);
1006+
if (!receiverDef?.relation) return undefined;
1007+
let currModel = receiverDef.type;
1008+
for (let i = 0; i < expr.members.length - 1; i++) {
1009+
const hopDef = QueryUtils.getField(this.schema, currModel, expr.members[i]!);
1010+
if (!hopDef?.relation) return undefined;
1011+
currModel = hopDef.type;
1012+
}
1013+
return QueryUtils.getField(this.schema, currModel, expr.members[expr.members.length - 1]!);
9961014
} else {
9971015
return undefined;
9981016
}

0 commit comments

Comments
 (0)