Skip to content
This repository was archived by the owner on Mar 1, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions packages/orm/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -489,26 +489,41 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
continue;
}

const value = this.transformInput(_value, fieldType, !!fieldDef.array);
invariant(fieldDef.array, 'Field must be an array type to build array filter');
const value = this.transformInput(_value, fieldType, true);

let receiver = fieldRef;
if (isEnum(this.schema, fieldType)) {
// cast enum array to `text[]` for type compatibility
receiver = this.eb.cast(fieldRef, sql.raw('text[]'));
}

const buildArray = (value: unknown) => {
invariant(Array.isArray(value), 'Array filter value must be an array');
return this.buildArrayValue(
value.map((v) => this.eb.val(v)),
fieldType,
);
};

switch (key) {
case 'equals': {
clauses.push(this.buildLiteralFilter(fieldRef, fieldType, this.eb.val(value)));
clauses.push(this.eb(receiver, '=', buildArray(value)));
break;
}

case 'has': {
clauses.push(this.eb(fieldRef, '@>', this.eb.val([value])));
clauses.push(this.buildArrayContains(receiver, this.eb.val(value)));
break;
}

case 'hasEvery': {
clauses.push(this.eb(fieldRef, '@>', this.eb.val(value)));
clauses.push(this.buildArrayHasEvery(receiver, buildArray(value)));
break;
}

case 'hasSome': {
clauses.push(this.eb(fieldRef, '&&', this.eb.val(value)));
clauses.push(this.buildArrayHasSome(receiver, buildArray(value)));
break;
}

Expand Down Expand Up @@ -1420,9 +1435,24 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
abstract buildArrayLength(array: Expression<unknown>): AliasableExpression<number>;

/**
* Builds an array literal SQL string for the given values.
* Builds an array value expression.
*/
abstract buildArrayLiteralSQL(values: unknown[]): AliasableExpression<unknown>;
abstract buildArrayValue(values: Expression<unknown>[], elemType: string): AliasableExpression<unknown>;

/**
* Builds an expression that checks if an array contains a single value.
*/
abstract buildArrayContains(field: Expression<unknown>, value: Expression<unknown>): AliasableExpression<SqlBool>;

/**
* Builds an expression that checks if an array contains all values from another array.
*/
abstract buildArrayHasEvery(field: Expression<unknown>, values: Expression<unknown>): AliasableExpression<SqlBool>;

/**
* Builds an expression that checks if an array overlaps with another array.
*/
abstract buildArrayHasSome(field: Expression<unknown>, values: Expression<unknown>): AliasableExpression<SqlBool>;

/**
* Casts the given expression to an integer type.
Expand All @@ -1439,11 +1469,6 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
*/
abstract trimTextQuotes<T extends Expression<string>>(expression: T): T;

/**
* Gets the SQL column type for the given field definition.
*/
abstract getFieldSqlType(fieldDef: FieldDef): string;

/*
* Gets the string casing behavior for the dialect.
*/
Expand Down
64 changes: 26 additions & 38 deletions packages/orm/src/client/crud/dialects/mysql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ import Decimal from 'decimal.js';
import type { AliasableExpression, TableExpression } from 'kysely';
import {
expressionBuilder,
ExpressionWrapper,
sql,
ValueListNode,
type Expression,
type ExpressionWrapper,
type SelectQueryBuilder,
type SqlBool,
} from 'kysely';
import { match } from 'ts-pattern';
import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types';
import type { BuiltinType, FieldDef, SchemaDef } from '../../../schema';
import type { SortOrder } from '../../crud-types';
import { createInternalError, createInvalidInputError, createNotSupportedError } from '../../errors';
import { createInvalidInputError, createNotSupportedError } from '../../errors';
import type { ClientOptions } from '../../options';
import { isTypeDef } from '../../query-utils';
import { LateralJoinDialectBase } from './lateral-join-dialect-base';
Expand Down Expand Up @@ -223,8 +224,29 @@ export class MySqlCrudDialect<Schema extends SchemaDef> extends LateralJoinDiale
return this.eb.fn('JSON_LENGTH', [array]);
}

