Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
### Testing

- E2E tests are in `tests/e2e/` directory
- Regression tests for GitHub issues go in `tests/regression/test/` as `issue-{number}.test.ts`

### ZenStack CLI Commands

Expand Down
11 changes: 6 additions & 5 deletions packages/orm/src/client/crud/dialects/base-dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
requireIdFields,
requireModel,
requireTypeDef,
tmpAlias,
} from '../../query-utils';

export abstract class BaseCrudDialect<Schema extends SchemaDef> {
Expand Down Expand Up @@ -298,7 +299,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
}
}

const joinAlias = `${modelAlias}$${field}`;
const joinAlias = tmpAlias(`${modelAlias}$${field}`);
const joinPairs = buildJoinPairs(
this.schema,
model,
Expand All @@ -307,7 +308,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
field,
joinAlias,
);
const filterResultField = `${field}$filter`;
const filterResultField = tmpAlias(`${field}$flt`);

const joinSelect = this.eb
.selectFrom(`${fieldDef.type} as ${joinAlias}`)
Expand Down Expand Up @@ -383,7 +384,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

// evaluating the filter involves creating an inner select,
// give it an alias to avoid conflict
const relationFilterSelectAlias = `${modelAlias}$${field}$filter`;
const relationFilterSelectAlias = tmpAlias(`${modelAlias}$${field}$flt`);

