Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion packages/clients/client-helpers/src/invalidation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ export function createInvalidator(
invalidator: InvalidateFunc,
logging: Logger | undefined,
) {
const normalizedModel = normalizeModelName(model, schema);
return async (...args: unknown[]) => {
const [_, variables] = args;
const predicate = await getInvalidationPredicate(
model,
normalizedModel,
operation as ORMWriteActionType,
variables,
schema,
Expand Down Expand Up @@ -87,3 +88,9 @@ function findNestedRead(visitingModel: string, targetModels: string[], schema: S
const modelsRead = getReadModels(visitingModel, schema, args);
return targetModels.some((m) => modelsRead.includes(m));
}

// resolves a model name to its canonical form as defined in the schema (case-insensitive match)
function normalizeModelName(model: string, schema: SchemaDef) {
const target = model.toLowerCase();
return Object.keys(schema.models).find((k) => k.toLowerCase() === target) ?? model;
}
4 changes: 0 additions & 4 deletions packages/clients/client-helpers/src/nested-write-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,6 @@ export class NestedWriteVisitor {
}
}
break;

default: {
throw new Error(`unhandled action type ${action}`);
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions packages/clients/tanstack-query/src/common/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ import type { QueryClient } from '@tanstack/query-core';
import type { InvalidationPredicate, QueryInfo } from '@zenstackhq/client-helpers';
import { parseQueryKey } from './query-key.js';

/** Strips a trailing slash from an endpoint URL. */
export function normalizeEndpoint(endpoint: string) {
return endpoint.replace(/\/$/, '');
}

export function invalidateQueriesMatchingPredicate(queryClient: QueryClient, predicate: InvalidationPredicate) {
return queryClient.invalidateQueries({
predicate: ({ queryKey }) => {
Expand Down
4 changes: 4 additions & 0 deletions packages/clients/tanstack-query/src/common/constants.ts
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
/** Route segment for custom procedures. */
export const CUSTOM_PROC_ROUTE_NAME = '$procs';

/** Route prefix for transaction endpoints. */
export const TRANSACTION_ROUTE_PREFIX = '$transaction';
54 changes: 54 additions & 0 deletions packages/clients/tanstack-query/src/common/transaction.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import type { Logger } from '@zenstackhq/client-helpers';
import { createInvalidator, type InvalidateFunc } from '@zenstackhq/client-helpers';
import type { FetchFn } from '@zenstackhq/client-helpers/fetch';
import { fetcher, marshal } from '@zenstackhq/client-helpers/fetch';
import type { SchemaDef } from '@zenstackhq/schema';
import { TRANSACTION_ROUTE_PREFIX } from './constants.js';
import type { TransactionOperation } from './types.js';

/**
* Builds the mutation function for a sequential transaction request.
*/
export function makeTransactionMutationFn<Schema extends SchemaDef>(
endpoint: string,
fetch: FetchFn | undefined,
) {
return (operations: TransactionOperation<Schema>[]) => {
const reqUrl = `${endpoint}/${TRANSACTION_ROUTE_PREFIX}/sequential`;
Comment thread
ymc9 marked this conversation as resolved.
const fetchInit = {
method: 'POST',
headers: { 'content-type': 'application/json' },
body: marshal(operations),
};
return fetcher<unknown[]>(reqUrl, fetchInit, fetch);
};
}

/**
* Builds the `onSuccess` handler for a sequential transaction mutation that invalidates
* all queries affected by the operations in the transaction.
*
* @param schema The schema definition.
* @param invalidateFunc Function that invalidates queries matching a predicate.
* @param logging Logging option.
* @param origOnSuccess The user-provided `onSuccess` callback to call after invalidation.
*/
export function makeTransactionOnSuccess(
schema: SchemaDef,
invalidateFunc: InvalidateFunc,
logging: Logger | undefined,
origOnSuccess: ((...args: any[]) => any) | undefined,
) {
return async (...args: any[]) => {
const variables = Array.isArray(args[1]) ? args[1] : [];
for (const op of variables) {
if (typeof op?.model !== 'string' || typeof op?.op !== 'string') {
continue;
}
const invalidator = createInvalidator(op.model, op.op, schema, invalidateFunc, logging);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
// pass op.args as mutation variables so the invalidator can analyze nested writes
await invalidator(args[0], op.args, args[2]);
}
await origOnSuccess?.(...args);
};
}
63 changes: 63 additions & 0 deletions packages/clients/tanstack-query/src/common/types.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import type { Logger, OptimisticDataProvider } from '@zenstackhq/client-helpers';
import type { FetchFn } from '@zenstackhq/client-helpers/fetch';
import type {
AggregateArgs,
CountArgs,
CreateArgs,
CreateManyAndReturnArgs,
CreateManyArgs,
DeleteArgs,
DeleteManyArgs,
ExistsArgs,
FindFirstArgs,
FindManyArgs,
FindUniqueArgs,
GetProcedureNames,
GetSlicedOperations,
GroupByArgs,
ModelAllowsCreate,
OperationsRequiringCreate,
ProcedureFunc,
QueryOptions,
UpdateArgs,
UpdateManyAndReturnArgs,
UpdateManyArgs,
UpsertArgs,
} from '@zenstackhq/orm';
import type { GetModels, SchemaDef } from '@zenstackhq/schema';

Expand Down Expand Up @@ -100,3 +116,50 @@ export type WithOptimistic<T> = T extends Array<infer U> ? Array<WithOptimisticF
export type ProcedureReturn<Schema extends SchemaDef, Name extends GetProcedureNames<Schema>> = Awaited<
ReturnType<ProcedureFunc<Schema, Name>>
>;

/**
* Maps each core CRUD operation to its argument type for a given model.
*/
type CrudArgsMap<Schema extends SchemaDef, Model extends GetModels<Schema>> = {
findMany: FindManyArgs<Schema, Model>;
findUnique: FindUniqueArgs<Schema, Model>;
findFirst: FindFirstArgs<Schema, Model>;
create: CreateArgs<Schema, Model>;
createMany: CreateManyArgs<Schema, Model>;
createManyAndReturn: CreateManyAndReturnArgs<Schema, Model>;
update: UpdateArgs<Schema, Model>;
updateMany: UpdateManyArgs<Schema, Model>;
updateManyAndReturn: UpdateManyAndReturnArgs<Schema, Model>;
upsert: UpsertArgs<Schema, Model>;
delete: DeleteArgs<Schema, Model>;
deleteMany: DeleteManyArgs<Schema, Model>;
count: CountArgs<Schema, Model>;
aggregate: AggregateArgs<Schema, Model>;
groupBy: GroupByArgs<Schema, Model>;
exists: ExistsArgs<Schema, Model>;
};

/**
* Operations available for a given model, omitting create-style operations
* for models that don't allow them (e.g. delegate models).
*/
type AllowedTransactionOps<Schema extends SchemaDef, Model extends GetModels<Schema>> =
ModelAllowsCreate<Schema, Model> extends true
? keyof CrudArgsMap<Schema, Model>
: Exclude<keyof CrudArgsMap<Schema, Model>, OperationsRequiringCreate>;

/**
* Represents a single operation to execute within a sequential transaction.
*
* The `model`, `op`, and `args` fields are correlated: `op` is constrained to
* the CRUD operations available on `model`, and `args` is typed accordingly.
*/
export type TransactionOperation<Schema extends SchemaDef> = {
[Model in GetModels<Schema>]: {
[Op in AllowedTransactionOps<Schema, Model>]: {
model: Model;
op: Op;
args?: CrudArgsMap<Schema, Model>[Op];
};
}[AllowedTransactionOps<Schema, Model>];
}[GetModels<Schema>];
Comment thread
coderabbitai[bot] marked this conversation as resolved.
48 changes: 45 additions & 3 deletions packages/clients/tanstack-query/src/react.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,16 @@ import type {
} from '@zenstackhq/orm';
import type { GetModels, SchemaDef } from '@zenstackhq/schema';
import { createContext, useContext } from 'react';
import { getAllQueries, invalidateQueriesMatchingPredicate } from './common/client.js';
import { getAllQueries, invalidateQueriesMatchingPredicate, normalizeEndpoint } from './common/client.js';
import { CUSTOM_PROC_ROUTE_NAME } from './common/constants.js';
import { getQueryKey } from './common/query-key.js';
import { makeTransactionMutationFn, makeTransactionOnSuccess } from './common/transaction.js';
import type {
ExtraMutationOptions,
ExtraQueryOptions,
ProcedureReturn,
QueryContext,
TransactionOperation,
TrimSlicedOperations,
WithOptimistic,
} from './common/types.js';
Expand Down Expand Up @@ -165,6 +167,12 @@ export type ModelMutationModelResult<
): Promise<SimplifiedResult<Schema, Model, T, Options, false, Array, ExtResult>>;
};

export type TransactionMutationOptions<Schema extends SchemaDef> = Omit<
UseMutationOptions<unknown[], DefaultError, TransactionOperation<Schema>[]>,
'mutationFn'
> &
Omit<ExtraMutationOptions, 'optimisticUpdate' | 'optimisticDataProvider'>;

export type ClientHooks<
Schema extends SchemaDef,
Options extends QueryOptions<Schema> = QueryOptions<Schema>,
Expand All @@ -176,7 +184,13 @@ export type ClientHooks<
Options,
ExtResult
>;
} & ProcedureHooks<Schema, Options>;
} & ProcedureHooks<Schema, Options> & {
$transaction: {
useSequential(
options?: TransactionMutationOptions<Schema>,
): UseMutationResult<unknown[], DefaultError, TransactionOperation<Schema>[]>;
};
};

type ProcedureHookGroup<Schema extends SchemaDef, Options extends QueryOptions<Schema>> = {
[Name in GetSlicedProcedures<Schema, Options>]: GetProcedure<Schema, Name> extends { mutation: true }
Expand Down Expand Up @@ -448,6 +462,10 @@ export function useClientQueries<SchemaOrClient extends SchemaDef | ClientContra
(result as any).$procs = buildProcedureHooks();
}

(result as any).$transaction = {
useSequential: (hookOptions?: any) => useInternalTransactionMutation(schema, { ...options, ...hookOptions }),
};

return result;
}

Expand Down Expand Up @@ -789,11 +807,35 @@ export function useInternalMutation<TArgs, R = any>(
return useMutation(finalOptions);
}

export function useInternalTransactionMutation<Schema extends SchemaDef>(
schema: Schema,
options?: TransactionMutationOptions<Schema>,
) {
const { endpoint, fetch, logging } = useFetchOptions(options);
const queryClient = useQueryClient();

const mutationFn = makeTransactionMutationFn<Schema>(endpoint, fetch);

const finalOptions = { ...options, mutationFn };

if (options?.invalidateQueries !== false) {
const origOnSuccess = finalOptions.onSuccess;
finalOptions.onSuccess = makeTransactionOnSuccess(
schema,
(predicate) => invalidateQueriesMatchingPredicate(queryClient, predicate),
logging,
origOnSuccess as any,
);
}

return useMutation(finalOptions);
}

function useFetchOptions(options: QueryContext | undefined) {
const { endpoint, fetch, logging } = useHooksContext();
// options take precedence over context
return {
endpoint: options?.endpoint ?? endpoint,
endpoint: normalizeEndpoint(options?.endpoint ?? endpoint),
fetch: options?.fetch ?? fetch,
logging: options?.logging ?? logging,
};
Expand Down
54 changes: 51 additions & 3 deletions packages/clients/tanstack-query/src/svelte/index.svelte.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,16 @@ import type {
} from '@zenstackhq/orm';
import type { GetModels, SchemaDef } from '@zenstackhq/schema';
import { getContext, setContext } from 'svelte';
import { getAllQueries, invalidateQueriesMatchingPredicate } from '../common/client.js';
import { getAllQueries, invalidateQueriesMatchingPredicate, normalizeEndpoint } from '../common/client.js';
import { CUSTOM_PROC_ROUTE_NAME } from '../common/constants.js';
import { getQueryKey } from '../common/query-key.js';
import { makeTransactionMutationFn, makeTransactionOnSuccess } from '../common/transaction.js';
import type {
ExtraMutationOptions,
ExtraQueryOptions,
ProcedureReturn,
QueryContext,
TransactionOperation,
TrimSlicedOperations,
WithOptimistic,
} from '../common/types.js';
Expand Down Expand Up @@ -158,6 +160,12 @@ export type ModelMutationModelResult<
): Promise<SimplifiedResult<Schema, Model, T, Options, false, Array, ExtResult>>;
};

export type TransactionMutationOptions<Schema extends SchemaDef> = Omit<
CreateMutationOptions<unknown[], DefaultError, TransactionOperation<Schema>[]>,
'mutationFn'
> &
Omit<ExtraMutationOptions, 'optimisticUpdate' | 'optimisticDataProvider'>;

export type ClientHooks<
Schema extends SchemaDef,
Options extends QueryOptions<Schema> = QueryOptions<Schema>,
Expand All @@ -169,7 +177,13 @@ export type ClientHooks<
Options,
ExtResult
>;
} & ProcedureHooks<Schema, Options>;
} & ProcedureHooks<Schema, Options> & {
$transaction: {
useSequential(
options?: TransactionMutationOptions<Schema>,
): CreateMutationResult<unknown[], DefaultError, TransactionOperation<Schema>[]>;
};
};

