diff --git a/packages/clients/tanstack-query/src/common/types.ts b/packages/clients/tanstack-query/src/common/types.ts index 8993445ed..6f8a64694 100644 --- a/packages/clients/tanstack-query/src/common/types.ts +++ b/packages/clients/tanstack-query/src/common/types.ts @@ -76,3 +76,15 @@ type WithOptimisticFlag = T extends object : T; export type WithOptimistic = T extends Array ? Array> : WithOptimisticFlag; + +export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'PATCH' | 'DELETE'; + +export type CustomOperationKind = 'query' | 'suspenseQuery' | 'infiniteQuery' | 'suspenseInfiniteQuery' | 'mutation'; + +export type CustomOperationDefinition = { + kind: CustomOperationKind; + method?: HttpMethod; + /** Phantom fields for typing */ + __args?: TArgs; + __result?: TResult; +}; diff --git a/packages/clients/tanstack-query/src/react.ts b/packages/clients/tanstack-query/src/react.ts index 4449e29fb..2965bbd0d 100644 --- a/packages/clients/tanstack-query/src/react.ts +++ b/packages/clients/tanstack-query/src/react.ts @@ -55,6 +55,7 @@ import { getQueryKey } from './common/query-key'; import type { ExtraMutationOptions, ExtraQueryOptions, + CustomOperationDefinition, QueryContext, TrimDelegateModelOperations, WithOptimistic, @@ -131,8 +132,32 @@ export type ModelMutationModelResult< ): Promise>; }; -export type ClientHooks = QueryOptions> = { - [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks; +type CustomOperationHooks> = {}> = { + [K in keyof CustomOperations as `use${Capitalize}`]: CustomOperations[K] extends CustomOperationDefinition< + infer TArgs, + infer TResult + > + ? CustomOperations[K]['kind'] extends 'mutation' + ? (options?: ModelMutationOptions) => ModelMutationResult + : CustomOperations[K]['kind'] extends 'query' + ? (args?: TArgs, options?: ModelQueryOptions) => ModelQueryResult + : CustomOperations[K]['kind'] extends 'suspenseQuery' + ? (args?: TArgs, options?: ModelSuspenseQueryOptions) => ModelSuspenseQueryResult + : CustomOperations[K]['kind'] extends 'infiniteQuery' + ? (args?: TArgs, options?: ModelInfiniteQueryOptions) => ModelInfiniteQueryResult< + InfiniteData + > + : (args?: TArgs, options?: ModelSuspenseInfiniteQueryOptions) => + ModelSuspenseInfiniteQueryResult> + : never; +}; + +export type ClientHooks< + Schema extends SchemaDef, + Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, +> = { + [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks; }; // Note that we can potentially use TypeScript's mapped type to directly map from ORM contract, but that seems @@ -141,6 +166,7 @@ export type ModelQueryHooks< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, > = TrimDelegateModelOperations< Schema, Model, @@ -250,7 +276,7 @@ export type ModelQueryHooks< args: Subset>, options?: ModelSuspenseQueryOptions>, ): ModelSuspenseQueryResult>; - } + } & CustomOperationHooks >; /** @@ -259,20 +285,27 @@ export type ModelQueryHooks< * @param schema The schema. * @param options Options for all queries originated from this hook. */ -export function useClientQueries = QueryOptions>( - schema: Schema, - options?: QueryContext, -): ClientHooks { +export function useClientQueries< + Schema extends SchemaDef, + Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, +>(schema: Schema, options?: QueryContext, customOperations?: CustomOperations): ClientHooks { return Object.keys(schema.models).reduce( (acc, model) => { - (acc as any)[lowerCaseFirst(model)] = useModelQueries, Options>( + (acc as any)[lowerCaseFirst(model)] = useModelQueries< + Schema, + GetModels, + Options, + CustomOperations + >( schema, model as GetModels, options, + customOperations, ); return acc; }, - {} as ClientHooks, + {} as ClientHooks, ); } @@ -283,7 +316,13 @@ export function useModelQueries< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions, ->(schema: Schema, model: Model, rootOptions?: QueryContext): ModelQueryHooks { + CustomOperations extends Record> = {}, +>( + schema: Schema, + model: Model, + rootOptions?: QueryContext, + customOperations?: CustomOperations, +): ModelQueryHooks { const modelDef = Object.values(schema.models).find((m) => m.name.toLowerCase() === model.toLowerCase()); if (!modelDef) { throw new Error(`Model "${model}" not found in schema`); @@ -291,7 +330,7 @@ export function useModelQueries< const modelName = modelDef.name; - return { + const builtInHooks = { useFindUnique: (args: any, options?: any) => { return useInternalQuery(schema, modelName, 'findUnique', args, { ...rootOptions, ...options }); }, @@ -390,6 +429,78 @@ export function useModelQueries< return useInternalSuspenseQuery(schema, modelName, 'groupBy', args, { ...rootOptions, ...options }); }, } as ModelQueryHooks; + + const customHooks = createCustomOperationHooks(schema, modelName, rootOptions, customOperations); + + return { ...builtInHooks, ...customHooks } as ModelQueryHooks; +} + +function createCustomOperationHooks< + Schema extends SchemaDef, + CustomOperations extends Record> = {}, +>( + schema: Schema, + modelName: string, + rootOptions: QueryContext | undefined, + customOperations?: CustomOperations, +) { + if (!customOperations) { + return {} as CustomOperationHooks; + } + + const hooks: Record = {}; + for (const [name, def] of Object.entries(customOperations)) { + const hookName = `use${name.charAt(0).toUpperCase()}${name.slice(1)}`; + switch (def.kind) { + case 'query': + hooks[hookName] = (args?: unknown, options?: unknown) => + useInternalQuery(schema, modelName, name, args, { + ...(rootOptions ?? {}), + ...((options as object) ?? {}), + }); + break; + case 'suspenseQuery': + hooks[hookName] = (args?: unknown, options?: unknown) => + useInternalSuspenseQuery(schema, modelName, name, args, { + ...(rootOptions ?? {}), + ...((options as object) ?? {}), + }); + break; + case 'infiniteQuery': + hooks[hookName] = (args?: unknown, options?: unknown) => + useInternalInfiniteQuery(schema, modelName, name, args, buildInfiniteOptions(rootOptions, options)); + break; + case 'suspenseInfiniteQuery': + hooks[hookName] = (args?: unknown, options?: unknown) => + useInternalSuspenseInfiniteQuery( + schema, + modelName, + name, + args, + buildInfiniteOptions(rootOptions, options) as any, + ); + break; + case 'mutation': + hooks[hookName] = (options?: unknown) => + useInternalMutation(schema, modelName, (def.method ?? 'POST') as any, name, { + ...(rootOptions ?? {}), + ...((options as object) ?? {}), + }); + break; + default: + break; + } + } + + return hooks as CustomOperationHooks; +} + +function buildInfiniteOptions(rootOptions: QueryContext | undefined, options: unknown) { + const merged = { ...(rootOptions ?? {}), ...((options as object) ?? {}) } as Record; + if (typeof merged.getNextPageParam !== 'function') { + merged.getNextPageParam = () => undefined; + } + return merged; } export function useInternalQuery( diff --git a/packages/clients/tanstack-query/src/svelte/index.svelte.ts b/packages/clients/tanstack-query/src/svelte/index.svelte.ts index a94941c40..1974798e3 100644 --- a/packages/clients/tanstack-query/src/svelte/index.svelte.ts +++ b/packages/clients/tanstack-query/src/svelte/index.svelte.ts @@ -54,6 +54,7 @@ import { getContext, setContext } from 'svelte'; import { getAllQueries, invalidateQueriesMatchingPredicate } from '../common/client'; import { getQueryKey } from '../common/query-key'; import type { + CustomOperationDefinition, ExtraMutationOptions, ExtraQueryOptions, QueryContext, @@ -120,8 +121,30 @@ export type ModelMutationModelResult< ): Promise>; }; -export type ClientHooks = QueryOptions> = { - [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks; +type CustomOperationHooks> = {}> = { + [K in keyof CustomOperations as `use${Capitalize}`]: CustomOperations[K] extends CustomOperationDefinition< + infer TArgs, + infer TResult + > + ? CustomOperations[K]['kind'] extends 'mutation' + ? (options?: ModelMutationOptions) => ModelMutationResult + : CustomOperations[K]['kind'] extends 'infiniteQuery' | 'suspenseInfiniteQuery' + ? (args?: TArgs, options?: ModelInfiniteQueryOptions) => ModelInfiniteQueryResult + : (args?: TArgs, options?: ModelQueryOptions) => ModelQueryResult + : never; +}; + +export type ClientHooks< + Schema extends SchemaDef, + Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, +> = { + [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks< + Schema, + Model, + Options, + CustomOperations + >; }; // Note that we can potentially use TypeScript's mapped type to directly map from ORM contract, but that seems @@ -130,6 +153,7 @@ export type ModelQueryHooks< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, > = TrimDelegateModelOperations< Schema, Model, @@ -202,26 +226,37 @@ export type ModelQueryHooks< args: Accessor>>, options?: Accessor>>, ): ModelQueryResult>; - } + } & CustomOperationHooks >; /** * Gets data query hooks for all models in the schema. */ -export function useClientQueries = QueryOptions>( +export function useClientQueries< + Schema extends SchemaDef, + Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, +>( schema: Schema, options?: Accessor, -): ClientHooks { + customOperations?: CustomOperations, +): ClientHooks { return Object.keys(schema.models).reduce( (acc, model) => { - (acc as any)[lowerCaseFirst(model)] = useModelQueries, Options>( + (acc as any)[lowerCaseFirst(model)] = useModelQueries< + Schema, + GetModels, + Options, + CustomOperations + >( schema, model as GetModels, options, + customOperations, ); return acc; }, - {} as ClientHooks, + {} as ClientHooks, ); } @@ -232,7 +267,13 @@ export function useModelQueries< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions, ->(schema: Schema, model: Model, rootOptions?: Accessor): ModelQueryHooks { + CustomOperations extends Record> = {}, +>( + schema: Schema, + model: Model, + rootOptions?: Accessor, + customOperations?: CustomOperations, +): ModelQueryHooks { const modelDef = Object.values(schema.models).find((m) => m.name.toLowerCase() === model.toLowerCase()); if (!modelDef) { throw new Error(`Model "${model}" not found in schema`); @@ -248,7 +289,7 @@ export function useModelQueries< }; }; - return { + const builtIns = { useFindUnique: (args: any, options?: any) => { return useInternalQuery(schema, modelName, 'findUnique', args, merge(rootOptions, options)); }, @@ -313,6 +354,67 @@ export function useModelQueries< return useInternalQuery(schema, modelName, 'groupBy', args, options); }, } as unknown as ModelQueryHooks; + + const custom = createCustomOperationHooks(schema, modelName, rootOptions, customOperations, merge); + + return { ...builtIns, ...custom } as ModelQueryHooks; +} + +function createCustomOperationHooks< + Schema extends SchemaDef, + CustomOperations extends Record> = {}, +>( + schema: Schema, + modelName: string, + rootOptions: Accessor | undefined, + customOperations: CustomOperations | undefined, + mergeOptions: (rootOpt: unknown, opt: unknown) => Accessor, +) { + if (!customOperations) { + return {} as CustomOperationHooks; + } + + const hooks: Record = {}; + for (const [name, def] of Object.entries(customOperations)) { + const hookName = `use${name.charAt(0).toUpperCase()}${name.slice(1)}`; + const merged = (options?: unknown) => mergeOptions(rootOptions, options); + + switch (def.kind) { + case 'query': + case 'suspenseQuery': + hooks[hookName] = (args?: unknown, options?: unknown) => + useInternalQuery(schema, modelName, name, args, merged(options as Accessor | undefined)); + break; + case 'infiniteQuery': + case 'suspenseInfiniteQuery': + hooks[hookName] = (args?: unknown, options?: unknown) => { + const mergedOptions = merged(options as Accessor | undefined); + const withDefault = () => { + const value = mergedOptions?.() as any; + if (value && typeof value.getNextPageParam !== 'function') { + value.getNextPageParam = () => undefined; + } + return value; + }; + return useInternalInfiniteQuery(schema, modelName, name, args, withDefault as any); + }; + break; + case 'mutation': + hooks[hookName] = (options?: unknown) => + useInternalMutation( + schema, + modelName, + (def.method ?? 'POST') as any, + name, + merged(options as Accessor | undefined) as any, + ); + break; + default: + break; + } + } + + return hooks as CustomOperationHooks; } export function useInternalQuery( diff --git a/packages/clients/tanstack-query/src/vue.ts b/packages/clients/tanstack-query/src/vue.ts index bd4dcf74b..e8c78d253 100644 --- a/packages/clients/tanstack-query/src/vue.ts +++ b/packages/clients/tanstack-query/src/vue.ts @@ -54,6 +54,7 @@ import { getQueryKey } from './common/query-key'; import type { ExtraMutationOptions, ExtraQueryOptions, + CustomOperationDefinition, QueryContext, TrimDelegateModelOperations, WithOptimistic, @@ -121,8 +122,32 @@ export type ModelMutationModelResult< ): Promise>; }; -export type ClientHooks = QueryOptions> = { - [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks; +type CustomOperationHooks> = {}> = { + [K in keyof CustomOperations as `use${Capitalize}`]: CustomOperations[K] extends CustomOperationDefinition< + infer TArgs, + infer TResult + > + ? CustomOperations[K]['kind'] extends 'mutation' + ? (options?: ModelMutationOptions) => ModelMutationResult + : CustomOperations[K]['kind'] extends 'infiniteQuery' | 'suspenseInfiniteQuery' + ? (args?: TArgs, options?: ModelInfiniteQueryOptions) => ModelInfiniteQueryResult< + InfiniteData + > + : (args?: TArgs, options?: ModelQueryOptions) => ModelQueryResult + : never; +}; + +export type ClientHooks< + Schema extends SchemaDef, + Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, +> = { + [Model in GetModels as `${Uncapitalize}`]: ModelQueryHooks< + Schema, + Model, + Options, + CustomOperations + >; }; // Note that we can potentially use TypeScript's mapped type to directly map from ORM contract, but that seems @@ -131,6 +156,7 @@ export type ModelQueryHooks< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, > = TrimDelegateModelOperations< Schema, Model, @@ -205,26 +231,37 @@ export type ModelQueryHooks< args: MaybeRefOrGetter>>, options?: MaybeRefOrGetter>>, ): ModelQueryResult>; - } + } & CustomOperationHooks >; /** * Gets data query hooks for all models in the schema. */ -export function useClientQueries = QueryOptions>( +export function useClientQueries< + Schema extends SchemaDef, + Options extends QueryOptions = QueryOptions, + CustomOperations extends Record> = {}, +>( schema: Schema, options?: MaybeRefOrGetter, -): ClientHooks { + customOperations?: CustomOperations, +): ClientHooks { return Object.keys(schema.models).reduce( (acc, model) => { - (acc as any)[lowerCaseFirst(model)] = useModelQueries, Options>( + (acc as any)[lowerCaseFirst(model)] = useModelQueries< + Schema, + GetModels, + Options, + CustomOperations + >( schema, model as GetModels, options, + customOperations, ); return acc; }, - {} as ClientHooks, + {} as ClientHooks, ); } @@ -235,7 +272,13 @@ export function useModelQueries< Schema extends SchemaDef, Model extends GetModels, Options extends QueryOptions, ->(schema: Schema, model: Model, rootOptions?: MaybeRefOrGetter): ModelQueryHooks { + CustomOperations extends Record> = {}, +>( + schema: Schema, + model: Model, + rootOptions?: MaybeRefOrGetter, + customOperations?: CustomOperations, +): ModelQueryHooks { const modelDef = Object.values(schema.models).find((m) => m.name.toLowerCase() === model.toLowerCase()); if (!modelDef) { throw new Error(`Model "${model}" not found in schema`); @@ -249,7 +292,7 @@ export function useModelQueries< }); }; - return { + const builtIn = { useFindUnique: (args: any, options?: any) => { return useInternalQuery(schema, modelName, 'findUnique', args, merge(rootOptions, options)); }, @@ -314,6 +357,64 @@ export function useModelQueries< return useInternalQuery(schema, modelName, 'groupBy', args, merge(rootOptions, options)); }, } as ModelQueryHooks; + + const custom = createCustomOperationHooks(schema, modelName, rootOptions, customOperations, merge); + + return { ...builtIn, ...custom } as ModelQueryHooks; +} + +function createCustomOperationHooks< + Schema extends SchemaDef, + CustomOperations extends Record> = {}, +>( + schema: Schema, + modelName: string, + rootOptions: MaybeRefOrGetter | undefined, + customOperations: CustomOperations | undefined, + mergeOptions: ( + rootOpt: MaybeRefOrGetter | undefined, + opt: MaybeRefOrGetter | undefined, + ) => MaybeRefOrGetter, +) { + if (!customOperations) { + return {} as CustomOperationHooks; + } + + const hooks: Record = {}; + for (const [name, def] of Object.entries(customOperations)) { + const hookName = `use${name.charAt(0).toUpperCase()}${name.slice(1)}`; + const merged = (opt?: MaybeRefOrGetter) => mergeOptions(rootOptions, opt); + + switch (def.kind) { + case 'query': + case 'suspenseQuery': + hooks[hookName] = (args?: unknown, options?: MaybeRefOrGetter) => + useInternalQuery(schema, modelName, name, args, merged(options) as any); + break; + case 'infiniteQuery': + case 'suspenseInfiniteQuery': + hooks[hookName] = (args?: unknown, options?: MaybeRefOrGetter) => { + const mergedOptions = merged(options) as MaybeRefOrGetter; + const withDefault = computed(() => { + const value = toValue(mergedOptions) as any; + if (value && typeof value.getNextPageParam !== 'function') { + value.getNextPageParam = () => undefined; + } + return value; + }); + return useInternalInfiniteQuery(schema, modelName, name, args, withDefault as any); + }; + break; + case 'mutation': + hooks[hookName] = (options?: MaybeRefOrGetter) => + useInternalMutation(schema, modelName, (def.method ?? 'POST') as any, name, merged(options) as any); + break; + default: + break; + } + } + + return hooks as CustomOperationHooks; } export function useInternalQuery( diff --git a/packages/server/src/api/index.ts b/packages/server/src/api/index.ts index 09d9700eb..28cb65971 100644 --- a/packages/server/src/api/index.ts +++ b/packages/server/src/api/index.ts @@ -1,2 +1,9 @@ export { RestApiHandler, type RestApiHandlerOptions } from './rest'; -export { RPCApiHandler, type RPCApiHandlerOptions } from './rpc'; +export { + RPCBadInputErrorResponse, + RPCGenericErrorResponse, + RPCApiHandler, + type RPCApiHandlerOptions, + type RPCCustomOperation, + type RPCCustomOperationContext, +} from './rpc'; diff --git a/packages/server/src/api/rpc/index.ts b/packages/server/src/api/rpc/index.ts index e4e6ea64c..7939b8f44 100644 --- a/packages/server/src/api/rpc/index.ts +++ b/packages/server/src/api/rpc/index.ts @@ -8,6 +8,40 @@ import { log, registerCustomSerializers } from '../utils'; registerCustomSerializers(); +const BUILT_IN_OPERATIONS = new Set([ + 'create', + 'createMany', + 'createManyAndReturn', + 'upsert', + 'findFirst', + 'findUnique', + 'findMany', + 'aggregate', + 'groupBy', + 'count', + 'update', + 'updateMany', + 'updateManyAndReturn', + 'delete', + 'deleteMany', +]); + +const JS_IDENTIFIER_RE = /^[A-Za-z_$][A-Za-z0-9_$]*$/; + +export class RPCBadInputErrorResponse extends Error {} + +export class RPCGenericErrorResponse extends Error {} + +export type RPCCustomOperationContext = RequestContext & { + model: string; + operation: string; + args?: unknown; +}; + +export type RPCCustomOperation = ( + args: RPCCustomOperationContext, +) => Promise | Response; + /** * Options for {@link RPCApiHandler} */ @@ -21,13 +55,21 @@ export type RPCApiHandlerOptions = { * Logging configuration */ log?: LogConfig; + + /** + * Custom operations callable via RPC path. Keys must be valid JS identifiers and must not + * overlap with built-in operations. + */ + customOperations?: Record>; }; /** * RPC style API request handler that mirrors the ZenStackClient API */ export class RPCApiHandler implements ApiHandler { - constructor(private readonly options: RPCApiHandlerOptions) {} + constructor(private readonly options: RPCApiHandlerOptions) { + this.validateCustomOperations(); + } get schema(): Schema { return this.options.schema; @@ -51,6 +93,11 @@ export class RPCApiHandler implements ApiH let args: unknown; let resCode = 200; + const { query: normalizedQuery, qArgs, error: queryError } = this.normalizeQuery(query); + if (queryError) { + return this.makeBadInputErrorResponse(queryError); + } + switch (op) { case 'create': case 'createMany': @@ -76,13 +123,7 @@ export class RPCApiHandler implements ApiH if (method !== 'GET') { return this.makeBadInputErrorResponse('invalid request method, only GET is supported'); } - try { - args = query?.['q'] - ? this.unmarshalQ(query['q'] as string, query['meta'] as string | undefined) - : {}; - } catch { - return this.makeBadInputErrorResponse('invalid "q" query parameter'); - } + args = qArgs ?? {}; break; case 'update': @@ -103,19 +144,33 @@ export class RPCApiHandler implements ApiH if (method !== 'DELETE') { return this.makeBadInputErrorResponse('invalid request method, only DELETE is supported'); } + args = qArgs ?? {}; + break; + + default: + break; + } + + if (!BUILT_IN_OPERATIONS.has(op)) { + const custom = this.options.customOperations?.[op]; + if (custom) { try { - args = query?.['q'] - ? this.unmarshalQ(query['q'] as string, query['meta'] as string | undefined) - : {}; + return await custom({ + client, + method, + path, + query: normalizedQuery, + requestBody, + model, + operation: op, + args: qArgs, + }); } catch (err) { - return this.makeBadInputErrorResponse( - err instanceof Error ? err.message : 'invalid "q" query parameter', - ); + return this.mapCustomOperationError(err); } - break; + } - default: - return this.makeBadInputErrorResponse('invalid operation: ' + op); + return this.makeBadInputErrorResponse('invalid operation: ' + op); } const { result: processedArgs, error } = await this.processRequestPayload(args); @@ -256,4 +311,67 @@ export class RPCApiHandler implements ApiH return parsedValue; } + + private normalizeQuery(originalQuery: RequestContext['query']) { + if (!originalQuery) { + return { query: originalQuery, qArgs: undefined as unknown }; + } + + const qValue = (originalQuery as any).q; + if (typeof qValue === 'undefined') { + return { query: originalQuery, qArgs: undefined as unknown }; + } + + if (typeof qValue !== 'string') { + return { query: originalQuery, qArgs: undefined as unknown, error: 'invalid "q" query parameter' }; + } + + try { + const parsed = this.unmarshalQ(qValue, (originalQuery as any).meta as string | undefined); + return { query: { ...(originalQuery as any), q: parsed }, qArgs: parsed }; + } catch (err) { + return { + query: originalQuery, + qArgs: undefined as unknown, + error: err instanceof Error ? err.message : 'invalid "q" query parameter', + }; + } + } + + private mapCustomOperationError(err: unknown): Response { + if (err instanceof RPCBadInputErrorResponse) { + return this.makeBadInputErrorResponse(err.message); + } + + if (err instanceof ORMError) { + return this.makeORMErrorResponse(err); + } + + if (err instanceof RPCGenericErrorResponse) { + return this.makeGenericErrorResponse(err); + } + + return this.makeGenericErrorResponse(err); + } + + private validateCustomOperations() { + const customOps = this.options.customOperations; + if (!customOps) { + return; + } + + Object.entries(customOps).forEach(([name, fn]) => { + if (!JS_IDENTIFIER_RE.test(name)) { + throw new Error(`custom operation name must be a valid identifier: ${name}`); + } + + if (BUILT_IN_OPERATIONS.has(name)) { + throw new Error(`custom operation cannot override built-in operation: ${name}`); + } + + if (typeof fn !== 'function') { + throw new Error(`custom operation must be a function: ${name}`); + } + }); + } } diff --git a/packages/server/test/api/rpc.test.ts b/packages/server/test/api/rpc.test.ts index 19e44ca08..6b9f2ee49 100644 --- a/packages/server/test/api/rpc.test.ts +++ b/packages/server/test/api/rpc.test.ts @@ -4,7 +4,12 @@ import { createPolicyTestClient, createTestClient } from '@zenstackhq/testtools' import Decimal from 'decimal.js'; import SuperJSON from 'superjson'; import { beforeAll, describe, expect, it } from 'vitest'; -import { RPCApiHandler } from '../../src/api'; +import { + RPCBadInputErrorResponse, + RPCGenericErrorResponse, + RPCApiHandler, + type RPCApiHandlerOptions, +} from '../../src/api'; import { schema } from '../utils'; describe('RPC API Handler Tests', () => { @@ -353,6 +358,107 @@ describe('RPC API Handler Tests', () => { expect(r.error.message).toContain('invalid "q" query parameter'); }); + it('custom operation works', async () => { + const handleRequest = makeHandler({ + customOperations: { + echo: async ({ requestBody }) => { + if (!requestBody) { + throw new RPCBadInputErrorResponse('missing body'); + } + return { status: 200, body: { data: requestBody } }; + }, + }, + }); + + const r = await handleRequest({ + method: 'post', + path: '/post/echo', + client: rawClient, + requestBody: { message: 'hello' }, + }); + + expect(r.status).toBe(200); + expect(r.data).toEqual({ message: 'hello' }); + }); + + it('custom operation auto unmarshals query', async () => { + const serialized = SuperJSON.serialize({ where: { id: '1', created: new Date() } }); + + const handleRequest = makeHandler({ + customOperations: { + passthrough: async ({ query }) => ({ status: 200, body: { data: query?.q } }), + }, + }); + + const r = await handleRequest({ + method: 'get', + path: '/post/passthrough', + client: rawClient, + query: { + q: JSON.stringify(serialized.json), + meta: JSON.stringify({ serialization: serialized.meta }), + }, + }); + + expect(r.status).toBe(200); + expect(r.data.where.id).toBe('1'); + expect(r.data.where.created).toBeInstanceOf(Date); + }); + + it('custom operation maps errors', async () => { + const handleRequest = makeHandler({ + customOperations: { + bad: async () => { + throw new RPCBadInputErrorResponse('nope'); + }, + boom: async () => { + throw new RPCGenericErrorResponse('boom'); + }, + }, + }); + + const bad = await handleRequest({ method: 'get', path: '/post/bad', client: rawClient }); + expect(bad.status).toBe(400); + expect(bad.error.message).toBe('nope'); + + const boom = await handleRequest({ method: 'get', path: '/post/boom', client: rawClient }); + expect(boom.status).toBe(500); + expect(boom.error.message).toBe('boom'); + }); + + it('custom operation cannot override built-in', () => { + expect(() => + new RPCApiHandler({ + schema: client.$schema, + customOperations: { + findMany: async () => ({ status: 200, body: { data: null } }), + }, + }), + ).toThrow(/cannot override built-in operation/); + }); + + it('custom operation name must be identifier', () => { + expect(() => + new RPCApiHandler({ + schema: client.$schema, + customOperations: { + 'not-valid': async () => ({ status: 200, body: { data: null } }), + }, + }), + ).toThrow(/valid identifier/); + }); + + it('custom operation must be function', () => { + expect(() => + new RPCApiHandler({ + schema: client.$schema, + customOperations: { + nope: 'oops' as any, + }, + }), + ).toThrow(/must be a function/); + }); + it('field types', async () => { const schema = ` model Foo { @@ -508,8 +614,8 @@ describe('RPC API Handler Tests', () => { expect(r.data).toBeNull(); }); - function makeHandler() { - const handler = new RPCApiHandler({ schema: client.$schema }); + function makeHandler(options?: Partial>) { + const handler = new RPCApiHandler({ schema: client.$schema, ...(options ?? {}) }); return async (args: any) => { const r = await handler.handleRequest({ ...args,