Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion packages/clients/client-helpers/src/invalidation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ export function createInvalidator(
invalidator: InvalidateFunc,
logging: Logger | undefined,
) {
const normalizedModel = normalizeModelName(model, schema);
return async (...args: unknown[]) => {
const [_, variables] = args;
const predicate = await getInvalidationPredicate(
model,
normalizedModel,
operation as ORMWriteActionType,
variables,
schema,
Expand Down Expand Up @@ -87,3 +88,9 @@ function findNestedRead(visitingModel: string, targetModels: string[], schema: S
const modelsRead = getReadModels(visitingModel, schema, args);
return targetModels.some((m) => modelsRead.includes(m));
}

// resolves a model name to its canonical form as defined in the schema (case-insensitive match)
function normalizeModelName(model: string, schema: SchemaDef) {
const target = model.toLowerCase();
return Object.keys(schema.models).find((k) => k.toLowerCase() === target) ?? model;
}
4 changes: 0 additions & 4 deletions packages/clients/client-helpers/src/nested-write-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,6 @@ export class NestedWriteVisitor {
}
}
break;

default: {
throw new Error(`unhandled action type ${action}`);
}
}
}

Expand Down
19 changes: 0 additions & 19 deletions packages/clients/client-helpers/test/nested-write-visitor.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1097,25 +1097,6 @@ describe('NestedWriteVisitor tests', () => {
}),
).resolves.not.toThrow();
});

it('throws error for unhandled action type', async () => {
const schema = createSchema({
User: {
name: 'User',
fields: {
id: createField('id', 'String'),
},
uniqueFields: {},
idFields: ['id'],
},
});

const visitor = new NestedWriteVisitor(schema, {});

await expect(visitor.visit('User', 'invalidAction' as any, { data: {} })).rejects.toThrow(
'unhandled action type',
);
});
});

