Skip to content

Commit 017a327

Browse files
committed
feat(sql-orm-client): M:N correlated read-through-junction include
Add many-to-many include support via a single correlated subquery that joins the child table to the junction on junction.childColumns = child.targetColumns and correlates to the parent via WHERE junction.parentColumns = parent.parentLocalColumns. Composite keys AND across all column pairs. No LATERAL joins. - IncludeExpr gains `through?: IncludeThroughDescriptor` carrying junction table name, parentColumns, childColumns, targetColumns, and parentLocalColumns. - `resolveIncludeRelation` in collection-contract.ts surfaces `through` from the contract relation when present, resolving field names to column names for the parent local columns. - `Collection.include()` propagates `through` into IncludeExpr via spread. - `buildManyToManyJunctionArtifacts` in query-plan-select.ts builds the JOIN ON expression (BinaryExpr or AndExpr over child column pairs) and the correlated WHERE (BinaryExpr or AndExpr over parent column pairs), producing a non-lateral inner JoinAst to the junction table. - `buildIncludeChildRowsSelect` detects `include.through` and uses the M:N artifacts instead of the FK equality WHERE; `buildDistinctNonLeafChildRowsSelect` receives and applies the same junction joins. - `dispatchWithIncludes` in collection-dispatch.ts forces all `through.parentLocalColumns` (not just `localColumn`) into the parent SELECT augmentation for composite M:N keys. - `buildManyToManyContract` test helper and M:N unit tests covering single-column and composite-key junction shapes, plus a FK path non-regression test. Signed-off-by: Alexey Orlenko's AI Agent <robot@aqrln.net>
1 parent b6a91c5 commit 017a327

7 files changed

Lines changed: 410 additions & 10 deletions

File tree

packages/3-extensions/sql-orm-client/src/collection-contract.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {
66
} from '@prisma-next/contract/types';
77
import type { SqlStorage, StorageTable } from '@prisma-next/sql-contract/types';
88
import { castAs } from '@prisma-next/utils/casts';
9-
import type { RelationCardinalityTag } from './types';
9+
import type { IncludeThroughDescriptor, RelationCardinalityTag } from './types';
1010

