diff --git a/packages/orm/src/client/contract.ts b/packages/orm/src/client/contract.ts index 4a451e203..4d54c2e49 100644 --- a/packages/orm/src/client/contract.ts +++ b/packages/orm/src/client/contract.ts @@ -28,15 +28,20 @@ import type { FindUniqueArgs, GroupByArgs, GroupByResult, + OmitWhere, ProcedureFunc, SelectSubset, + SelectSubsetWithWhere, SimplifiedPlainResult, Subset, + SubsetWithWhere, TypeDefResult, UpdateArgs, UpdateManyAndReturnArgs, UpdateManyArgs, UpsertArgs, + WhereInput, + WhereUniqueInput, } from './crud-types'; import type { Diagnostics } from './diagnostics'; import type { ClientOptions, QueryOptions } from './options'; @@ -405,8 +410,8 @@ export type AllModelOperations< * }); * ``` */ - updateManyAndReturn>( - args: Subset>, + updateManyAndReturn>>( + args: { where?: WhereInput } & SubsetWithWhere>>, ): ZenStackPromise[]>; }); @@ -498,8 +503,8 @@ type CommonModelOperations< * }); // result: `{ _count: { posts: number } }` * ``` */ - findMany>( - args?: SelectSubset>, + findMany>>( + args?: { where?: WhereInput } & SelectSubset>, ): ZenStackPromise[]>; /** @@ -508,8 +513,8 @@ type CommonModelOperations< * @returns a single entity or null if not found * @see {@link findMany} */ - findUnique>( - args: SelectSubset>, + findUnique>>( + args: { where: WhereUniqueInput } & SelectSubset>, ): ZenStackPromise | null>; /** @@ -518,8 +523,8 @@ type CommonModelOperations< * @returns a single entity * @see {@link findMany} */ - findUniqueOrThrow>( - args: SelectSubset>, + findUniqueOrThrow>>( + args: { where: WhereUniqueInput } & SelectSubset>, ): ZenStackPromise>; /** @@ -528,8 +533,8 @@ type CommonModelOperations< * @returns a single entity or null if not found * @see {@link findMany} */ - findFirst>( - args?: SelectSubset>, + findFirst>>( + args?: { where?: WhereInput } & SelectSubset>, ): ZenStackPromise | null>; /** @@ -538,8 +543,8 @@ type CommonModelOperations< * @returns a single entity * @see {@link findMany} */ - findFirstOrThrow>( - args?: SelectSubset>, + findFirstOrThrow>>( + args?: { where?: WhereInput } & SelectSubset>, ): ZenStackPromise>; /** @@ -744,8 +749,8 @@ type CommonModelOperations< * }); * ``` */ - update>( - args: SelectSubset>, + update>>( + args: { where: WhereUniqueInput } & SelectSubsetWithWhere>>, ): ZenStackPromise>; /** @@ -768,8 +773,8 @@ type CommonModelOperations< * limit: 10 * }); */ - updateMany>( - args: Subset>, + updateMany>>( + args: { where?: WhereInput } & SubsetWithWhere>>, ): ZenStackPromise; /** @@ -792,8 +797,8 @@ type CommonModelOperations< * }); * ``` */ - upsert>( - args: SelectSubset>, + upsert>>( + args: { where: WhereUniqueInput } & SelectSubsetWithWhere>>, ): ZenStackPromise>; /** @@ -815,8 +820,8 @@ type CommonModelOperations< * }); // result: `{ id: string; email: string }` * ``` */ - delete>( - args: SelectSubset>, + delete>>( + args: { where: WhereUniqueInput } & SelectSubset>, ): ZenStackPromise>; /** @@ -838,8 +843,8 @@ type CommonModelOperations< * }); * ``` */ - deleteMany>( - args?: Subset>, + deleteMany>>( + args?: { where?: WhereInput } & Subset>, ): ZenStackPromise; /** diff --git a/packages/orm/src/client/crud-types.ts b/packages/orm/src/client/crud-types.ts index d29822209..8c8108089 100644 --- a/packages/orm/src/client/crud-types.ts +++ b/packages/orm/src/client/crud-types.ts @@ -347,6 +347,22 @@ export type WhereInput< AND?: OrArray>; OR?: WhereInput[]; NOT?: OrArray>; +} & (IsDelegateModel extends true + ? { $is?: SubModelWhereInput } + : object); + +/** + * Where filter that targets a specific sub-model of a delegate (polymorphic) base model. + * Keys are direct sub-model names; values are `WhereInput` for that sub-model. + * Multiple sub-model entries are combined with OR semantics. + */ +export type SubModelWhereInput< + Schema extends SchemaDef, + Model extends GetModels, + Options extends QueryOptions = QueryOptions, + ScalarOnly extends boolean = false, +> = { + [SubModel in GetSubModels]?: WhereInput | null; }; type FieldFilter< @@ -1088,6 +1104,42 @@ export type SelectSubset = { ? 'Please either choose `select` or `omit`.' : {}); +/** + * Strips the `where` field from an args type so the remaining fields can be used as + * the generic type parameter `T` in CRUD methods, allowing `where` to be typed directly + * and benefit from TypeScript's excess property checking. + * @internal + */ +export type OmitWhere = Omit; + +/** + * Like {@link Subset} but maps the `where` key to `unknown` (instead of `never`) when + * `where` is not present in `U`. This is used in CRUD method signatures where `where` + * is separately typed as `{ where: WhereXxxInput }`: because TypeScript infers T from + * the full argument object (including the `where` field), a naive `Subset>` + * would produce `where: never` in the mapped result, collapsing the `where` type in the + * intersection to `never`. Mapping to `unknown` instead gives + * `{ where: W } & { where: unknown }` = `{ where: W }`, preserving both the correct type + * and TypeScript's excess-property checking on `where`. + * @internal + */ +export type SubsetWithWhere = { + [key in keyof T]: key extends keyof U ? T[key] : key extends 'where' ? unknown : never; +}; + +/** + * Like {@link SelectSubset} but maps the `where` key to `unknown` (instead of `never`) when + * `where` is not present in `U`. See {@link SubsetWithWhere} for the rationale. + * @internal + */ +export type SelectSubsetWithWhere = { + [key in keyof T]: key extends keyof U ? T[key] : key extends 'where' ? unknown : never; +} & (T extends { select: any; include: any } + ? 'Please either choose `select` or `include`.' + : T extends { select: any; omit: any } + ? 'Please either choose `select` or `omit`.' + : {}); + type ToManyRelationFilter< Schema extends SchemaDef, Model extends GetModels, diff --git a/packages/orm/src/client/crud/dialects/base-dialect.ts b/packages/orm/src/client/crud/dialects/base-dialect.ts index b525ac486..3a7ddbf38 100644 --- a/packages/orm/src/client/crud/dialects/base-dialect.ts +++ b/packages/orm/src/client/crud/dialects/base-dialect.ts @@ -23,6 +23,7 @@ import { ensureArray, flattenCompoundUniqueFilters, getDelegateDescendantModels, + getDiscriminatorField, getManyToManyRelation, getModelFields, getRelationForeignKeyFieldPairs, @@ -224,6 +225,11 @@ export abstract class BaseCrudDialect { result = this.and(result, _where['$expr'](this.eb)); } + // handle $is filter for delegate (polymorphic) base models + if ('$is' in _where && _where['$is'] != null && typeof _where['$is'] === 'object') { + result = this.and(result, this.buildIsFilter(model, modelAlias, _where['$is'] as Record)); + } + return result; } @@ -319,6 +325,59 @@ export abstract class BaseCrudDialect { .exhaustive(); } + /** + * Builds a filter expression for the `$is` operator on a delegate (polymorphic) base model. + * Each key in `payload` is a direct sub-model name; the value is an optional `WhereInput` for + * that sub-model. Multiple sub-model entries are combined with OR semantics. + */ + private buildIsFilter(model: string, modelAlias: string, payload: Record): Expression { + const discriminatorField = getDiscriminatorField(this.schema, model); + if (!discriminatorField) { + throw createInvalidInputError( + `"$is" filter is only supported on delegate models; "${model}" is not a delegate model. ` + + `Only models with a @@delegate attribute support the "$is" filter.`, + ); + } + + const discriminatorFieldDef = requireField(this.schema, model, discriminatorField); + const discriminatorTableAlias = discriminatorFieldDef.originModel ?? modelAlias; + const discriminatorRef = this.eb.ref(`${discriminatorTableAlias}.${discriminatorField}`); + + const conditions: Expression[] = []; + + for (const [subModelName, subWhere] of Object.entries(payload)) { + // discriminator must equal the sub-model name + const discriminatorCheck = this.eb(discriminatorRef, '=', subModelName); + + if (subWhere == null || (typeof subWhere === 'object' && Object.keys(subWhere).length === 0)) { + // no sub-model field filter — just check the discriminator + conditions.push(discriminatorCheck); + } else { + // build a correlated EXISTS subquery for sub-model-specific field filters + const subAlias = tmpAlias(`${modelAlias}__is__${subModelName}`); + const idFields = requireIdFields(this.schema, model); + + // correlate sub-model rows to the outer model rows via primary key + const joinConditions = idFields.map((idField) => + this.eb(this.eb.ref(`${subAlias}.${idField}`), '=', this.eb.ref(`${modelAlias}.${idField}`)), + ); + + const subWhereFilter = this.buildFilter(subModelName, subAlias, subWhere); + + const existsSubquery = this.eb + .selectFrom(`${subModelName} as ${subAlias}`) + .select(this.eb.lit(1).as('__exists')) + .where(this.and(...joinConditions, subWhereFilter)); + + conditions.push(this.and(discriminatorCheck, this.eb.exists(existsSubquery))); + } + } + + if (conditions.length === 0) return this.true(); + if (conditions.length === 1) return conditions[0]!; + return this.or(...conditions); + } + private buildRelationFilter(model: string, modelAlias: string, field: string, fieldDef: FieldDef, payload: any) { if (!fieldDef.array) { return this.buildToOneRelationFilter(model, modelAlias, field, fieldDef, payload); diff --git a/packages/orm/src/client/zod/factory.ts b/packages/orm/src/client/zod/factory.ts index 0f8b22701..3db45a192 100644 --- a/packages/orm/src/client/zod/factory.ts +++ b/packages/orm/src/client/zod/factory.ts @@ -484,6 +484,23 @@ export class ZodSchemaFactory< // expression builder fields['$expr'] = z.custom((v) => typeof v === 'function', { error: '"$expr" must be a function' }).optional(); + // $is sub-model filter for delegate (polymorphic) base models + const modelDef = requireModel(this.schema, model); + if (modelDef.isDelegate && modelDef.subModels && modelDef.subModels.length > 0) { + const subModelSchema = z.object( + Object.fromEntries( + modelDef.subModels.map((subModel) => [ + subModel, + z + .lazy(() => this.makeWhereSchema(subModel, false, false, false, options)) + .nullish() + .optional(), + ]), + ), + ); + fields['$is'] = subModelSchema.optional(); + } + // logical operators fields['AND'] = this.orArray( z.lazy(() => this.makeWhereSchema(model, false, withoutRelationFields, false, options)), diff --git a/tests/e2e/orm/client-api/delegate.test.ts b/tests/e2e/orm/client-api/delegate.test.ts index e325aaa5b..b4075a88c 100644 --- a/tests/e2e/orm/client-api/delegate.test.ts +++ b/tests/e2e/orm/client-api/delegate.test.ts @@ -557,6 +557,69 @@ describe('Delegate model tests ', () => { }), ).toResolveFalsy(); }); + + it('works with $is sub-model filter on base model', async () => { + // add an Image so we can test OR semantics of $is + await client.image.create({ + data: { format: 'png', viewCount: 2 }, + }); + + // $is: { Video: {} } — all assets that are Videos (2 RatedVideos) + await expect( + client.asset.findMany({ + where: { $is: { Video: {} } }, + }), + ).toResolveWithLength(2); + + // $is: { Video: null } — null value also means "is a Video" + await expect( + client.asset.findMany({ + where: { $is: { Video: null } }, + }), + ).toResolveWithLength(2); + + // $is: { Video: { duration: { gt: 100 } } } — only v2 + await expect( + client.asset.findMany({ + where: { $is: { Video: { duration: { gt: 100 } } } }, + }), + ).toResolveWithLength(1); + + // $is combined with base-model field filter + await expect( + client.asset.findMany({ + where: { viewCount: { gt: 0 }, $is: { Video: { duration: { gt: 100 } } } }, + }), + ).toResolveWithLength(1); + + // $is: { Video: { duration: { gte: 100 } } } — both videos + await expect( + client.asset.findMany({ + where: { $is: { Video: { duration: { gte: 100 } } } }, + }), + ).toResolveWithLength(2); + + // $is with multiple sub-models → OR semantics (1 Video with viewCount>0 OR the Image) + await expect( + client.asset.findMany({ + where: { $is: { Video: { duration: { gt: 100 } }, Image: { format: 'png' } } }, + }), + ).toResolveWithLength(2); + + // $is on Video (which is itself a delegate) — filter on its sub-model RatedVideo + await expect( + client.video.findMany({ + where: { $is: { RatedVideo: { rating: 5 } } }, + }), + ).toResolveWithLength(1); + + // nested $is: Asset.$is.Video.$is.RatedVideo + await expect( + client.asset.findMany({ + where: { $is: { Video: { $is: { RatedVideo: { rating: 5 } } } } }, + }), + ).toResolveWithLength(1); + }); }); describe('Delegate update tests', () => { diff --git a/tests/e2e/orm/schemas/delegate/typecheck.ts b/tests/e2e/orm/schemas/delegate/typecheck.ts index 76e36115f..c17981818 100644 --- a/tests/e2e/orm/schemas/delegate/typecheck.ts +++ b/tests/e2e/orm/schemas/delegate/typecheck.ts @@ -176,11 +176,48 @@ async function queryBuilder() { client.$qb.selectFrom('Video').select(['viewCount']).execute(); } +async function whereEPC() { + // unknown fields in `where` clause should produce a TypeScript error + await client.asset.findMany({ + where: { + viewCount: 1, + // @ts-expect-error notExistsColumn is not a valid field + notExistsColumn: 1, + }, + }); + + await client.asset.findFirst({ + where: { + viewCount: 1, + // @ts-expect-error notExistsColumn is not a valid field + notExistsColumn: 1, + }, + }); + + // valid fields should not produce errors + await client.asset.findMany({ + where: { + viewCount: { gt: 0 }, + }, + }); + + // unknown fields in `where` clause for update should also produce TypeScript errors + await client.asset.update({ + where: { + id: 1, + // @ts-expect-error notExistsColumn is not a valid field + notExistsColumn: 1, + }, + data: { viewCount: 2 }, + }); +} + async function main() { await create(); await update(); await find(); await queryBuilder(); + await whereEPC(); } main(); diff --git a/tests/e2e/orm/schemas/typing/typecheck.ts b/tests/e2e/orm/schemas/typing/typecheck.ts index 9f8b2aa86..d28059153 100644 --- a/tests/e2e/orm/schemas/typing/typecheck.ts +++ b/tests/e2e/orm/schemas/typing/typecheck.ts @@ -220,6 +220,30 @@ async function find() { console.log(u.posts[0]?.author?.role); // @ts-expect-error console.log(u.posts[0]?.author?.email); + + // unknown fields in `where` clause should produce TypeScript errors + await client.user.findMany({ + where: { + email: 'test@test.com', + // @ts-expect-error notExistsColumn is not a valid field + notExistsColumn: 1, + }, + }); + + await client.user.findFirst({ + where: { + email: 'test@test.com', + // @ts-expect-error notExistsColumn is not a valid field + notExistsColumn: 1, + }, + }); + + // valid where fields should not produce errors + await client.user.findMany({ + where: { + email: { contains: '@test.com' }, + }, + }); } async function create() { @@ -554,6 +578,35 @@ async function update() { email: 'alex@zenstack.dev', }, }); + + // unknown fields in `where` clause should produce TypeScript errors for update/upsert/updateMany + await client.user.update({ + where: { + id: 1, + // @ts-expect-error notExistsColumn is not a valid field + notExistsColumn: 1, + }, + data: { name: 'Alex' }, + }); + + await client.user.upsert({ + where: { + id: 1, + // @ts-expect-error notExistsColumn is not a valid field + notExistsColumn: 1, + }, + create: { name: 'Alex', email: 'alex@zenstack.dev' }, + update: { name: 'Alex New' }, + }); + + await client.user.updateMany({ + where: { + email: 'test@test.com', + // @ts-expect-error notExistsColumn is not a valid field + notExistsColumn: 1, + }, + data: { name: 'Alex' }, + }); } async function del() {