Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
resolveTableForContract,
storageTableForContract,
} from './storage-resolution';
import type { RelationCardinalityTag } from './types';
import type { IncludeThroughDescriptor, RelationCardinalityTag } from './types';

type ModelStorageFields = Record<string, { column?: string }>;
type ModelEntry = {
Expand Down Expand Up @@ -218,6 +218,7 @@ export interface ResolvedIncludeRelation {
readonly targetColumn: string;
readonly localColumn: string;
readonly cardinality: RelationCardinalityTag | undefined;
readonly through?: IncludeThroughDescriptor;
}

export function resolveIncludeRelation(
Expand All @@ -242,12 +243,28 @@ export function resolveIncludeRelation(
const localColumn = resolveFieldToColumn(contract, modelName, localField);
const targetColumn = resolveFieldToColumn(contract, relation.to, targetField);

let through: IncludeThroughDescriptor | undefined;
if (relation.through !== undefined) {
const parentLocalColumns = relation.on.localFields.map((field) =>
resolveFieldToColumn(contract, modelName, field),
);
through = {
table: relation.through.table,
namespaceId: relation.through.namespaceId,
parentColumns: relation.through.parentColumns,
childColumns: relation.through.childColumns,
targetColumns: relation.through.targetColumns,
parentLocalColumns,
};
}

return {
relatedModelName: relation.to,
relatedTableName,
targetColumn,
localColumn,
cardinality: relation.cardinality,
...(through !== undefined ? { through } : {}),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ function dispatchWithIncludes<Row>(options: {
const generator = async function* (): AsyncGenerator<Row, void, unknown> {
const { scope, release } = await acquireRuntimeScope(runtime);
try {
const parentJoinColumns = state.includes.map((include) => include.localColumn);
const parentJoinColumns = state.includes.flatMap((include) =>
include.through !== undefined ? include.through.parentLocalColumns : [include.localColumn],
);
const { selectedForQuery: parentSelectedForQuery, hiddenColumns: hiddenParentColumns } =
augmentSelectionForJoinColumns(state.selectedFields, parentJoinColumns);
const compiled = compileSelectWithIncludes(
Expand Down
1 change: 1 addition & 0 deletions packages/3-extensions/sql-orm-client/src/collection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ export class Collection<
targetColumn: relation.targetColumn,
localColumn: relation.localColumn,
cardinality: relation.cardinality,
...(relation.through !== undefined ? { through: relation.through } : {}),
nested: nestedState,
scalar: scalarSelector,
combine: combineBranches,
Expand Down
89 changes: 83 additions & 6 deletions packages/3-extensions/sql-orm-client/src/query-plan-select.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
} from '@prisma-next/sql-relational-core/ast';
import { codecRefForStorageColumn } from '@prisma-next/sql-relational-core/codec-descriptor-registry';
import type { SqlQueryPlan } from '@prisma-next/sql-relational-core/plan';
import { castAs } from '@prisma-next/utils/casts';
import { ifDefined } from '@prisma-next/utils/defined';
import {
type PolymorphismInfo,
Expand Down Expand Up @@ -292,6 +293,58 @@ function buildNestedIncludeProjections(
);
}

/**
* Build the correlated WHERE and junction JOIN artifacts for a many-to-many
* include. The resulting WHERE correlates the junction to the parent rows
* (AND-ed across all column pairs for composite keys). The junction JOIN
* connects child rows to the junction via the child columns.
*/
function buildManyToManyJunctionArtifacts(
parentTableName: string,
childTableRef: string,
through: NonNullable<IncludeExpr['through']>,
): {
readonly whereExpr: AnyExpression;
readonly junctionJoin: JoinAst;
} {
const {
table: junctionTable,
parentColumns,
childColumns,
targetColumns,
parentLocalColumns,
namespaceId,
} = through;

const joinOnPairs = childColumns.map((junctionCol, i) =>
BinaryExpr.eq(
ColumnRef.of(junctionTable, junctionCol),
ColumnRef.of(childTableRef, targetColumns[i] ?? junctionCol),
),
);
const joinOn: AnyExpression =
joinOnPairs.length === 1 ? castAs<AnyExpression>(joinOnPairs[0]!) : AndExpr.of(joinOnPairs);

const correlationPairs = parentColumns.map((junctionCol, i) =>
BinaryExpr.eq(
ColumnRef.of(junctionTable, junctionCol),
ColumnRef.of(parentTableName, parentLocalColumns[i] ?? junctionCol),
),
);
const whereExpr: AnyExpression =
correlationPairs.length === 1
? castAs<AnyExpression>(correlationPairs[0]!)
: AndExpr.of(correlationPairs);

const junctionJoin = JoinAst.inner(
TableSource.named(junctionTable, undefined, namespaceId),
joinOn,
false,
);

return { whereExpr, junctionJoin };
}

function buildIncludeChildRowsSelect(
contract: Contract<SqlStorage>,
parentTableName: string,
Expand Down Expand Up @@ -327,11 +380,25 @@ function buildIncludeChildRowsSelect(
const childWhere = buildStateWhere(contract, childTableRef, childState, {
filterTableName: include.relatedTableName,
});
const joinExpr = BinaryExpr.eq(
ColumnRef.of(childTableRef, include.targetColumn),
ColumnRef.of(parentTableName, include.localColumn),
);
const whereExpr = childWhere ? AndExpr.of([joinExpr, childWhere]) : joinExpr;

let whereExpr: AnyExpression;
let junctionJoins: JoinAst[] = [];

if (include.through !== undefined) {
const artifacts = buildManyToManyJunctionArtifacts(
parentTableName,
childTableRef,
include.through,
);
whereExpr = childWhere ? AndExpr.of([artifacts.whereExpr, childWhere]) : artifacts.whereExpr;
junctionJoins = [artifacts.junctionJoin];
} else {
const joinExpr = BinaryExpr.eq(
ColumnRef.of(childTableRef, include.targetColumn),
ColumnRef.of(parentTableName, include.localColumn),
);
whereExpr = childWhere ? AndExpr.of([joinExpr, childWhere]) : joinExpr;
}

// `distinct()` on a non-leaf include cannot be lowered as
// `SELECT DISTINCT <scalars>, json_agg(<grandchild>) FROM ...`:
Expand Down Expand Up @@ -359,6 +426,7 @@ function buildIncludeChildRowsSelect(
hiddenOrderProjection,
aggregateOrderBy,
whereExpr,
junctionJoins,
});
}

Expand Down Expand Up @@ -394,6 +462,10 @@ function buildIncludeChildRowsSelect(
.withProjection([...childProjection, ...hiddenOrderProjection])
.withWhere(whereExpr);

if (junctionJoins.length > 0) {
childRows = childRows.withJoins(junctionJoins);
}

if (childState.distinctOn && childState.distinctOn.length > 0) {
childRows = childRows.withDistinctOn(
childState.distinctOn.map((column) => ColumnRef.of(childTableRef, column)),
Expand Down Expand Up @@ -456,6 +528,7 @@ function buildDistinctNonLeafChildRowsSelect(options: {
readonly hiddenOrderProjection: ReadonlyArray<ProjectionItem>;
readonly aggregateOrderBy: ReadonlyArray<OrderByItem> | undefined;
readonly whereExpr: AnyExpression;
readonly junctionJoins: ReadonlyArray<JoinAst>;
}): {
readonly childRows: SelectAst;
readonly childProjection: ReadonlyArray<ProjectionItem>;
Expand All @@ -472,6 +545,7 @@ function buildDistinctNonLeafChildRowsSelect(options: {
hiddenOrderProjection,
aggregateOrderBy,
whereExpr,
junctionJoins,
} = options;
const childState = include.nested;

Expand Down Expand Up @@ -513,11 +587,14 @@ function buildDistinctNonLeafChildRowsSelect(options: {
selectedForQuery,
childTableRef,
);
const baseInner = SelectAst.from(
let baseInner = SelectAst.from(
tableSourceForContract(contract, include.relatedTableName, childTableAlias),
)
.withProjection([...innerScalarProjection, ...hiddenOrderProjection])
.withWhere(whereExpr);
if (junctionJoins.length > 0) {
baseInner = baseInner.withJoins(junctionJoins);
}

// `childState.distinct` is non-empty by the `isDistinctNonLeaf` guard
// at the only caller (`buildIncludeChildRowsSelect`); assert here so
Expand Down
15 changes: 15 additions & 0 deletions packages/3-extensions/sql-orm-client/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,28 @@ export interface IncludeCombine<ResultShape extends Record<string, unknown>>
readonly branches: Readonly<Record<string, IncludeCombineBranch>>;
}

export interface IncludeThroughDescriptor {
readonly table: string;
/** Namespace the junction table lives in, as declared in the contract. */
readonly namespaceId: string;
/** FK columns in the junction table that point to the parent. */
readonly parentColumns: readonly string[];
/** FK columns in the junction table that point to the target (child). */
readonly childColumns: readonly string[];
/** PK columns in the target table that the junction's childColumns reference. */
readonly targetColumns: readonly string[];
/** Resolved column names in the parent table that junction.parentColumns reference. */
readonly parentLocalColumns: readonly string[];
}

export interface IncludeExpr {
readonly relationName: string;
readonly relatedModelName: string;
readonly relatedTableName: string;
readonly targetColumn: string;
readonly localColumn: string;
readonly cardinality: RelationCardinalityTag | undefined;
readonly through?: IncludeThroughDescriptor;
readonly nested: CollectionState;
readonly scalar: IncludeScalar<unknown> | undefined;
readonly combine: Readonly<Record<string, IncludeCombineBranch>> | undefined;
Expand Down
142 changes: 142 additions & 0 deletions packages/3-extensions/sql-orm-client/test/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,148 @@ export function buildStiPolyContract(): TestContract {
return deserializeTestContract(raw);
}

type RawColumn = { nativeType: string; codecId: string; nullable: boolean; default?: unknown };

/**
* Builds a minimal M:N contract with Parent ↔ Child via a junction table.
* Used by unit tests that assert the correlated subquery shape for M:N includes.
*
* `localFields` defaults to `['id']` (single-column parent PK). For composite-key parents,
* pass `localFields: ['tenant_id', 'id']` — each entry resolves positionally to the
* corresponding junction `parentColumns` entry.
*/
export function buildManyToManyContract(opts: {
junctionTable: string;
parentColumns: string[];
childColumns: string[];
targetColumns: string[];
localFields?: string[];
extraColumns?: Record<string, RawColumn>;
}): FrameworkContract<SqlStorage> {
const {
junctionTable,
parentColumns,
childColumns,
targetColumns,
localFields = ['id'],
extraColumns = {},
} = opts;

const junctionStorageColumns: Record<string, RawColumn> = {};
for (const col of parentColumns) {
junctionStorageColumns[col] = { nativeType: 'int4', codecId: 'pg/int4@1', nullable: false };
}
for (const col of childColumns) {
junctionStorageColumns[col] = { nativeType: 'int4', codecId: 'pg/int4@1', nullable: false };
}
for (const [name, col] of Object.entries(extraColumns)) {
junctionStorageColumns[name] = col;
}

const parentStorageColumns: Record<string, RawColumn> = {};
for (const col of localFields) {
parentStorageColumns[col] = { nativeType: 'int4', codecId: 'pg/int4@1', nullable: false };
}

const parentStorageFields: Record<string, { column: string }> = {};
for (const col of localFields) {
parentStorageFields[col] = { column: col };
}

const parentFields: Record<
string,
{ nullable: boolean; type: { kind: string; codecId: string } }
> = {};
for (const col of localFields) {
parentFields[col] = { nullable: false, type: { kind: 'scalar', codecId: 'pg/int4@1' } };
}

return {
domain: {
namespaces: {
public: {
id: 'public',
models: {
Parent: {
fields: parentFields,
relations: {
children: {
to: { model: 'Child', namespace: 'public' },
cardinality: 'N:M',
on: { localFields, targetFields: targetColumns },
through: {
table: junctionTable,
parentColumns,
childColumns,
targetColumns,
},
},
},
storage: { table: 'parents', fields: parentStorageFields },
},
Child: {
fields: Object.fromEntries(
targetColumns.map((col) => [
col,
{ nullable: false, type: { kind: 'scalar', codecId: 'pg/int4@1' } },
]),
),
relations: {},
storage: {
table: 'children',
fields: Object.fromEntries(targetColumns.map((col) => [col, { column: col }])),
},
},
Junction: {
fields: {},
relations: {},
storage: { table: junctionTable, fields: {} },
},
},
},
},
},
storage: {
namespaces: {
public: {
id: 'public',
entries: {
table: {
parents: {
columns: parentStorageColumns,
primaryKey: { columns: localFields },
uniques: [],
indexes: [],
foreignKeys: [],
},
children: {
columns: Object.fromEntries(
targetColumns.map((col) => [
col,
{ nativeType: 'int4', codecId: 'pg/int4@1', nullable: false },
]),
),
primaryKey: { columns: targetColumns },
uniques: [],
indexes: [],
foreignKeys: [],
},
[junctionTable]: {
columns: junctionStorageColumns,
primaryKey: { columns: [...parentColumns, ...childColumns] },
uniques: [],
indexes: [],
foreignKeys: [],
},
},
},
},
},
},
capabilities: {},
} as unknown as FrameworkContract<SqlStorage>;
}

export function createMockRuntime(): MockRuntime {
const executions: MockExecution[] = [];
let nextResult: Record<string, unknown>[][] = [];
Expand Down
Loading
Loading