1111
type ModelStorageFields = Record<string, { column?: string }>;
1212
type ModelEntry = {
@@ -221,6 +221,7 @@ export interface ResolvedIncludeRelation {
221221
readonly targetColumn: string;
222222
readonly localColumn: string;
223223
readonly cardinality: RelationCardinalityTag | undefined;
224+
readonly through?: IncludeThroughDescriptor;
224225
}
225226

226227
export function resolveIncludeRelation(
@@ -245,12 +246,27 @@ export function resolveIncludeRelation(
245246
const localColumn = resolveFieldToColumn(contract, modelName, localField);
246247
const targetColumn = resolveFieldToColumn(contract, relation.to, targetField);
247248

249+
let through: IncludeThroughDescriptor | undefined;
250+
if (relation.through !== undefined) {
251+
const parentLocalColumns = relation.on.localFields.map((field) =>
252+
resolveFieldToColumn(contract, modelName, field),
253+
);
254+
through = {
255+
table: relation.through.table,
256+
parentColumns: relation.through.parentColumns,
257+
childColumns: relation.through.childColumns,
258+
targetColumns: relation.through.targetColumns,
259+
parentLocalColumns,
260+
};
261+
}
262+
248263
return {
249264
relatedModelName: relation.to,
250265
relatedTableName,
251266
targetColumn,
252267
localColumn,
253268
cardinality: relation.cardinality,
269+
...(through !== undefined ? { through } : {}),
254270
};
255271
}
256272

packages/3-extensions/sql-orm-client/src/collection-dispatch.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ function dispatchWithIncludes<Row>(options: {
9191
const generator = async function* (): AsyncGenerator<Row, void, unknown> {
9292
const { scope, release } = await acquireRuntimeScope(runtime);
9393
try {
94-
const parentJoinColumns = state.includes.map((include) => include.localColumn);
94+
const parentJoinColumns = state.includes.flatMap((include) =>
95+
include.through !== undefined ? include.through.parentLocalColumns : [include.localColumn],
96+
);
9597
const { selectedForQuery: parentSelectedForQuery, hiddenColumns: hiddenParentColumns } =
9698
augmentSelectionForJoinColumns(state.selectedFields, parentJoinColumns);
9799
const compiled = compileSelectWithIncludes(

packages/3-extensions/sql-orm-client/src/collection.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ export class Collection<
481481
targetColumn: relation.targetColumn,
482482
localColumn: relation.localColumn,
483483
cardinality: relation.cardinality,
484+
...(relation.through !== undefined ? { through: relation.through } : {}),
484485
nested: nestedState,
485486
scalar: scalarSelector,
486487
combine: combineBranches,

packages/3-extensions/sql-orm-client/src/query-plan-select.ts

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,53 @@ function buildNestedIncludeProjections(
288288
);
289289
}
290290

291+
/**
292+
* Build the correlated WHERE and junction JOIN artifacts for a many-to-many
293+
* include. The resulting WHERE correlates the junction to the parent rows
294+
* (AND-ed across all column pairs for composite keys). The junction JOIN
295+
* connects child rows to the junction via the child columns.
296+
*/
297+
function buildManyToManyJunctionArtifacts(
298+
parentTableName: string,
299+
childTableRef: string,
300+
through: NonNullable<IncludeExpr['through']>,
301+
): {
302+
readonly whereExpr: AnyExpression;
303+
readonly junctionJoin: JoinAst;
304+
} {
305+
const {
306+
table: junctionTable,
307+
parentColumns,
308+
childColumns,
309+
targetColumns,
310+
parentLocalColumns,
311+
} = through;
312+
313+
const joinOnPairs = childColumns.map((junctionCol, i) =>
314+
BinaryExpr.eq(
315+
ColumnRef.of(junctionTable, junctionCol),
316+
ColumnRef.of(childTableRef, targetColumns[i] ?? junctionCol),
317+
),
318+
);
319+
const joinOn: AnyExpression =
320+
joinOnPairs.length === 1 ? (joinOnPairs[0] as AnyExpression) : AndExpr.of(joinOnPairs);
321+
322+
const correlationPairs = parentColumns.map((junctionCol, i) =>
323+
BinaryExpr.eq(
324+
ColumnRef.of(junctionTable, junctionCol),
325+
ColumnRef.of(parentTableName, parentLocalColumns[i] ?? junctionCol),
326+
),
327+
);
328+
const whereExpr: AnyExpression =
329+
correlationPairs.length === 1
330+
? (correlationPairs[0] as AnyExpression)
331+
: AndExpr.of(correlationPairs);
332+
333+
const junctionJoin = JoinAst.inner(TableSource.named(junctionTable), joinOn, false);
334+
335+
return { whereExpr, junctionJoin };
336+
}
337+
291338
function buildIncludeChildRowsSelect(
292339
contract: Contract<SqlStorage>,
293340
parentTableName: string,
@@ -323,11 +370,25 @@ function buildIncludeChildRowsSelect(
323370
const childWhere = buildStateWhere(contract, childTableRef, childState, {
324371
filterTableName: include.relatedTableName,
325372
});
326-
const joinExpr = BinaryExpr.eq(
327-
ColumnRef.of(childTableRef, include.targetColumn),
328-
ColumnRef.of(parentTableName, include.localColumn),
329-
);
330-
const whereExpr = childWhere ? AndExpr.of([joinExpr, childWhere]) : joinExpr;
373+
374+
let whereExpr: AnyExpression;
375+
let junctionJoins: JoinAst[] = [];
376+
377+
if (include.through !== undefined) {
378+
const artifacts = buildManyToManyJunctionArtifacts(
379+
parentTableName,
380+
childTableRef,
381+
include.through,
382+
);
383+
whereExpr = childWhere ? AndExpr.of([artifacts.whereExpr, childWhere]) : artifacts.whereExpr;
384+
junctionJoins = [artifacts.junctionJoin];
385+
} else {
386+
const joinExpr = BinaryExpr.eq(
387+
ColumnRef.of(childTableRef, include.targetColumn),
388+
ColumnRef.of(parentTableName, include.localColumn),
389+
);
390+
whereExpr = childWhere ? AndExpr.of([joinExpr, childWhere]) : joinExpr;
391+
}
331392

332393
// `distinct()` on a non-leaf include cannot be lowered as
333394
// `SELECT DISTINCT <scalars>, json_agg(<grandchild>) FROM ...`:
@@ -355,6 +416,7 @@ function buildIncludeChildRowsSelect(
355416
hiddenOrderProjection,
356417
aggregateOrderBy,
357418
whereExpr,
419+
junctionJoins,
358420
});
359421
}
360422

@@ -388,6 +450,10 @@ function buildIncludeChildRowsSelect(
388450
.withProjection([...childProjection, ...hiddenOrderProjection])
389451
.withWhere(whereExpr);
390452

453+
if (junctionJoins.length > 0) {
454+
childRows = childRows.withJoins(junctionJoins);
455+
}
456+
391457
if (childState.distinctOn && childState.distinctOn.length > 0) {
392458
childRows = childRows.withDistinctOn(
393459
childState.distinctOn.map((column) => ColumnRef.of(childTableRef, column)),
@@ -450,6 +516,7 @@ function buildDistinctNonLeafChildRowsSelect(options: {
450516
readonly hiddenOrderProjection: ReadonlyArray<ProjectionItem>;
451517
readonly aggregateOrderBy: ReadonlyArray<OrderByItem> | undefined;
452518
readonly whereExpr: AnyExpression;
519+
readonly junctionJoins: ReadonlyArray<JoinAst>;
453520
}): {
454521
readonly childRows: SelectAst;
455522
readonly childProjection: ReadonlyArray<ProjectionItem>;
@@ -466,6 +533,7 @@ function buildDistinctNonLeafChildRowsSelect(options: {
466533
hiddenOrderProjection,
467534
aggregateOrderBy,
468535
whereExpr,
536+
junctionJoins,
469537
} = options;
470538
const childState = include.nested;
471539

@@ -507,9 +575,12 @@ function buildDistinctNonLeafChildRowsSelect(options: {
507575
selectedForQuery,
508576
childTableRef,
509577
);
510-
const baseInner = SelectAst.from(TableSource.named(include.relatedTableName, childTableAlias))
578+
let baseInner = SelectAst.from(TableSource.named(include.relatedTableName, childTableAlias))
511579
.withProjection([...innerScalarProjection, ...hiddenOrderProjection])
512580
.withWhere(whereExpr);
581+
if (junctionJoins.length > 0) {
582+
baseInner = baseInner.withJoins(junctionJoins);
583+
}
513584

514585
// `childState.distinct` is non-empty by the `isDistinctNonLeaf` guard
515586
// at the only caller (`buildIncludeChildRowsSelect`); assert here so

packages/3-extensions/sql-orm-client/src/types.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,26 @@ export interface IncludeCombine<ResultShape extends Record<string, unknown>>
5050
readonly branches: Readonly<Record<string, IncludeCombineBranch>>;
5151
}
5252

53+
export interface IncludeThroughDescriptor {
54+
readonly table: string;
55+
/** FK columns in the junction table that point to the parent. */
56+
readonly parentColumns: readonly string[];
57+
/** FK columns in the junction table that point to the target (child). */
58+
readonly childColumns: readonly string[];
59+
/** PK columns in the target table that the junction's childColumns reference. */
60+
readonly targetColumns: readonly string[];
61+
/** Resolved column names in the parent table that junction.parentColumns reference. */
62+
readonly parentLocalColumns: readonly string[];
63+
}
64+
5365
export interface IncludeExpr {
5466
readonly relationName: string;
5567
readonly relatedModelName: string;
5668
readonly relatedTableName: string;
5769
readonly targetColumn: string;
5870
readonly localColumn: string;
5971
readonly cardinality: RelationCardinalityTag | undefined;
72+
readonly through?: IncludeThroughDescriptor;
6073
readonly nested: CollectionState;
6174
readonly scalar: IncludeScalar<unknown> | undefined;
6275
readonly combine: Readonly<Record<string, IncludeCombineBranch>> | undefined;

packages/3-extensions/sql-orm-client/test/helpers.ts

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,146 @@ export function buildStiPolyContract(): TestContract {
288288
return raw as TestContract;
289289
}
290290

291+
type RawColumn = { nativeType: string; codecId: string; nullable: boolean; default?: unknown };
292+
293+
/**
294+
* Builds a minimal M:N contract with Parent ↔ Child via a junction table.
295+
* Used by unit tests that assert the correlated subquery shape for M:N includes.
296+
*
297+
* `localFields` defaults to `['id']` (single-column parent PK). For composite-key parents,
298+
* pass `localFields: ['tenant_id', 'id']` — each entry resolves positionally to the
299+
* corresponding junction `parentColumns` entry.
300+
*/
301+
export function buildManyToManyContract(opts: {
302+
junctionTable: string;
303+
parentColumns: string[];
304+
childColumns: string[];
305+
targetColumns: string[];
306+
localFields?: string[];
307+
extraColumns?: Record<string, RawColumn>;
308+
}): FrameworkContract<SqlStorage> {
309+
const {
310+
junctionTable,
311+
parentColumns,
312+
childColumns,
313+
targetColumns,
314+
localFields = ['id'],
315+
extraColumns = {},
316+
} = opts;
317+
318+
const junctionStorageColumns: Record<string, RawColumn> = {};
319+
for (const col of parentColumns) {
320+
junctionStorageColumns[col] = { nativeType: 'int4', codecId: 'pg/int4@1', nullable: false };
321+
}
322+
for (const col of childColumns) {
323+
junctionStorageColumns[col] = { nativeType: 'int4', codecId: 'pg/int4@1', nullable: false };
324+
}
325+
for (const [name, col] of Object.entries(extraColumns)) {
326+
junctionStorageColumns[name] = col;
327+
}
328+
329+
const parentStorageColumns: Record<string, RawColumn> = {};
330+
for (const col of localFields) {
331+
parentStorageColumns[col] = { nativeType: 'int4', codecId: 'pg/int4@1', nullable: false };
332+
}
333+
334+
const parentStorageFields: Record<string, { column: string }> = {};
335+
for (const col of localFields) {
336+
parentStorageFields[col] = { column: col };
337+
}
338+
339+
const parentFields: Record<
340+
string,
341+
{ nullable: boolean; type: { kind: string; codecId: string } }
342+
> = {};
343+
for (const col of localFields) {
344+
parentFields[col] = { nullable: false, type: { kind: 'scalar', codecId: 'pg/int4@1' } };
345+
}
346+
347+
return {
348+
domain: {
349+
namespaces: {
350+
public: {
351+
id: 'public',
352+
models: {
353+
Parent: {
354+
fields: parentFields,
355+
relations: {
356+
children: {
357+
to: { model: 'Child', namespace: 'public' },
358+
cardinality: 'N:M',
359+
on: { localFields, targetFields: targetColumns },
360+
through: {
361+
table: junctionTable,
362+
parentColumns,
363+
childColumns,
364+
targetColumns,
365+
},
366+
},
367+
},
368+
storage: { table: 'parents', fields: parentStorageFields },
369+
},
370+
Child: {
371+
fields: Object.fromEntries(
372+
targetColumns.map((col) => [
373+
col,
374+
{ nullable: false, type: { kind: 'scalar', codecId: 'pg/int4@1' } },
375+
]),
376+
),
377+
relations: {},
378+
storage: {
379+
table: 'children',
380+
fields: Object.fromEntries(targetColumns.map((col) => [col, { column: col }])),
381+
},
382+
},
383+
Junction: {
384+
fields: {},
385+
relations: {},
386+
storage: { table: junctionTable, fields: {} },
387+
},
388+
},
389+
},
390+
},
391+
},
392+
storage: {
393+
namespaces: {
394+
public: {
395+
id: 'public',
396+
tables: {
397+
parents: {
398+
columns: parentStorageColumns,
399+
primaryKey: { columns: localFields },
400+
uniques: [],
401+
indexes: [],
402+
foreignKeys: [],
403+
},
404+
children: {
405+
columns: Object.fromEntries(
406+
targetColumns.map((col) => [
407+
col,
408+
{ nativeType: 'int4', codecId: 'pg/int4@1', nullable: false },
409+
]),
410+
),
411+
primaryKey: { columns: targetColumns },
412+
uniques: [],
413+
indexes: [],
414+
foreignKeys: [],
415+
},
416+
[junctionTable]: {
417+
columns: junctionStorageColumns,
418+
primaryKey: { columns: [...parentColumns, ...childColumns] },
419+
uniques: [],
420+
indexes: [],
421+
foreignKeys: [],
422+
},
423+
},
424+
},
425+
},
426+
},
427+
capabilities: {},
428+
} as unknown as FrameworkContract<SqlStorage>;
429+
}
430+
291431
export function createMockRuntime(): MockRuntime {
292432
const executions: MockExecution[] = [];
293433
let nextResult: Record<string, unknown>[][] = [];

0 commit comments

Comments
 (0)