describe('complex real-world scenarios', () => {
Expand Down
5 changes: 5 additions & 0 deletions packages/clients/tanstack-query/src/common/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ import type { QueryClient } from '@tanstack/query-core';
import type { InvalidationPredicate, QueryInfo } from '@zenstackhq/client-helpers';
import { parseQueryKey } from './query-key.js';

/** Strips a trailing slash from an endpoint URL. */
export function normalizeEndpoint(endpoint: string) {
return endpoint.replace(/\/$/, '');
}

export function invalidateQueriesMatchingPredicate(queryClient: QueryClient, predicate: InvalidationPredicate) {
return queryClient.invalidateQueries({
predicate: ({ queryKey }) => {
Expand Down
4 changes: 4 additions & 0 deletions packages/clients/tanstack-query/src/common/constants.ts
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
/** Route segment for custom procedures. */
export const CUSTOM_PROC_ROUTE_NAME = '$procs';

/** Route prefix for transaction endpoints. */
export const TRANSACTION_ROUTE_PREFIX = '$transaction';
56 changes: 56 additions & 0 deletions packages/clients/tanstack-query/src/common/transaction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import type { Logger } from '@zenstackhq/client-helpers';
import { createInvalidator, type InvalidateFunc } from '@zenstackhq/client-helpers';
import type { FetchFn } from '@zenstackhq/client-helpers/fetch';
import { fetcher, marshal } from '@zenstackhq/client-helpers/fetch';
import { CoreReadOperations } from '@zenstackhq/orm';
import type { SchemaDef } from '@zenstackhq/schema';
import { TRANSACTION_ROUTE_PREFIX } from './constants.js';
import type { TransactionOperation } from './types.js';

/**
* Builds the mutation function for a sequential transaction request.
*/
export function makeTransactionMutationFn<Schema extends SchemaDef>(endpoint: string, fetch: FetchFn | undefined) {
return (operations: TransactionOperation<Schema>[]) => {
const reqUrl = `${endpoint}/${TRANSACTION_ROUTE_PREFIX}/sequential`;
Comment thread
ymc9 marked this conversation as resolved.
const fetchInit = {
method: 'POST',
headers: { 'content-type': 'application/json' },
body: marshal(operations),
};
return fetcher<unknown[]>(reqUrl, fetchInit, fetch);
};
}

/**
* Builds the `onSuccess` handler for a sequential transaction mutation that invalidates
* all queries affected by the operations in the transaction.
*
* @param schema The schema definition.
* @param invalidateFunc Function that invalidates queries matching a predicate.
* @param logging Logging option.
* @param origOnSuccess The user-provided `onSuccess` callback to call after invalidation.
*/
export function makeTransactionOnSuccess(
schema: SchemaDef,
invalidateFunc: InvalidateFunc,
logging: Logger | undefined,
origOnSuccess: ((...args: any[]) => any) | undefined,
) {
return async (...args: any[]) => {
const variables = Array.isArray(args[1]) ? args[1] : [];
for (const op of variables) {
if (typeof op?.model !== 'string' || typeof op?.op !== 'string') {
continue;
}
// read-only ops don't mutate state, so they don't trigger invalidation
if (CoreReadOperations.includes(op.op)) {
continue;
}
const invalidator = createInvalidator(op.model, op.op, schema, invalidateFunc, logging);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
// pass op.args as mutation variables so the invalidator can analyze nested writes
await invalidator(args[0], op.args, args[2]);
}
await origOnSuccess?.(...args);
};
}
61 changes: 61 additions & 0 deletions packages/clients/tanstack-query/src/common/types.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import type { Logger, OptimisticDataProvider } from '@zenstackhq/client-helpers';
import type { FetchFn } from '@zenstackhq/client-helpers/fetch';
import type {
AggregateArgs,
CountArgs,
CreateArgs,
CreateManyAndReturnArgs,
CreateManyArgs,
DeleteArgs,
DeleteManyArgs,
ExistsArgs,
FindFirstArgs,
FindManyArgs,
FindUniqueArgs,
GetProcedureNames,
GetSlicedOperations,
GroupByArgs,
ModelAllowsCreate,
OperationsRequiringCreate,
ProcedureFunc,
QueryOptions,
UpdateArgs,
UpdateManyAndReturnArgs,
UpdateManyArgs,
UpsertArgs,
} from '@zenstackhq/orm';
import type { GetModels, SchemaDef } from '@zenstackhq/schema';

Expand Down Expand Up @@ -100,3 +116,48 @@ export type WithOptimistic<T> = T extends Array<infer U> ? Array<WithOptimisticF
export type ProcedureReturn<Schema extends SchemaDef, Name extends GetProcedureNames<Schema>> = Awaited<
ReturnType<ProcedureFunc<Schema, Name>>
>;

/**
* Maps each core CRUD operation to its argument type for a given model.
*/
type CrudArgsMap<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
findMany: FindManyArgs<Schema, Model>;
findUnique: FindUniqueArgs<Schema, Model>;
findFirst: FindFirstArgs<Schema, Model>;
create: CreateArgs<Schema, Model>;
createMany: CreateManyArgs<Schema, Model>;
createManyAndReturn: CreateManyAndReturnArgs<Schema, Model>;
update: UpdateArgs<Schema, Model>;
updateMany: UpdateManyArgs<Schema, Model>;
updateManyAndReturn: UpdateManyAndReturnArgs<Schema, Model>;
upsert: UpsertArgs<Schema, Model>;
delete: DeleteArgs<Schema, Model>;
deleteMany: DeleteManyArgs<Schema, Model>;
count: CountArgs<Schema, Model>;
aggregate: AggregateArgs<Schema, Model>;
groupBy: GroupByArgs<Schema, Model>;
exists: ExistsArgs<Schema, Model>;
};

/**
* Operations available for a given model, omitting create-style operations
* for models that don't allow them (e.g. delegate models).
*/
type AllowedTransactionOps<Schema extends SchemaDef, Model extends GetModels<Schema>> =
ModelAllowsCreate<Schema, Model> extends true
? keyof CrudArgsMap<Schema, Model>
: Exclude<keyof CrudArgsMap<Schema, Model>, OperationsRequiringCreate>;

/**
* Represents a single operation to execute within a sequential transaction.
*
* The `model`, `op`, and `args` fields are correlated: `op` is constrained to
* the CRUD operations available on `model`, and `args` is typed accordingly.
*/
export type TransactionOperation<Schema extends SchemaDef> = {
[Model in GetModels<Schema>]: {
[Op in AllowedTransactionOps<Schema, Model>]: {} extends CrudArgsMap<Schema, Model>[Op]
? { model: Model; op: Op; args?: CrudArgsMap<Schema, Model>[Op] }
: { model: Model; op: Op; args: CrudArgsMap<Schema, Model>[Op] };
}[AllowedTransactionOps<Schema, Model>];
}[GetModels<Schema>];
Comment thread
coderabbitai[bot] marked this conversation as resolved.
48 changes: 45 additions & 3 deletions packages/clients/tanstack-query/src/react.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,16 @@ import type {
} from '@zenstackhq/orm';
import type { GetModels, SchemaDef } from '@zenstackhq/schema';
import { createContext, useContext } from 'react';
import { getAllQueries, invalidateQueriesMatchingPredicate } from './common/client.js';
import { getAllQueries, invalidateQueriesMatchingPredicate, normalizeEndpoint } from './common/client.js';
import { CUSTOM_PROC_ROUTE_NAME } from './common/constants.js';
import { getQueryKey } from './common/query-key.js';
import { makeTransactionMutationFn, makeTransactionOnSuccess } from './common/transaction.js';
import type {
ExtraMutationOptions,
ExtraQueryOptions,
ProcedureReturn,
QueryContext,
TransactionOperation,
TrimSlicedOperations,
WithOptimistic,
} from './common/types.js';
Expand Down Expand Up @@ -165,6 +167,12 @@ export type ModelMutationModelResult<
): Promise<SimplifiedResult<Schema, Model, T, Options, false, Array, ExtResult>>;
};

export type TransactionMutationOptions<Schema extends SchemaDef> = Omit<
UseMutationOptions<unknown[], DefaultError, TransactionOperation<Schema>[]>,
'mutationFn'
> &
Omit<ExtraMutationOptions, 'optimisticUpdate' | 'optimisticDataProvider'>;

export type ClientHooks<
Schema extends SchemaDef,
Options extends QueryOptions<Schema> = QueryOptions<Schema>,
Expand All @@ -176,7 +184,13 @@ export type ClientHooks<
Options,
ExtResult
>;
} & ProcedureHooks<Schema, Options>;
} & ProcedureHooks<Schema, Options> & {
$transaction: {
useSequential(
options?: TransactionMutationOptions<Schema>,
): UseMutationResult<unknown[], DefaultError, TransactionOperation<Schema>[]>;
};
};

type ProcedureHookGroup<Schema extends SchemaDef, Options extends QueryOptions<Schema>> = {
[Name in GetSlicedProcedures<Schema, Options>]: GetProcedure<Schema, Name> extends { mutation: true }
Expand Down Expand Up @@ -448,6 +462,10 @@ export function useClientQueries<SchemaOrClient extends SchemaDef | ClientContra
(result as any).$procs = buildProcedureHooks();
}

(result as any).$transaction = {
useSequential: (hookOptions?: any) => useInternalTransactionMutation(schema, { ...options, ...hookOptions }),
};

return result;
}

Expand Down Expand Up @@ -789,11 +807,35 @@ export function useInternalMutation<TArgs, R = any>(
return useMutation(finalOptions);
}

export function useInternalTransactionMutation<Schema extends SchemaDef>(
schema: Schema,
options?: TransactionMutationOptions<Schema>,
) {
const { endpoint, fetch, logging } = useFetchOptions(options);
const queryClient = useQueryClient();

const mutationFn = makeTransactionMutationFn<Schema>(endpoint, fetch);

const finalOptions = { ...options, mutationFn };

if (options?.invalidateQueries !== false) {
const origOnSuccess = finalOptions.onSuccess;
finalOptions.onSuccess = makeTransactionOnSuccess(
schema,
(predicate) => invalidateQueriesMatchingPredicate(queryClient, predicate),
logging,
origOnSuccess as any,
);
}

return useMutation(finalOptions);
}

function useFetchOptions(options: QueryContext | undefined) {
const { endpoint, fetch, logging } = useHooksContext();
// options take precedence over context
return {
endpoint: options?.endpoint ?? endpoint,
endpoint: normalizeEndpoint(options?.endpoint ?? endpoint),
fetch: options?.fetch ?? fetch,
logging: options?.logging ?? logging,
};
Expand Down
Loading
Loading