Skip to content

Commit 8ddfe3e

Browse files
ymc9claude
andcommitted
fix(orm): cast fields with @db.* native type annotations to base SQL type
When a field has a native database type annotation like @db.Uuid, PostgreSQL stores it as that native type. Comparing such a field with a plain String field (e.g., in policy expressions like `id == x`) causes a type mismatch error. This fix casts fields with @db.* attributes back to their ZModel base SQL type in fieldRef(), which affects both SELECT and WHERE clauses. - Add abstract `getSqlType()` to BaseCrudDialect with implementations in PostgreSQL, SQLite, and MySQL dialects - Add `hasNativeTypeAttribute()` helper to detect @db.* attributes - Modify `fieldRef()` to apply CAST for fields with native type annotations - Update policy expression transformer to use dialect's fieldRef() Fixes #2394 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d3ab3a6 commit 8ddfe3e

8 files changed

Lines changed: 136 additions & 31 deletions

File tree

packages/orm/src/client/crud/dialects/base-dialect.ts

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
8484

8585
// #endregion
8686

87+
// #region type mapping
88+
89+
/**
90+
* Maps a ZModel type to the corresponding SQL type for this dialect.
91+
*/
92+
protected abstract getSqlType(zmodelType: string): string | undefined;
93+
94+
/**
95+
* Checks if a field has a native database type attribute (e.g., `@db.Uuid`).
96+
*/
97+
protected hasNativeTypeAttribute(fieldDef: FieldDef): boolean {
98+
return !!fieldDef.attributes?.some((a) => a.name.startsWith('@db.'));
99+
}
100+
101+
// #endregion
102+
87103
// #region value transformation
88104