override buildArrayLiteralSQL(_values: unknown[]): AliasableExpression<number> {
throw new Error('MySQL does not support array literals');
override buildArrayValue(values: Expression<unknown>[], _elemType: string): AliasableExpression<unknown> {
return new ExpressionWrapper(ValueListNode.create(values.map((v) => v.toOperationNode())));
}

override buildArrayContains(
_field: Expression<unknown>,
_value: Expression<unknown>,
): AliasableExpression<SqlBool> {
throw createNotSupportedError('MySQL does not support native array operations');
}

override buildArrayHasEvery(
_field: Expression<unknown>,
_values: Expression<unknown>,
): AliasableExpression<SqlBool> {
throw createNotSupportedError('MySQL does not support native array operations');
}

override buildArrayHasSome(
_field: Expression<unknown>,
_values: Expression<unknown>,
): AliasableExpression<SqlBool> {
throw createNotSupportedError('MySQL does not support native array operations');
}

protected override buildJsonEqualityFilter(
Expand Down Expand Up @@ -288,40 +310,6 @@ export class MySqlCrudDialect<Schema extends SchemaDef> extends LateralJoinDiale
);
}

override getFieldSqlType(fieldDef: FieldDef) {
// TODO: respect `@db.x` attributes
if (fieldDef.relation) {
throw createInternalError('Cannot get SQL type of a relation field');
}

let result: string;

if (this.schema.enums?.[fieldDef.type]) {
// enums are treated as text/varchar
result = 'varchar(255)';
} else {
result = match(fieldDef.type)
.with('String', () => 'varchar(255)')
.with('Boolean', () => 'tinyint(1)') // MySQL uses tinyint(1) for boolean
.with('Int', () => 'int')
.with('BigInt', () => 'bigint')
.with('Float', () => 'double')
.with('Decimal', () => 'decimal')
.with('DateTime', () => 'datetime')
.with('Bytes', () => 'blob')
.with('Json', () => 'json')
// fallback to text
.otherwise(() => 'text');
}

if (fieldDef.array) {
// MySQL stores arrays as JSON
result = 'json';
}

return result;
}

override getStringCasingBehavior() {
// MySQL LIKE is case-insensitive by default (depends on collation), no ILIKE support
return { supportsILike: false, likeCaseSensitive: false };
Expand Down
124 changes: 44 additions & 80 deletions packages/orm/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import z from 'zod';
import { AnyNullClass, DbNullClass, JsonNullClass } from '../../../common-types';
import type { BuiltinType, FieldDef, SchemaDef } from '../../../schema';
import type { SortOrder } from '../../crud-types';
import { createInternalError, createInvalidInputError } from '../../errors';
import { createInvalidInputError } from '../../errors';
import type { ClientOptions } from '../../options';
import { getEnum, isEnum, isTypeDef } from '../../query-utils';
import { isEnum, isTypeDef } from '../../query-utils';
import { LateralJoinDialectBase } from './lateral-join-dialect-base';

export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDialectBase<Schema> {
Expand Down Expand Up @@ -95,18 +95,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
if (type === 'Json' && !forArrayField) {
// scalar `Json` fields need their input stringified
return JSON.stringify(value);
}
if (isEnum(this.schema, type)) {
// cast to enum array `CAST(ARRAY[...] AS "enum_type"[])`
return this.eb.cast(
sql`ARRAY[${sql.join(
value.map((v) => this.transformInput(v, type, false)),
sql.raw(','),
)}]`,
this.createSchemaQualifiedEnumType(type, true),
);
} else {
// `Json[]` fields need their input as array (not stringified)
return value.map((v) => this.transformInput(v, type, false));
}
} else {
Expand Down Expand Up @@ -136,32 +125,6 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
}
}