const buildPkFkWhereRefs = (eb: ExpressionBuilder<any, any>) => {
const m2m = getManyToManyRelation(this.schema, model, field);
Expand Down Expand Up @@ -1083,7 +1084,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
);
const sort = this.negateSort(value._count, negated);
result = result.orderBy((eb) => {
const subQueryAlias = `${modelAlias}$orderBy$${field}$count`;
const subQueryAlias = tmpAlias(`${modelAlias}$ob$${field}$ct`);
let subQuery = this.buildSelectModel(relationModel, subQueryAlias);
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, subQueryAlias);
subQuery = subQuery.where(() =>
Expand All @@ -1099,7 +1100,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
}
} else {
// order by to-one relation
const joinAlias = `${modelAlias}$orderBy$${index}`;
const joinAlias = tmpAlias(`${modelAlias}$ob$${index}`);
result = result.leftJoin(`${relationModel} as ${joinAlias}`, (join) => {
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, joinAlias);
return join.on((eb) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
requireField,
requireIdFields,
requireModel,
tmpAlias,
} from '../../query-utils';
import { BaseCrudDialect } from './base-dialect';

Expand All @@ -31,7 +32,7 @@ export abstract class LateralJoinDialectBase<Schema extends SchemaDef> extends B
parentAlias: string,
payload: true | FindArgs<Schema, GetModels<Schema>, any, true>,
): SelectQueryBuilder<any, any, any> {
const relationResultName = `${parentAlias}$${relationField}`;
const relationResultName = tmpAlias(`${parentAlias}$${relationField}`);
const joinedQuery = this.buildRelationJSON(
model,
query,
Expand All @@ -56,7 +57,7 @@ export abstract class LateralJoinDialectBase<Schema extends SchemaDef> extends B

return qb.leftJoinLateral(
(eb) => {
const relationSelectName = `${resultName}$sub`;
const relationSelectName = tmpAlias(`${resultName}$sub`);
const relationModelDef = requireModel(this.schema, relationModel);

let tbl: SelectQueryBuilder<any, any, any>;
Expand Down
2 changes: 1 addition & 1 deletion packages/orm/src/client/crud/dialects/mysql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ export class MySqlCrudDialect<Schema extends SchemaDef> extends LateralJoinDiale
return this.eb.exists(
this.eb
.selectFrom(sql`JSON_TABLE(${receiver}, '$[*]' COLUMNS(value JSON PATH '$'))`.as('$items'))
.select(this.eb.lit(1).as('$t'))
.select(this.eb.lit(1).as('_'))
.where(buildFilter(this.eb.ref('$items.value'))),
);
}
Expand Down
2 changes: 1 addition & 1 deletion packages/orm/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends LateralJoinDi
return this.eb.exists(
this.eb
.selectFrom(this.eb.fn('jsonb_array_elements', [receiver]).as('$items'))
.select(this.eb.lit(1).as('$t'))
.select(this.eb.lit(1).as('_'))
.where(buildFilter(this.eb.ref('$items.value'))),
);
}
Expand Down
13 changes: 7 additions & 6 deletions packages/orm/src/client/crud/dialects/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
requireField,
requireIdFields,
requireModel,
tmpAlias,
} from '../../query-utils';
import { BaseCrudDialect } from './base-dialect';

Expand Down Expand Up @@ -201,7 +202,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
const relationModel = relationFieldDef.type as GetModels<Schema>;
const relationModelDef = requireModel(this.schema, relationModel);

const subQueryName = `${parentAlias}$${relationField}`;
const subQueryName = tmpAlias(`${parentAlias}$${relationField}`);
let tbl: SelectQueryBuilder<any, any, any>;

if (this.canJoinWithoutNestedSelect(relationModelDef, payload)) {
Expand All @@ -214,7 +215,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
// need to make a nested select on relation model
tbl = eb.selectFrom(() => {
// nested query name
const selectModelAlias = `${parentAlias}$${relationField}$sub`;
const selectModelAlias = tmpAlias(`${parentAlias}$${relationField}$sub`);

// select all fields
let selectModelQuery = this.buildModelSelect(relationModel, selectModelAlias, payload, true);
Expand Down Expand Up @@ -268,7 +269,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
const subJson = this.buildCountJson(
relationModel,
eb,
`${parentAlias}$${relationField}`,
tmpAlias(`${parentAlias}$${relationField}`),
value,
);
return [sql.lit(field), subJson];
Expand All @@ -279,7 +280,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
relationModel,
eb,
field,
`${parentAlias}$${relationField}`,
tmpAlias(`${parentAlias}$${relationField}`),
value,
);
return [sql.lit(field), subJson];
Expand All @@ -305,7 +306,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
relationModel,
eb,
field,
`${parentAlias}$${relationField}`,
tmpAlias(`${parentAlias}$${relationField}`),
value,
);
return [sql.lit(field), subJson];
Expand Down Expand Up @@ -440,7 +441,7 @@ export class SqliteCrudDialect<Schema extends SchemaDef> extends BaseCrudDialect
return this.eb.exists(
this.eb
.selectFrom(this.eb.fn('json_each', [receiver]).as('$items'))
.select(this.eb.lit(1).as('$t'))
.select(this.eb.lit(1).as('_'))
.where(buildFilter(this.eb.ref('$items.value'))),
);
}
Expand Down
10 changes: 5 additions & 5 deletions packages/orm/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -260,23 +260,23 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
.exists(
this.dialect
.buildSelectModel(model, model)
.select(sql.lit(1).as('$t'))
.select(sql.lit(1).as('_'))
.where(() => this.dialect.buildFilter(model, model, filter)),
)
.as('exists'),
.as('$exists'),
)
.modifyEnd(this.makeContextComment({ model, operation: 'read' }));

let result: { exists: number | boolean }[] = [];
let result: { $exists: number | boolean }[] = [];
const compiled = kysely.getExecutor().compileQuery(query.toOperationNode(), createQueryId());
try {
const r = await kysely.getExecutor().executeQuery(compiled);
result = r.rows as { exists: number | boolean }[];
result = r.rows as { $exists: number | boolean }[];
} catch (err) {
throw createDBQueryError(`Failed to execute query: ${err}`, err, compiled.sql, compiled.parameters);
}

return !!result[0]?.exists;
return !!result[0]?.$exists;
}

protected async read(
Expand Down
27 changes: 27 additions & 0 deletions packages/orm/src/client/executor/temp-alias-transformer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { IdentifierNode, OperationNodeTransformer, type OperationNode, type QueryId } from 'kysely';
import { TEMP_ALIAS_PREFIX } from '../query-utils';

/**
* Kysely node transformer that replaces temporary aliases created during query construction with
* shorter names while ensuring the same temp alias gets replaced with the same name.
*/
export class TempAliasTransformer extends OperationNodeTransformer {
private aliasMap = new Map<string, string>();

run<T extends OperationNode>(node: T): T {
this.aliasMap.clear();
return this.transformNode(node);
}

protected override transformIdentifier(node: IdentifierNode, queryId?: QueryId): IdentifierNode {
if (node.name.startsWith(TEMP_ALIAS_PREFIX)) {
let mapped = this.aliasMap.get(node.name);
if (!mapped) {
mapped = `$$t${this.aliasMap.size + 1}`;
this.aliasMap.set(node.name, mapped);
}
return IdentifierNode.create(mapped);
}
return super.transformIdentifier(node, queryId);
}
}
19 changes: 17 additions & 2 deletions packages/orm/src/client/executor/zenstack-query-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import { createDBQueryError, createInternalError, ORMError } from '../errors';
import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin';
import { requireIdFields, stripAlias } from '../query-utils';
import { QueryNameMapper } from './name-mapper';
import { TempAliasTransformer } from './temp-alias-transformer';
import type { ZenStackDriver } from './zenstack-driver';

type MutationQueryNode = InsertQueryNode | UpdateQueryNode | DeleteQueryNode;
Expand Down Expand Up @@ -620,10 +621,24 @@ In such cases, ZenStack cannot reliably determine the IDs of the mutated entitie
}) as string;
}