89105
/**
@@ -1143,7 +1159,16 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
11431159
) {
11441160
continue;
11451161
}
1146-
jsonObject[field] = eb.ref(`${subModel.name}.${field}`);
1162+
const subFieldDef = requireField(this.schema, subModel.name, field);
1163+
const castSqlType = this.hasNativeTypeAttribute(subFieldDef)
1164+
? this.getSqlType(subFieldDef.type)
1165+
: undefined;
1166+
if (castSqlType) {
1167+
jsonObject[field] =
1168+
sql`CAST(${sql.ref(`${subModel.name}.${field}`)} AS ${sql.raw(castSqlType)})`;
1169+
} else {
1170+
jsonObject[field] = eb.ref(`${subModel.name}.${field}`);
1171+
}
11471172
}
11481173
return this.buildJsonObject(jsonObject).as(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`);
11491174
});
@@ -1344,7 +1369,18 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
13441369

13451370
if (!fieldDef.computed) {
13461371
// regular field
1347-
return this.eb.ref(modelAlias ? `${modelAlias}.${field}` : field);
1372+
const ref = modelAlias ? `${modelAlias}.${field}` : field;
1373+
1374+
// if the field has a native database type annotation (e.g., @db.Uuid), cast it
1375+
// back to the base SQL type to avoid type mismatch in comparisons
1376+
if (this.hasNativeTypeAttribute(fieldDef)) {
1377+
const sqlType = this.getSqlType(fieldDef.type);
1378+
if (sqlType) {
1379+
return sql`CAST(${sql.ref(ref)} AS ${sql.raw(sqlType)})`;
1380+
}
1381+
}
1382+
1383+
return this.eb.ref(ref);
13481384
} else {
13491385
// computed field
13501386
if (!inlineComputedField) {

packages/orm/src/client/crud/dialects/mysql.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import type { BuiltinType, FieldDef, SchemaDef } from '../../../schema';
1616
import type { SortOrder } from '../../crud-types';
1717
import { createInvalidInputError, createNotSupportedError } from '../../errors';
1818
import type { ClientOptions } from '../../options';
19-
import { isTypeDef } from '../../query-utils';
19+
import { isEnum, isTypeDef } from '../../query-utils';
2020
import { LateralJoinDialectBase } from './lateral-join-dialect-base';
2121

2222
export class MySqlCrudDialect<Schema extends SchemaDef> extends LateralJoinDialectBase<Schema> {
@@ -318,6 +318,23 @@ export class MySqlCrudDialect<Schema extends SchemaDef> extends LateralJoinDiale
318318
);
319319
}
320320

321+
protected override getSqlType(zmodelType: string) {
322+
if (isEnum(this.schema, zmodelType)) {
323+
return 'varchar(191)';
324+
}
325+
return match(zmodelType)
326+
.with('String', () => 'varchar(191)')
327+
.with('Boolean', () => 'tinyint(1)')
328+
.with('Int', () => 'signed')
329+
.with('BigInt', () => 'bigint')
330+
.with('Float', () => 'double')
331+
.with('Decimal', () => 'decimal(65,30)')
332+
.with('DateTime', () => 'datetime(3)')
333+
.with('Bytes', () => 'longblob')
334+
.with('Json', () => 'json')
335+
.otherwise(() => undefined);
336+
}
337+
321338
override getStringCasingBehavior() {
322339
// MySQL LIKE is case-insensitive by default (depends on collation), no ILIKE support
323340
return { supportsILike: false, likeCaseSensitive: false };

packages/orm/src/client/crud/dialects/postgresql.ts

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,11 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
281281
override buildArrayValue(values: Expression<unknown>[], elemType: string): AliasableExpression<unknown> {
282282
const arr = sql`ARRAY[${sql.join(values, sql.raw(','))}]`;
283283
const mappedType = this.getSqlType(elemType);
284-
return this.eb.cast(arr, sql`${sql.raw(mappedType)}[]`);
284+
if (mappedType) {
285+
return this.eb.cast(arr, sql`${sql.raw(mappedType)}[]`);
286+
} else {
287+
return arr;
288+
}
285289
}
286290

287291
override buildArrayContains(
@@ -293,7 +297,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
293297
const arrayExpr = sql`ARRAY[${value}]`;
294298
if (elemType) {
295299
const mappedType = this.getSqlType(elemType);
296-
const typedArray = this.eb.cast(arrayExpr, sql`${sql.raw(mappedType)}[]`);
300+
const typedArray = mappedType ? this.eb.cast(arrayExpr, sql`${sql.raw(mappedType)}[]`) : arrayExpr;
297301
return this.eb(field, '@>', typedArray);
298302
} else {
299303
return this.eb(field, '@>', arrayExpr);
@@ -357,25 +361,22 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
357361
);
358362
}
359363

360-
private getSqlType(zmodelType: string) {
364+
protected override getSqlType(zmodelType: string) {
361365
if (isEnum(this.schema, zmodelType)) {
362366
// reduce enum to text for type compatibility
363367
return 'text';
364368
} else {
365-
return (
366-
match(zmodelType)
367-
.with('String', () => 'text')
368-
.with('Boolean', () => 'boolean')
369-
.with('Int', () => 'integer')
370-
.with('BigInt', () => 'bigint')
371-
.with('Float', () => 'double precision')
372-
.with('Decimal', () => 'decimal')
373-
.with('DateTime', () => 'timestamp')
374-
.with('Bytes', () => 'bytea')
375-
.with('Json', () => 'jsonb')
376-
// fallback to text
377-
.otherwise(() => 'text')
378-
);
369+
return match(zmodelType)
370+
.with('String', () => 'text')
371+
.with('Boolean', () => 'boolean')
372+
.with('Int', () => 'integer')
373+
.with('BigInt', () => 'bigint')
374+
.with('Float', () => 'double precision')
375+
.with('Decimal', () => 'decimal(65,30)')
376+
.with('DateTime', () => 'timestamp(3)')
377+
.with('Bytes', () => 'bytea')
378+
.with('Json', () => 'jsonb')
379+
.otherwise(() => undefined);
379380
}
380381
}
381382

@@ -414,8 +415,12 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
414415
.select(
415416
fields.map((f, i) => {
416417
const mappedType = this.getSqlType(f.type);
417-
const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType);
418-
return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name);
418+
if (mappedType) {
419+
const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType);
420+
return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name);
421+
} else {
422+
return sql.ref(`$values.column${i + 1}`).as(f.name);
423+
}
419424
}),
420425
);
421426
}

packages/orm/src/client/crud/dialects/sqlite.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
getDelegateDescendantModels,
2323
getManyToManyRelation,
2424
getRelationForeignKeyFieldPairs,
25+
isEnum,
2526
requireField,
2627
requireIdFields,
2728
requireModel,
@@ -488,6 +489,23 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
488489
return this.eb.fn('trim', [expression, sql.lit('"')]) as unknown as T;
489490
}
490491

492+
protected override getSqlType(zmodelType: string) {
493+
if (isEnum(this.schema, zmodelType)) {
494+
return 'text';
495+
}
496+
return match(zmodelType)
497+
.with('String', () => 'text')
498+
.with('Boolean', () => 'integer')
499+
.with('Int', () => 'integer')
500+
.with('BigInt', () => 'integer')
501+
.with('Float', () => 'real')
502+
.with('Decimal', () => 'decimal')
503+
.with('DateTime', () => 'numeric')
504+
.with('Bytes', () => 'blob')
505+
.with('Json', () => 'jsonb')
506+
.otherwise(() => undefined);
507+
}
508+
491509
override getStringCasingBehavior() {
492510
// SQLite `LIKE` is case-insensitive, and there is no `ILIKE`
493511
return { supportsILike: false, likeCaseSensitive: false };

packages/plugins/policy/src/expression-transformer.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {
908908

909909
const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, column);
910910
if (!fieldDef.originModel || fieldDef.originModel === context.modelOrType) {
911-
return ReferenceNode.create(ColumnNode.create(column), TableNode.create(tableName));
911+
return this.dialect.fieldRef(context.modelOrType, column, tableName, false).toOperationNode();
912912
}
913913

914914
return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);

pnpm-lock.yaml

Lines changed: 7 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/regression/package.json

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
},
1414
"dependencies": {
1515
"@zenstackhq/testtools": "workspace:*",
16-
"decimal.js": "catalog:"
16+
"decimal.js": "catalog:",
17+
"uuid": "^11.0.5"
1718
},
1819
"devDependencies": {
20+
"@types/node": "catalog:",
1921
"@zenstackhq/cli": "workspace:*",
2022
"@zenstackhq/language": "workspace:*",
21-
"@zenstackhq/schema": "workspace:*",
2223
"@zenstackhq/orm": "workspace:*",
23-
"@zenstackhq/sdk": "workspace:*",
2424
"@zenstackhq/plugin-policy": "workspace:*",
25+
"@zenstackhq/schema": "workspace:*",
26+
"@zenstackhq/sdk": "workspace:*",
2527
"@zenstackhq/typescript-config": "workspace:*",
26-
"@zenstackhq/vitest-config": "workspace:*",
27-
"@types/node": "catalog:"
28+
"@zenstackhq/vitest-config": "workspace:*"
2829
}
2930
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import { createPolicyTestClient } from '@zenstackhq/testtools';
2+
import { v4 as uuid } from 'uuid';
3+
import { describe, expect, it } from 'vitest';
4+
5+
describe('Regression for issue #2394', () => {
6+
const UUID_SCHEMA = `
7+
model Foo {
8+
id String @id @db.Uuid @default(dbgenerated("gen_random_uuid()"))
9+
x String
10+
11+
@@allow('all', id == x)
12+
}
13+
`;
14+
15+
it('works with policies', async () => {
16+
const db = await createPolicyTestClient(UUID_SCHEMA, {
17+
provider: 'postgresql',
18+
usePrismaPush: true,
19+
});
20+
21+
await db.$unuseAll().foo.create({ data: { x: uuid() } });
22+
await expect(db.foo.findMany()).toResolveTruthy();
23+
});
24+
});

0 commit comments

Comments
 (0)