private createSchemaQualifiedEnumType(type: string, array: boolean) {
// determines the postgres schema name for the enum type, and returns the
// qualified name

let qualified = type;

const enumDef = getEnum(this.schema, type);
if (enumDef) {
// check if the enum has a custom "@@schema" attribute
const schemaAttr = enumDef.attributes?.find((attr) => attr.name === '@@schema');
if (schemaAttr) {
const mapArg = schemaAttr.args?.find((arg) => arg.name === 'map');
if (mapArg && mapArg.value.kind === 'literal') {
const schemaName = mapArg.value.value as string;
qualified = `"${schemaName}"."${type}"`;
}
} else {
// no custom schema, use default from datasource or 'public'
const defaultSchema = this.schema.provider.defaultSchema ?? 'public';
qualified = `"${defaultSchema}"."${type}"`;
}
}

return array ? sql.raw(`${qualified}[]`) : sql.raw(qualified);
}

override transformOutput(value: unknown, type: BuiltinType, array: boolean) {
if (value === null || value === undefined) {
return value;
Expand Down Expand Up @@ -290,15 +253,25 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
return this.eb.fn('array_length', [array]);
}

override buildArrayLiteralSQL(values: unknown[]): AliasableExpression<unknown> {
if (values.length === 0) {
return sql`{}`;
} else {
return sql`ARRAY[${sql.join(
values.map((v) => sql.val(v)),
sql.raw(','),
)}]`;
}
override buildArrayValue(values: Expression<unknown>[], elemType: string): AliasableExpression<unknown> {
const arr = sql`ARRAY[${sql.join(values, sql.raw(','))}]`;
const mappedType = this.getSqlType(elemType);
return this.eb.cast(arr, sql`${sql.raw(mappedType)}[]`);
}

override buildArrayContains(field: Expression<unknown>, value: Expression<unknown>): AliasableExpression<SqlBool> {
// PostgreSQL @> operator expects array on both sides, so wrap single value in array
return this.eb(field, '@>', sql`ARRAY[${value}]`);
}

override buildArrayHasEvery(field: Expression<unknown>, values: Expression<unknown>): AliasableExpression<SqlBool> {
// PostgreSQL @> operator: field contains all elements in values
return this.eb(field, '@>', values);
}

override buildArrayHasSome(field: Expression<unknown>, values: Expression<unknown>): AliasableExpression<SqlBool> {
// PostgreSQL && operator: arrays have any elements in common
return this.eb(field, '&&', values);
}

protected override buildJsonPathSelection(receiver: Expression<any>, path: string | undefined) {
Expand Down Expand Up @@ -348,37 +321,26 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
);
}

override getFieldSqlType(fieldDef: FieldDef) {
// TODO: respect `@db.x` attributes
if (fieldDef.relation) {
throw createInternalError('Cannot get SQL type of a relation field');
}

let result: string;

if (this.schema.enums?.[fieldDef.type]) {
// enums are treated as text
result = 'text';
private getSqlType(zmodelType: string) {
if (isEnum(this.schema, zmodelType)) {
// reduce enum to text for type compatibility
return 'text';
} else {
result = match(fieldDef.type)
.with('String', () => 'text')
.with('Boolean', () => 'boolean')
.with('Int', () => 'integer')
.with('BigInt', () => 'bigint')
.with('Float', () => 'double precision')
.with('Decimal', () => 'decimal')
.with('DateTime', () => 'timestamp')
.with('Bytes', () => 'bytea')
.with('Json', () => 'jsonb')
// fallback to text
.otherwise(() => 'text');
}

if (fieldDef.array) {
result += '[]';
return (
match(zmodelType)
.with('String', () => 'text')
.with('Boolean', () => 'boolean')
.with('Int', () => 'integer')
.with('BigInt', () => 'bigint')
.with('Float', () => 'double precision')
.with('Decimal', () => 'decimal')
.with('DateTime', () => 'timestamp')
.with('Bytes', () => 'bytea')
.with('Json', () => 'jsonb')
// fallback to text
.otherwise(() => 'text')
);
}

return result;
}

override getStringCasingBehavior() {
Expand Down Expand Up @@ -414,9 +376,11 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
)})`.as('$values'),
)
.select(
fields.map((f, i) =>
sql`CAST(${sql.ref(`$values.column${i + 1}`)} AS ${sql.raw(this.getFieldSqlType(f))})`.as(f.name),
),
fields.map((f, i) => {
const mappedType = this.getSqlType(f.type);
const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType);
return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name);
}),
);
}

Expand Down
Loading