type ProcedureHookGroup<Schema extends SchemaDef, Options extends QueryOptions<Schema>> = {
[Name in GetSlicedProcedures<Schema, Options>]: GetProcedure<Schema, Name> extends { mutation: true }
Expand Down Expand Up @@ -375,6 +389,10 @@ export function useClientQueries<SchemaOrClient extends SchemaDef | ClientContra
(result as any).$procs = buildProcedureHooks();
}

(result as any).$transaction = {
useSequential: (hookOptions?: any) => useInternalTransactionMutation(schema, merge(options, hookOptions)),
};

return result;
}

Expand Down Expand Up @@ -691,12 +709,42 @@ export function useInternalMutation<TArgs, R = any>(
return createMutation(finalOptions);
}

export function useInternalTransactionMutation<Schema extends SchemaDef>(
schema: Schema,
options?: Accessor<TransactionMutationOptions<Schema>>,
) {
const { endpoint, fetch, logging } = useFetchOptions(options);
const queryClient = useQueryClient();

const mutationFn = makeTransactionMutationFn<Schema>(endpoint, fetch);

const finalOptions = () => {
const optionsValue = options?.();
const result: any = { ...optionsValue, mutationFn };

if (optionsValue?.invalidateQueries !== false) {
result.onSuccess = makeTransactionOnSuccess(
schema,
(predicate: InvalidationPredicate) =>
// @ts-ignore
invalidateQueriesMatchingPredicate(queryClient, predicate),
logging,
optionsValue?.onSuccess as any,
);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

return result;
};

return createMutation(finalOptions);
}

function useFetchOptions(options: Accessor<QueryContext> | undefined) {
const { endpoint, fetch, logging } = useQuerySettings();
const optionsValue = options?.();
// options take precedence over context
return {
endpoint: optionsValue?.endpoint ?? endpoint,
endpoint: normalizeEndpoint(optionsValue?.endpoint ?? endpoint),
fetch: optionsValue?.fetch ?? fetch,
logging: optionsValue?.logging ?? logging,
};
Expand Down
Loading
Loading