Skip to content
This repository was archived by the owner on Mar 1, 2026. It is now read-only.

Commit 92c34c4

Browse files
authored
fix(orm): add special treatment to enum arrays for postgres db (#577)
* fix(orm): add special treatment to enum arrays for postgres db - For input, the string array needs to be casted with "Enum"[] - For output, the raw pg array string needs to be parsed back to a proper JS string array fixes #576 * fix enum array filtering and improve test cases * update * addressing PR comments
1 parent 880e3b6 commit 92c34c4

11 files changed

Lines changed: 328 additions & 55 deletions

File tree

packages/orm/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"json-stable-stringify": "^1.3.0",
9191
"kysely": "catalog:",
9292
"nanoid": "^5.0.9",
93+
"postgres-array": "^3.0.4",
9394
"toposort": "^2.0.2",
9495
"ts-pattern": "catalog:",
9596
"ulid": "^3.0.0",

packages/orm/src/client/crud/dialects/base-dialect.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
4848
return value;
4949
}
5050

51-
transformOutput(value: unknown, _type: BuiltinType) {
51+
transformOutput(value: unknown, _type: BuiltinType, _array: boolean) {
5252
return value;
5353
}
5454

packages/orm/src/client/crud/dialects/postgresql.ts

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
type SelectQueryBuilder,
1010
type SqlBool,
1111
} from 'kysely';
12+
import { parse as parsePostgresArray } from 'postgres-array';
1213
import { match } from 'ts-pattern';
1314
import z from 'zod';
1415
import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types';
@@ -20,14 +21,17 @@ import type { ClientOptions } from '../../options';
2021
import {
2122
buildJoinPairs,
2223
getDelegateDescendantModels,
24+
getEnum,
2325
getManyToManyRelation,
26+
isEnum,
2427
isRelationField,
2528
isTypeDef,
2629
requireField,
2730
requireIdFields,
2831
requireModel,
2932
} from '../../query-utils';
3033
import { BaseCrudDialect } from './base-dialect';
34+
3135
export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect<Schema> {
3236
private isoDateSchema = z.iso.datetime({ local: true, offset: true });
3337

@@ -70,6 +74,16 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
7074
if (type === 'Json' && !forArrayField) {
7175
// scalar `Json` fields need their input stringified
7276
return JSON.stringify(value);
77+
}
78+
if (isEnum(this.schema, type)) {
79+
// cast to enum array `CAST(ARRAY[...] AS "enum_type"[])`
80+
return this.eb.cast(
81+
sql`ARRAY[${sql.join(
82+
value.map((v) => this.transformPrimitive(v, type, false)),
83+
sql.raw(','),
84+
)}]`,
85+
this.createSchemaQualifiedEnumType(type, true),
86+
);
7387
} else {
7488
// `Json[]` fields need their input as array (not stringified)
7589
return value.map((v) => this.transformPrimitive(v, type, false));
@@ -96,7 +110,33 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
96110
}
97111
}
98112

99-
override transformOutput(value: unknown, type: BuiltinType) {
113+
private createSchemaQualifiedEnumType(type: string, array: boolean) {
114+
// determines the postgres schema name for the enum type, and returns the
115+
// qualified name
116+
117+
let qualified = type;
118+
119+
const enumDef = getEnum(this.schema, type);
120+
if (enumDef) {
121+
// check if the enum has a custom "@@schema" attribute
122+
const schemaAttr = enumDef.attributes?.find((attr) => attr.name === '@@schema');
123+
if (schemaAttr) {
124+
const mapArg = schemaAttr.args?.find((arg) => arg.name === 'map');
125+
if (mapArg && mapArg.value.kind === 'literal') {
126+
const schemaName = mapArg.value.value as string;
127+
qualified = `"${schemaName}"."${type}"`;
128+
}
129+
} else {
130+
// no custom schema, use default from datasource or 'public'
131+
const defaultSchema = this.schema.provider.defaultSchema ?? 'public';
132+
qualified = `"${defaultSchema}"."${type}"`;
133+
}
134+
}
135+
136+
return array ? sql.raw(`${qualified}[]`) : sql.raw(qualified);
137+
}
138+
139+
override transformOutput(value: unknown, type: BuiltinType, array: boolean) {
100140
if (value === null || value === undefined) {
101141
return value;
102142
}
@@ -105,7 +145,11 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
105145
.with('Bytes', () => this.transformOutputBytes(value))
106146
.with('BigInt', () => this.transformOutputBigInt(value))
107147
.with('Decimal', () => this.transformDecimal(value))
108-
.otherwise(() => super.transformOutput(value, type));
148+
.when(
149+
(type) => isEnum(this.schema, type),
150+
() => this.transformOutputEnum(value, array),
151+
)
152+
.otherwise(() => super.transformOutput(value, type, array));
109153
}
110154

111155
private transformOutputBigInt(value: unknown) {
@@ -162,6 +206,19 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
162206
: value;
163207
}
164208

209+
private transformOutputEnum(value: unknown, array: boolean) {
210+
if (array && typeof value === 'string') {
211+
try {
212+
// postgres returns enum arrays as `{"val 1",val2}` strings, parse them back
213+
// to string arrays here
214+
return parsePostgresArray(value);
215+
} catch {
216+
// fall through - return as-is if parsing fails
217+
}
218+
}
219+
return value;
220+
}
221+
165222
override buildRelationSelection(
166223
query: SelectQueryBuilder<any, any, any>,
167224
model: string,

packages/orm/src/client/crud/dialects/sqlite.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
6767
}
6868
}
6969

