Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 73 additions & 89 deletions packages/orm/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ import { clone, enumerate, invariant, isPlainObject } from '@zenstackhq/common-h
import { default as cuid1 } from 'cuid';
import {
createQueryId,
DeleteResult,
expressionBuilder,
sql,
UpdateResult,
type Compilable,
type ExpressionBuilder,
type IsolationLevel,
Expand Down Expand Up @@ -565,7 +563,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
fromRelation.ids,
m2m.otherModel,
m2m.otherField,
createdEntity,
[createdEntity],
m2m.joinTable,
);
}
Expand Down Expand Up @@ -657,56 +655,55 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
leftEntity: any,
rightModel: string,
rightField: string,
rightEntity: any,
rightEntities: any[],
joinTable: string,
): Promise<Action extends 'connect' ? UpdateResult | undefined : DeleteResult | undefined> {
const sortedRecords = [
{
model: leftModel,
field: leftField,
entity: leftEntity,
},
{
model: rightModel,
field: rightField,
entity: rightEntity,
},
].sort((a, b) =>
// the implicit m2m join table's "A", "B" fk fields' order is determined
// by model name's sort order, and when identical (for self-relations),
// field name's sort order
a.model !== b.model ? a.model.localeCompare(b.model) : a.field.localeCompare(b.field),
);
): Promise<void> {
if (rightEntities.length === 0) {
return;
}

// the implicit m2m join table's "A", "B" fk fields' order is determined
// by model name's sort order, and when identical (for self-relations),
// field name's sort order
const leftFirst =
leftModel !== rightModel
? leftModel.localeCompare(rightModel) <= 0
: leftField.localeCompare(rightField) <= 0;

const firstIds = requireIdFields(this.schema, sortedRecords[0]!.model);
const secondIds = requireIdFields(this.schema, sortedRecords[1]!.model);
invariant(firstIds.length === 1, 'many-to-many relation must have exactly one id field');
invariant(secondIds.length === 1, 'many-to-many relation must have exactly one id field');
const leftIdField = requireIdFields(this.schema, leftModel);
const rightIdField = requireIdFields(this.schema, rightModel);
invariant(leftIdField.length === 1, 'many-to-many relation must have exactly one id field');
invariant(rightIdField.length === 1, 'many-to-many relation must have exactly one id field');

const leftIdValue = leftEntity[leftIdField[0]!];
const rightIdValues = rightEntities.map((e) => e[rightIdField[0]!]);

// Prisma's convention for many-to-many: fk fields are named "A" and "B"
if (action === 'connect') {
const result = await kysely
const values = rightIdValues.map((rightId) => ({
A: leftFirst ? leftIdValue : rightId,
B: leftFirst ? rightId : leftIdValue,
}));

await kysely
.insertInto(joinTable as any)
.values({
A: sortedRecords[0]!.entity[firstIds[0]!],
B: sortedRecords[1]!.entity[secondIds[0]!],
} as any)
.values(values as any)
// case for `INSERT IGNORE` or `ON CONFLICT DO NOTHING` syntax
.$if(this.dialect.insertIgnoreMethod === 'onConflict', (qb) =>
qb.onConflict((oc) => oc.columns(['A', 'B'] as any).doNothing()),
)
// case for `INSERT IGNORE` syntax
.$if(this.dialect.insertIgnoreMethod === 'ignore', (qb) => qb.ignore())
.execute();
return result[0] as any;
} else {
const eb = expressionBuilder<any, any>();
const result = await kysely
const [leftCol, rightCol] = leftFirst ? (['A', 'B'] as const) : (['B', 'A'] as const);

await kysely
.deleteFrom(joinTable as any)
.where(eb(`${joinTable}.A`, '=', sortedRecords[0]!.entity[firstIds[0]!]))
.where(eb(`${joinTable}.B`, '=', sortedRecords[1]!.entity[secondIds[0]!]))
.where(eb(`${joinTable}.${leftCol}`, '=', leftIdValue))
.where(eb(`${joinTable}.${rightCol}`, 'in', rightIdValues))
.execute();
return result[0] as any;
}
}

Expand Down Expand Up @@ -1921,31 +1918,22 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {

const m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field);
if (m2m) {
// handle many-to-many relation
const results: (unknown | undefined)[] = [];
for (const d of _data) {
const ids = await this.getEntityIds(kysely, model, d);
if (!ids) {
throw createNotFoundError(model);
}
const r = await this.handleManyToManyRelation(
kysely,
'connect',
fromRelation.model,
fromRelation.field,
fromRelation.ids,
m2m.otherModel!,
m2m.otherField!,
ids,
m2m.joinTable,
);
results.push(r);
}

// validate connect result
if (_data.length > results.filter((r) => !!r).length) {
// handle many-to-many relation: batch fetch all entity IDs in one query
const allIds = await this.getEntitiesIds(kysely, model, _data);
if (allIds.length !== _data.length) {
throw createNotFoundError(model);
}
await this.handleManyToManyRelation(
kysely,
'connect',
fromRelation.model,
fromRelation.field,
fromRelation.ids,
m2m.otherModel!,
m2m.otherField!,
allIds,
m2m.joinTable,
);
} else {
const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs(
this.schema,
Expand Down Expand Up @@ -2057,13 +2045,9 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {

const m2m = getManyToManyRelation(this.schema, fromRelation.model, fromRelation.field);
if (m2m) {
// handle many-to-many relation
for (const d of disconnectConditions) {
const ids = await this.getEntityIds(kysely, model, d);
if (!ids) {
// not found
return;
}
// handle many-to-many relation: batch fetch all entity IDs in one query
const allIds = await this.getEntitiesIds(kysely, model, disconnectConditions);
if (allIds.length > 0) {
await this.handleManyToManyRelation(
kysely,
'disconnect',
Expand All @@ -2072,7 +2056,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
fromRelation.ids,
m2m.otherModel,
m2m.otherField,
ids,
allIds,
m2m.joinTable,
);
}
Expand Down Expand Up @@ -2158,32 +2142,24 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
// reset for the parent
await this.resetManyToManyRelation(kysely, fromRelation.model, fromRelation.field, fromRelation.ids);

// connect new entities
const results: (unknown | undefined)[] = [];
for (const d of _data) {
const ids = await this.getEntityIds(kysely, model, d);
if (!ids) {
if (_data.length > 0) {
// batch fetch all entity IDs in one query
const allIds = await this.getEntitiesIds(kysely, model, _data);
if (allIds.length !== _data.length) {
throw createNotFoundError(model);
}
results.push(
await this.handleManyToManyRelation(
kysely,
'connect',
fromRelation.model,
fromRelation.field,
fromRelation.ids,
m2m.otherModel,
m2m.otherField,
ids,
m2m.joinTable,
),
await this.handleManyToManyRelation(
kysely,
'connect',
fromRelation.model,
fromRelation.field,
fromRelation.ids,
m2m.otherModel,
m2m.otherField,
allIds,
m2m.joinTable,
);
}

// validate connect result
if (_data.length > results.filter((r) => !!r).length) {
throw createNotFoundError(model);
}
} else {
const { ownedByModel, keyPairs } = getRelationForeignKeyFieldPairs(
this.schema,
Expand Down Expand Up @@ -2533,6 +2509,14 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
});
}

// Given multiple unique filters, load all matching entities and return their id fields in one query
private getEntitiesIds(kysely: AnyKysely, model: string, uniqueFilters: any[]) {
return this.read(kysely, model, {
where: { OR: uniqueFilters },
select: this.makeIdSelect(model),
} as any);
}

/**
* Normalize input args to strip `undefined` fields
*/
Expand Down
Loading