private processQueryNode<Node extends RootOperationNode>(query: Node): Node {
let result = query;
result = this.processNameMapping(result);
result = this.processTempAlias(result);
return result;
}

private processNameMapping<Node extends RootOperationNode>(query: Node): Node {
return this.nameMapper?.transformNode(query) ?? query;
}

private processTempAlias<Node extends RootOperationNode>(query: Node): Node {
if (this.options.useCompactAliasNames === false) {
return query;
}
return new TempAliasTransformer().run(query);
}

private createClientForConnection(connection: DatabaseConnection, inTx: boolean) {
const innerExecutor = this.withConnectionProvider(new SingleConnectionProvider(connection));
innerExecutor.suppressMutationHooks = true;
Expand All @@ -650,8 +665,8 @@ In such cases, ZenStack cannot reliably determine the IDs of the mutated entitie
queryId?: QueryId,
parameters?: readonly unknown[],
) {
// no need to handle mutation hooks, just proceed
const finalQuery = this.processNameMapping(query);
// run query node processors: name mapping, temp alias renaming, etc.
const finalQuery = this.processQueryNode(query);

// inherit the original queryId
let compiledQuery = this.compileQuery(finalQuery, queryId ?? createQueryId());
Expand Down
6 changes: 6 additions & 0 deletions packages/orm/src/client/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ export type ClientOptions<Schema extends SchemaDef> = QueryOptions<Schema> & {
* `@@validate`, etc. Defaults to `true`.
*/
validateInput?: boolean;

/**
* Whether to use compact alias names (e.g., "$t1", "$t2") when transforming ORM queries to SQL.
* Defaults to `true`.
*/
useCompactAliasNames?: boolean;
} & (HasComputedFields<Schema> extends true
? {
/**
Expand Down
13 changes: 13 additions & 0 deletions packages/orm/src/client/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,16 @@ export function extractFieldName(node: OperationNode) {
return undefined;
}
}

export const TEMP_ALIAS_PREFIX = '$$_';

/**
* Create an alias name for a temporary table or column name.
*/
export function tmpAlias(name: string) {
if (!name.startsWith(TEMP_ALIAS_PREFIX)) {
return `${TEMP_ALIAS_PREFIX}${name}`;
} else {
return name;
}
}
4 changes: 2 additions & 2 deletions packages/plugins/policy/src/expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

return this.transform(expr.left, {
...context,
memberSelect: SelectionNode.create(AliasNode.create(predicateResult, IdentifierNode.create('$t'))),
memberSelect: SelectionNode.create(AliasNode.create(predicateResult, IdentifierNode.create('_'))),
memberFilter: predicateFilter,
});
}
Expand Down Expand Up @@ -776,7 +776,7 @@ export class ExpressionTransformer<Schema extends SchemaDef> {

return {
...receiver,
selections: [SelectionNode.create(AliasNode.create(currNode!, IdentifierNode.create('$t')))],
selections: [SelectionNode.create(AliasNode.create(currNode!, IdentifierNode.create('_')))],
};
}

Expand Down
4 changes: 2 additions & 2 deletions packages/plugins/policy/src/policy-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -825,13 +825,13 @@ export class PolicyHandler<Schema extends SchemaDef> extends OperationNodeTransf
const queryA = eb
.selectFrom(m2m.firstModel)
.where(eb(eb.ref(`${m2m.firstModel}.${m2m.firstIdField}`), '=', aValue))
.select(() => new ExpressionWrapper(filterA).as('$t'));
.select(() => new ExpressionWrapper(filterA).as('_'));

const filterB = this.buildPolicyFilter(m2m.secondModel, undefined, 'update');
const queryB = eb
.selectFrom(m2m.secondModel)
.where(eb(eb.ref(`${m2m.secondModel}.${m2m.secondIdField}`), '=', bValue))
.select(() => new ExpressionWrapper(filterB).as('$t'));
.select(() => new ExpressionWrapper(filterB).as('_'));

// select both conditions in one query
const queryNode: SelectQueryNode = {
Expand Down
Loading
Loading