70-
override transformOutput(value: unknown, type: BuiltinType) {
70+
override transformOutput(value: unknown, type: BuiltinType, array: boolean) {
7171
if (value === null || value === undefined) {
7272
return value;
7373
} else if (this.schema.typeDefs && type in this.schema.typeDefs) {
@@ -81,7 +81,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
8181
.with('Decimal', () => this.transformOutputDecimal(value))
8282
.with('BigInt', () => this.transformOutputBigInt(value))
8383
.with('Json', () => this.transformOutputJson(value))
84-
.otherwise(() => super.transformOutput(value, type));
84+
.otherwise(() => super.transformOutput(value, type, array));
8585
}
8686
}
8787

packages/orm/src/client/crud/validator/index.ts

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ import {
99
type BuiltinType,
1010
type EnumDef,
1111
type FieldDef,
12-
type ProcedureDef,
1312
type GetModels,
1413
type ModelDef,
14+
type ProcedureDef,
1515
type SchemaDef,
1616
} from '../../../schema';
1717
import { extractFields } from '../../../utils/object-utils';
@@ -199,10 +199,7 @@ export class InputValidator<Schema extends SchemaDef> {
199199
>(model, 'find', options, (model, options) => this.makeFindSchema(model, options), args);
200200
}
201201

202-
validateExistsArgs(
203-
model: GetModels<Schema>,
204-
args: unknown,
205-
): ExistsArgs<Schema, GetModels<Schema>> | undefined {
202+
validateExistsArgs(model: GetModels<Schema>, args: unknown): ExistsArgs<Schema, GetModels<Schema>> | undefined {
206203
return this.validate<ExistsArgs<Schema, GetModels<Schema>>>(
207204
model,
208205
'exists',
@@ -429,9 +426,11 @@ export class InputValidator<Schema extends SchemaDef> {
429426
}
430427

431428
private makeExistsSchema(model: string) {
432-
return z.strictObject({
433-
where: this.makeWhereSchema(model, false).optional(),
434-
}).optional();
429+
return z
430+
.strictObject({
431+
where: this.makeWhereSchema(model, false).optional(),
432+
})
433+
.optional();
435434
}
436435

437436
private makeScalarSchema(type: string, attributes?: readonly AttributeApplication[]) {
@@ -577,7 +576,12 @@ export class InputValidator<Schema extends SchemaDef> {
577576
if (enumDef) {
578577
// enum
579578
if (Object.keys(enumDef.values).length > 0) {
580-
fieldSchema = this.makeEnumFilterSchema(enumDef, !!fieldDef.optional, withAggregations);
579+
fieldSchema = this.makeEnumFilterSchema(
580+
enumDef,
581+
!!fieldDef.optional,
582+
withAggregations,
583+
!!fieldDef.array,
584+
);
581585
}
582586
} else if (fieldDef.array) {
583587
// array field
@@ -614,7 +618,12 @@ export class InputValidator<Schema extends SchemaDef> {
614618
if (enumDef) {
615619
// enum
616620
if (Object.keys(enumDef.values).length > 0) {
617-
fieldSchema = this.makeEnumFilterSchema(enumDef, !!def.optional, false);
621+
fieldSchema = this.makeEnumFilterSchema(
622+
enumDef,
623+
!!def.optional,
624+
false,
625+
false,
626+
);
618627
} else {
619628
fieldSchema = z.never();
620629
}
@@ -696,24 +705,23 @@ export class InputValidator<Schema extends SchemaDef> {
696705
!!fieldDef.array,
697706
).optional();
698707
} else {
699-
// array, enum, primitives
700-
if (fieldDef.array) {
708+
// enum, array, primitives
709+
const enumDef = getEnum(this.schema, fieldDef.type);
710+
if (enumDef) {
711+
fieldSchemas[fieldName] = this.makeEnumFilterSchema(
712+
enumDef,
713+
!!fieldDef.optional,
714+
false,
715+
!!fieldDef.array,
716+
).optional();
717+
} else if (fieldDef.array) {
701718
fieldSchemas[fieldName] = this.makeArrayFilterSchema(fieldDef.type as BuiltinType).optional();
702719
} else {
703-
const enumDef = getEnum(this.schema, fieldDef.type);
704-
if (enumDef) {
705-
fieldSchemas[fieldName] = this.makeEnumFilterSchema(
706-
enumDef,
707-
!!fieldDef.optional,
708-
false,
709-
).optional();
710-
} else {
711-
fieldSchemas[fieldName] = this.makePrimitiveFilterSchema(
712-
fieldDef.type as BuiltinType,
713-
!!fieldDef.optional,
714-
false,
715-
).optional();
716-
}
720+
fieldSchemas[fieldName] = this.makePrimitiveFilterSchema(
721+
fieldDef.type as BuiltinType,
722+
!!fieldDef.optional,
723+
false,
724+
).optional();
717725
}
718726
}
719727
}
@@ -757,24 +765,31 @@ export class InputValidator<Schema extends SchemaDef> {
757765
return this.schema.typeDefs && type in this.schema.typeDefs;
758766
}
759767

760-
private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean) {
768+
private makeEnumFilterSchema(enumDef: EnumDef, optional: boolean, withAggregations: boolean, array: boolean) {
761769
const baseSchema = z.enum(Object.keys(enumDef.values) as [string, ...string[]]);
770+
if (array) {
771+
return this.internalMakeArrayFilterSchema(baseSchema);
772+
}
762773
const components = this.makeCommonPrimitiveFilterComponents(
763774
baseSchema,
764775
optional,
765-
() => z.lazy(() => this.makeEnumFilterSchema(enumDef, optional, withAggregations)),
776+
() => z.lazy(() => this.makeEnumFilterSchema(enumDef, optional, withAggregations, array)),
766777
['equals', 'in', 'notIn', 'not'],
767778
withAggregations ? ['_count', '_min', '_max'] : undefined,
768779
);
769780
return z.union([this.nullableIf(baseSchema, optional), z.strictObject(components)]);
770781
}
771782

772783
private makeArrayFilterSchema(type: BuiltinType) {
784+
return this.internalMakeArrayFilterSchema(this.makeScalarSchema(type));
785+
}
786+
787+
private internalMakeArrayFilterSchema(elementSchema: ZodType) {
773788
return z.strictObject({
774-
equals: this.makeScalarSchema(type).array().optional(),
775-
has: this.makeScalarSchema(type).optional(),
776-
hasEvery: this.makeScalarSchema(type).array().optional(),
777-
hasSome: this.makeScalarSchema(type).array().optional(),
789+
equals: elementSchema.array().optional(),
790+
has: elementSchema.optional(),
791+
hasEvery: elementSchema.array().optional(),
792+
hasSome: elementSchema.array().optional(),
778793
isEmpty: z.boolean().optional(),
779794
});
780795
}

packages/orm/src/client/executor/name-mapper.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,9 @@ export class QueryNameMapper extends OperationNodeTransformer {
530530
let schema = this.schema.provider.defaultSchema ?? 'public';
531531
const schemaAttr = this.schema.models[model]?.attributes?.find((attr) => attr.name === '@@schema');
532532
if (schemaAttr) {
533-
const nameArg = schemaAttr.args?.find((arg) => arg.name === 'map');
534-
if (nameArg && nameArg.value.kind === 'literal') {
535-
schema = nameArg.value.value as string;
533+
const mapArg = schemaAttr.args?.find((arg) => arg.name === 'map');
534+
if (mapArg && mapArg.value.kind === 'literal') {
535+
schema = mapArg.value.value as string;
536536
}
537537
}
538538
return schema;

packages/orm/src/client/result-processor.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ export class ResultProcessor<Schema extends SchemaDef> {
4949
// merge delegate descendant fields
5050
if (value) {
5151
// descendant fields are packed as JSON
52-
const subRow = this.dialect.transformOutput(value, 'Json');
52+
const subRow = this.dialect.transformOutput(value, 'Json', false);
5353

5454
// process the sub-row
5555
const subModel = key.slice(DELEGATE_JOINED_FIELD_PREFIX.length) as GetModels<Schema>;
@@ -93,10 +93,10 @@ export class ResultProcessor<Schema extends SchemaDef> {
9393
private processFieldValue(value: unknown, fieldDef: FieldDef) {
9494
const type = fieldDef.type as BuiltinType;
9595
if (Array.isArray(value)) {
96-
value.forEach((v, i) => (value[i] = this.dialect.transformOutput(v, type)));
96+
value.forEach((v, i) => (value[i] = this.dialect.transformOutput(v, type, false)));
9797
return value;
9898
} else {
99-
return this.dialect.transformOutput(value, type);
99+
return this.dialect.transformOutput(value, type, !!fieldDef.array);
100100
}
101101
}
102102

packages/testtools/src/client.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ export async function createTestClient(
205205
fs.writeFileSync(path.resolve(workDir!, 'schema.prisma'), prismaSchemaText);
206206
execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', {
207207
cwd: workDir,
208-
stdio: 'ignore',
208+
stdio: options.debug ? 'inherit' : 'ignore',
209209
});
210210
} else {
211211
if (provider === 'postgresql') {

0 commit comments

Comments
 (0)