Skip to content
This repository was archived by the owner on Mar 1, 2026. It is now read-only.

Commit 2c1ffac

Browse files
ymc9Copilot
andauthored
feat(orm): implement client API extensions, refactor query args extension (#603)
* feat(orm): implement client API extensions, refactor query args extension * address PR comments * Fix upsert validation to merge $create and $update schemas (#604) * Initial plan * fix: merge $create and $update schemas for upsert validation - Handle upsert operation specially to match TypeScript type behavior - When both $create and $update schemas exist, merge them for upsert - Add test case to verify the fix works correctly Co-authored-by: ymc9 <104139426+ymc9@users.noreply.github.com> * fix: improve comment accuracy about Zod merge behavior Co-authored-by: ymc9 <104139426+ymc9@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ymc9 <104139426+ymc9@users.noreply.github.com> * minor fixes --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
1 parent 27e8910 commit 2c1ffac

File tree

11 files changed

+618
-148
lines changed

11 files changed

+618
-148
lines changed

packages/orm/src/client/client-impl.ts

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import { ZenStackQueryExecutor } from './executor/zenstack-query-executor';
3737
import * as BuiltinFunctions from './functions';
3838
import { SchemaDbPusher } from './helpers/schema-db-pusher';
3939
import type { ClientOptions, ProceduresOptions } from './options';
40-
import type { RuntimePlugin } from './plugin';
40+
import type { AnyPlugin } from './plugin';
4141
import { createZenStackPromise, type ZenStackPromise } from './promise';
4242
import { ResultProcessor } from './result-processor';
4343

@@ -293,8 +293,8 @@ export class ClientImpl {
293293
await new SchemaDbPusher(this.schema, this.kysely).push();
294294
}
295295

296-
$use(plugin: RuntimePlugin<any, any>) {
297-
const newPlugins: RuntimePlugin<any, any>[] = [...(this.$options.plugins ?? []), plugin];
296+
$use(plugin: AnyPlugin) {
297+
const newPlugins: AnyPlugin[] = [...(this.$options.plugins ?? []), plugin];
298298
const newOptions: ClientOptions<SchemaDef> = {
299299
...this.options,
300300
plugins: newPlugins,
@@ -308,7 +308,7 @@ export class ClientImpl {
308308

309309
$unuse(pluginId: string) {
310310
// tsc perf
311-
const newPlugins: RuntimePlugin<any, any>[] = [];
311+
const newPlugins: AnyPlugin[] = [];
312312
for (const plugin of this.options.plugins ?? []) {
313313
if (plugin.id !== pluginId) {
314314
newPlugins.push(plugin);
@@ -329,7 +329,7 @@ export class ClientImpl {
329329
// tsc perf
330330
const newOptions: ClientOptions<SchemaDef> = {
331331
...this.options,
332-
plugins: [] as RuntimePlugin<any, any>[],
332+
plugins: [] as AnyPlugin[],
333333
};
334334
const newClient = new ClientImpl(this.schema, newOptions, this);
335335
// create a new validator to have a fresh schema cache, because plugins may
@@ -408,6 +408,16 @@ function createClientProxy(client: ClientImpl): ClientImpl {
408408
return new Proxy(client, {
409409
get: (target, prop, receiver) => {
410410
if (typeof prop === 'string' && prop.startsWith('$')) {
411+
// Check for plugin-provided members (search in reverse order so later plugins win)
412+
const plugins = target.$options.plugins ?? [];
413+
for (let i = plugins.length - 1; i >= 0; i--) {
414+
const plugin = plugins[i];
415+
const clientMembers = plugin?.client as Record<string, unknown> | undefined;
416+
if (clientMembers && prop in clientMembers) {
417+
return clientMembers[prop];
418+
}
419+
}
420+
// Fall through to built-in $ methods
411421
return Reflect.get(target, prop, receiver);
412422
}
413423

packages/orm/src/client/contract.ts

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,15 @@ import type {
4040
UpdateManyArgs,
4141
UpsertArgs,
4242
} from './crud-types';
43-
import type { CoreCrudOperations } from './crud/operations/base';
43+
import type {
44+
CoreCreateOperations,
45+
CoreCrudOperations,
46+
CoreDeleteOperations,
47+
CoreReadOperations,
48+
CoreUpdateOperations,
49+
} from './crud/operations/base';
4450
import type { ClientOptions, QueryOptions, ToQueryOptions } from './options';
45-
import type { ExtQueryArgsBase, RuntimePlugin } from './plugin';
51+
import type { ExtClientMembersBase, ExtQueryArgsBase, RuntimePlugin } from './plugin';
4652
import type { ZenStackPromise } from './promise';
4753
import type { ToKysely } from './query-builder';
4854

@@ -51,11 +57,26 @@ type TransactionUnsupportedMethods = (typeof TRANSACTION_UNSUPPORTED_METHODS)[nu
5157
/**
5258
* Extracts extended query args for a specific operation.
5359
*/
54-
type ExtractExtQueryArgs<ExtQueryArgs, Operation extends CoreCrudOperations> = Operation extends keyof ExtQueryArgs
55-
? NonNullable<ExtQueryArgs[Operation]>
56-
: 'all' extends keyof ExtQueryArgs
57-
? NonNullable<ExtQueryArgs['all']>
58-
: {};
60+
type ExtractExtQueryArgs<ExtQueryArgs, Operation extends CoreCrudOperations> = (Operation extends keyof ExtQueryArgs
61+
? ExtQueryArgs[Operation]
62+
: {}) &
63+
('$create' extends keyof ExtQueryArgs
64+
? Operation extends CoreCreateOperations
65+
? ExtQueryArgs['$create']
66+
: {}
67+
: {}) &
68+
('$read' extends keyof ExtQueryArgs ? (Operation extends CoreReadOperations ? ExtQueryArgs['$read'] : {}) : {}) &
69+
('$update' extends keyof ExtQueryArgs
70+
? Operation extends CoreUpdateOperations
71+
? ExtQueryArgs['$update']
72+
: {}
73+
: {}) &
74+
('$delete' extends keyof ExtQueryArgs
75+
? Operation extends CoreDeleteOperations
76+
? ExtQueryArgs['$delete']
77+
: {}
78+
: {}) &
79+
('$all' extends keyof ExtQueryArgs ? ExtQueryArgs['$all'] : {});
5980

6081
/**
6182
* Transaction isolation levels.
@@ -75,6 +96,7 @@ export type ClientContract<
7596
Schema extends SchemaDef,
7697
Options extends ClientOptions<Schema> = ClientOptions<Schema>,
7798
ExtQueryArgs extends ExtQueryArgsBase = {},
99+
ExtClientMembers extends ExtClientMembersBase = {},
78100
> = {
79101
/**
80102
* The schema definition.
@@ -132,7 +154,7 @@ export type ClientContract<
132154
/**
133155
* Sets the current user identity.
134156
*/
135-
$setAuth(auth: AuthType<Schema> | undefined): ClientContract<Schema, Options, ExtQueryArgs>;
157+
$setAuth(auth: AuthType<Schema> | undefined): ClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>;
136158

137159
/**
138160
* Returns a new client with new options applied.
@@ -141,15 +163,17 @@ export type ClientContract<
141163
* const dbNoValidation = db.$setOptions({ ...db.$options, validateInput: false });
142164
* ```
143165
*/
144-
$setOptions<Options extends ClientOptions<Schema>>(options: Options): ClientContract<Schema, Options, ExtQueryArgs>;
166+
$setOptions<NewOptions extends ClientOptions<Schema>>(
167+
options: NewOptions,
168+
): ClientContract<Schema, NewOptions, ExtQueryArgs, ExtClientMembers>;
145169

146170
/**
147171
* Returns a new client enabling/disabling input validations expressed with attributes like
148172
* `@email`, `@regex`, `@@validate`, etc.
149173
*
150174
* @deprecated Use {@link $setOptions} instead.
151175
*/
152-
$setInputValidation(enable: boolean): ClientContract<Schema, Options, ExtQueryArgs>;
176+
$setInputValidation(enable: boolean): ClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>;
153177

154178
/**
155179
* The Kysely query builder instance.
@@ -165,7 +189,7 @@ export type ClientContract<
165189
* Starts an interactive transaction.
166190
*/
167191
$transaction<T>(
168-
callback: (tx: TransactionClientContract<Schema, Options, ExtQueryArgs>) => Promise<T>,
192+
callback: (tx: TransactionClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>) => Promise<T>,
169193
options?: { isolationLevel?: TransactionIsolationLevel },
170194
): Promise<T>;
171195

@@ -180,14 +204,18 @@ export type ClientContract<
180204
/**
181205
* Returns a new client with the specified plugin installed.
182206
*/
183-
$use<PluginSchema extends SchemaDef = Schema, PluginExtQueryArgs extends ExtQueryArgsBase = {}>(
184-
plugin: RuntimePlugin<PluginSchema, PluginExtQueryArgs>,
185-
): ClientContract<Schema, Options, ExtQueryArgs & PluginExtQueryArgs>;
207+
$use<
208+
PluginSchema extends SchemaDef = Schema,
209+
PluginExtQueryArgs extends ExtQueryArgsBase = {},
210+
PluginExtClientMembers extends ExtClientMembersBase = {},
211+
>(
212+
plugin: RuntimePlugin<PluginSchema, PluginExtQueryArgs, PluginExtClientMembers>,
213+
): ClientContract<Schema, Options, ExtQueryArgs & PluginExtQueryArgs, ExtClientMembers & PluginExtClientMembers>;
186214

187215
/**
188216
* Returns a new client with the specified plugin removed.
189217
*/
190-
$unuse(pluginId: string): ClientContract<Schema, Options, ExtQueryArgs>;
218+
$unuse(pluginId: string): ClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>;
191219

192220
/**
193221
* Returns a new client with all plugins removed.
@@ -216,7 +244,8 @@ export type ClientContract<
216244
ToQueryOptions<Options>,
217245
ExtQueryArgs
218246
>;
219-
} & ProcedureOperations<Schema>;
247+
} & ProcedureOperations<Schema> &
248+
ExtClientMembers;
220249

221250
/**
222251
* The contract for a client in a transaction.
@@ -225,7 +254,8 @@ export type TransactionClientContract<
225254
Schema extends SchemaDef,
226255
Options extends ClientOptions<Schema>,
227256
ExtQueryArgs extends ExtQueryArgsBase,
228-
> = Omit<ClientContract<Schema, Options, ExtQueryArgs>, TransactionUnsupportedMethods>;
257+
ExtClientMembers extends ExtClientMembersBase,
258+
> = Omit<ClientContract<Schema, Options, ExtQueryArgs, ExtClientMembers>, TransactionUnsupportedMethods>;
229259

230260
export type ProcedureOperations<Schema extends SchemaDef> =
231261
Schema['procedures'] extends Record<string, ProcedureDef>

packages/orm/src/client/crud/operations/base.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,36 @@ export const CoreWriteOperations = [
119119
*/
120120
export type CoreWriteOperations = (typeof CoreWriteOperations)[number];
121121

122+
/**
123+
* List of core create operations.
124+
*/
125+
export const CoreCreateOperations = ['create', 'createMany', 'createManyAndReturn', 'upsert'] as const;
126+
127+
/**
128+
* List of core create operations.
129+
*/
130+
export type CoreCreateOperations = (typeof CoreCreateOperations)[number];
131+
132+
/**
133+
* List of core update operations.
134+
*/
135+
export const CoreUpdateOperations = ['update', 'updateMany', 'updateManyAndReturn', 'upsert'] as const;
136+
137+
/**
138+
* List of core update operations.
139+
*/
140+
export type CoreUpdateOperations = (typeof CoreUpdateOperations)[number];
141+
142+
/**
143+
* List of core delete operations.
144+
*/
145+
export const CoreDeleteOperations = ['delete', 'deleteMany'] as const;
146+
147+
/**
148+
* List of core delete operations.
149+
*/
150+
export type CoreDeleteOperations = (typeof CoreDeleteOperations)[number];
151+
122152
/**
123153
* List of all CRUD operations, including 'orThrow' variants.
124154
*/

packages/orm/src/client/crud/validator/index.ts

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import {
3535
type UpsertArgs,
3636
} from '../../crud-types';
3737
import { createInternalError, createInvalidInputError } from '../../errors';
38+
import type { AnyPlugin } from '../../plugin';
3839
import {
3940
fieldHasDefaultValue,
4041
getDiscriminatorField,
@@ -46,7 +47,13 @@ import {
4647
requireField,
4748
requireModel,
4849
} from '../../query-utils';
49-
import type { CoreCrudOperations } from '../operations/base';
50+
import {
51+
CoreCreateOperations,
52+
CoreDeleteOperations,
53+
CoreReadOperations,
54+
CoreUpdateOperations,
55+
type CoreCrudOperations,
56+
} from '../operations/base';
5057
import {
5158
addBigIntValidation,
5259
addCustomValidation,
@@ -365,8 +372,8 @@ export class InputValidator<Schema extends SchemaDef> {
365372
private mergePluginArgsSchema(schema: ZodObject, operation: CoreCrudOperations) {
366373
let result = schema;
367374
for (const plugin of this.options.plugins ?? []) {
368-
if (plugin.extQueryArgs) {
369-
const pluginSchema = plugin.extQueryArgs.getValidationSchema(operation);
375+
if (plugin.queryArgs) {
376+
const pluginSchema = this.getPluginExtQueryArgsSchema(plugin, operation);
370377
if (pluginSchema) {
371378
result = result.extend(pluginSchema.shape);
372379
}
@@ -375,6 +382,77 @@ export class InputValidator<Schema extends SchemaDef> {
375382
return result.strict();
376383
}
377384

385+
private getPluginExtQueryArgsSchema(plugin: AnyPlugin, operation: string): ZodObject | undefined {
386+
if (!plugin.queryArgs) {
387+
return undefined;
388+
}
389+
390+
let result: ZodType | undefined;
391+
392+
if (operation in plugin.queryArgs && plugin.queryArgs[operation]) {
393+
// most specific operation takes highest precedence
394+
result = plugin.queryArgs[operation];
395+
} else if (operation === 'upsert') {
396+
// upsert is special: it's in both CoreCreateOperations and CoreUpdateOperations
397+
// so we need to merge both $create and $update schemas to match the type system
398+
const createSchema =
399+
'$create' in plugin.queryArgs && plugin.queryArgs['$create'] ? plugin.queryArgs['$create'] : undefined;
400+
const updateSchema =
401+
'$update' in plugin.queryArgs && plugin.queryArgs['$update'] ? plugin.queryArgs['$update'] : undefined;
402+
403+
if (createSchema && updateSchema) {
404+
invariant(
405+
createSchema instanceof z.ZodObject,
406+
'Plugin extended query args schema must be a Zod object',
407+
);
408+
invariant(
409+
updateSchema instanceof z.ZodObject,
410+
'Plugin extended query args schema must be a Zod object',
411+
);
412+
// merge both schemas (combines their properties)
413+
result = createSchema.extend(updateSchema.shape);
414+
} else if (createSchema) {
415+
result = createSchema;
416+
} else if (updateSchema) {
417+
result = updateSchema;
418+
}
419+
} else if (
420+
// then comes grouped operations: $create, $read, $update, $delete
421+
CoreCreateOperations.includes(operation as CoreCreateOperations) &&
422+
'$create' in plugin.queryArgs &&
423+
plugin.queryArgs['$create']
424+
) {
425+
result = plugin.queryArgs['$create'];
426+
} else if (
427+
CoreReadOperations.includes(operation as CoreReadOperations) &&
428+
'$read' in plugin.queryArgs &&
429+
plugin.queryArgs['$read']
430+
) {
431+
result = plugin.queryArgs['$read'];
432+
} else if (
433+
CoreUpdateOperations.includes(operation as CoreUpdateOperations) &&
434+
'$update' in plugin.queryArgs &&
435+
plugin.queryArgs['$update']
436+
) {
437+
result = plugin.queryArgs['$update'];
438+
} else if (
439+
CoreDeleteOperations.includes(operation as CoreDeleteOperations) &&
440+
'$delete' in plugin.queryArgs &&
441+
plugin.queryArgs['$delete']
442+
) {
443+
result = plugin.queryArgs['$delete'];
444+
} else if ('$all' in plugin.queryArgs && plugin.queryArgs['$all']) {
445+
// finally comes $all
446+
result = plugin.queryArgs['$all'];
447+
}
448+
449+
invariant(
450+
result === undefined || result instanceof z.ZodObject,
451+
'Plugin extended query args schema must be a Zod object',
452+
);
453+
return result;
454+
}
455+
378456
// #region Find
379457

380458
private makeFindSchema(model: string, operation: CoreCrudOperations) {

packages/orm/src/client/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ export { BaseCrudDialect } from './crud/dialects/base-dialect';
66
export {
77
AllCrudOperations,
88
AllReadOperations,
9+
CoreCreateOperations,
910
CoreCrudOperations,
11+
CoreDeleteOperations,
1012
CoreReadOperations,
13+
CoreUpdateOperations,
1114
CoreWriteOperations,
1215
} from './crud/operations/base';
1316
export { InputValidator } from './crud/validator';

packages/orm/src/client/options.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import type { PrependParameter } from '../utils/type-utils';
44
import type { ClientContract, CRUD_EXT } from './contract';
55
import type { GetProcedureNames, ProcedureHandlerFunc } from './crud-types';
66
import type { BaseCrudDialect } from './crud/dialects/base-dialect';
7-
import type { RuntimePlugin } from './plugin';
7+
import type { AnyPlugin } from './plugin';
88
import type { ToKyselySchema } from './query-builder';
99

1010
export type ZModelFunctionContext<Schema extends SchemaDef> = {
@@ -59,7 +59,7 @@ export type ClientOptions<Schema extends SchemaDef> = {
5959
/**
6060
* Plugins.
6161
*/
62-
plugins?: RuntimePlugin<any, any>[];
62+
plugins?: AnyPlugin[];
6363

6464
/**
6565
* Logging configuration.

0 commit comments

Comments
 (0)