Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions packages/sdk/src/model-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,36 @@ export function isDelegateModel(node: AstNode) {
return isDataModel(node) && hasAttribute(node, '@@delegate');
}

/**
* Returns all fields that physically belong to a model's table: its directly declared
* fields plus fields from its mixins (recursively).
*/
export function getOwnedFields(model: DataModel | TypeDef): DataField[] {
const fields: DataField[] = [...model.fields];
for (const mixin of model.mixins) {
if (mixin.ref) {
fields.push(...getOwnedFields(mixin.ref));
}
}
return fields;
}

/**
* Returns the name of the delegate base model that "owns" the given field in the context of
* `contextModel`. This handles both direct fields of delegate models and mixin fields that
* belong to a mixin used by a delegate base model.
*/
export function getDelegateOriginModel(field: DataField, contextModel: DataModel): string | undefined {
let base = contextModel.baseModel?.ref;
while (base) {
if (isDelegateModel(base) && getOwnedFields(base).includes(field)) {
return base.name;
}
base = base.baseModel?.ref;
}
return undefined;
}

export function isUniqueField(field: DataField) {
if (hasAttribute(field, '@unique')) {
return true;
Expand Down
10 changes: 3 additions & 7 deletions packages/sdk/src/prisma/prisma-schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import {
import { AstUtils } from 'langium';
import { match } from 'ts-pattern';
import { ModelUtils } from '..';
import { DELEGATE_AUX_RELATION_PREFIX, getIdFields } from '../model-utils';
import { DELEGATE_AUX_RELATION_PREFIX, getDelegateOriginModel, getIdFields } from '../model-utils';
import {
AttributeArgValue,
ModelFieldType,
Expand Down Expand Up @@ -204,7 +204,7 @@ export class PrismaSchemaGenerator {
continue; // skip computed fields
}
// exclude non-id fields inherited from delegate
if (ModelUtils.isIdField(field, decl) || !this.isInheritedFromDelegate(field, decl)) {
if (ModelUtils.isIdField(field, decl) || !getDelegateOriginModel(field, decl)) {
this.generateModelField(model, field, decl);
}
}
Expand Down Expand Up @@ -311,7 +311,7 @@ export class PrismaSchemaGenerator {
// when building physical schema, exclude `@default` for id fields inherited from delegate base
!(
ModelUtils.isIdField(field, contextModel) &&
this.isInheritedFromDelegate(field, contextModel) &&
getDelegateOriginModel(field, contextModel) &&
attr.decl.$refText === '@default'
),
)
Expand All @@ -335,10 +335,6 @@ export class PrismaSchemaGenerator {
return AstUtils.streamAst(expr).some(isAuthInvocation);
}

private isInheritedFromDelegate(field: DataField, contextModel: DataModel) {
return field.$container !== contextModel && ModelUtils.isDelegateModel(field.$container);
}

private makeFieldAttribute(attr: DataFieldAttribute) {
const attrName = attr.decl.ref!.name;
return new PrismaFieldAttribute(
Expand Down
18 changes: 8 additions & 10 deletions packages/sdk/src/ts-schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import { ModelUtils } from '.';
import {
getAttribute,
getAuthDecl,
getDelegateOriginModel,
getIdFields,
hasAttribute,
isDelegateModel,
Expand Down Expand Up @@ -587,17 +588,14 @@ export class TsSchemaGenerator {
if (
contextModel &&
// id fields are duplicated in inherited models
!isIdField(field, contextModel) &&
field.$container !== contextModel &&
isDelegateModel(field.$container)
!isIdField(field, contextModel)
) {
// field is inherited from delegate
objectFields.push(
ts.factory.createPropertyAssignment(
'originModel',
ts.factory.createStringLiteral(field.$container.name),
),
);
const delegateOrigin = getDelegateOriginModel(field, contextModel);
if (delegateOrigin) {
objectFields.push(
ts.factory.createPropertyAssignment('originModel', ts.factory.createStringLiteral(delegateOrigin)),
);
}
}

// discriminator
Expand Down
4 changes: 4 additions & 0 deletions packages/testtools/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ export async function createTestClient(
execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', {
cwd: workDir,
stdio: options.debug ? 'inherit' : 'ignore',
env: {
...process.env,
PRISMA_USER_CONSENT_FOR_DANGEROUS_AI_ACTION: 'true',
},
});
} else {
await prepareDatabase(provider, dbName);
Expand Down
79 changes: 79 additions & 0 deletions tests/regression/test/issue-2351.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import { createPolicyTestClient } from '@zenstackhq/testtools';
import { describe, expect, it } from 'vitest';

// https://github.com/zenstackhq/zenstack/issues/2351
describe('Regression for issue 2351', () => {
it('should correctly query delegate model that inherits from a model using a mixin abstract type', async () => {
const db = await createPolicyTestClient(
`
type BaseEntity {
id String @id @default(cuid())
createdOn DateTime @default(now())
updatedOn DateTime @updatedAt
isDeleted Boolean @default(false)
isArchived Boolean @default(false)
}

enum DataType {
DataText
DataNumber
}

model RoutineData with BaseEntity {
dataType DataType
routineId String
Routine Routine @relation(fields: [routineId], references: [id])
@@delegate(dataType)
@@allow('all', auth().id == Routine.userId)
}

model Routine {
id String @id @default(cuid())
userId String
User User @relation(fields: [userId], references: [id])
data RoutineData[]
@@allow('all', true)
}

model User {
id String @id @default(cuid())
name String
routines Routine[]
@@allow('all', true)
}

model DataText extends RoutineData {
textValue String
}
`,
{ usePrismaPush: true },
);

const user = await db.user.create({
data: {
name: 'Test User',
},
});

const routine = await db.routine.create({
data: {
userId: user.id,
},
});

const authDb = db.$setAuth({ id: user.id });
const created = await authDb.dataText.create({
data: { textValue: 'hello', routineId: routine.id },
});
expect(created.textValue).toBe('hello');
expect(created.isDeleted).toBe(false);
expect(created.isArchived).toBe(false);

const found = await authDb.dataText.findUnique({
where: { id: created.id },
});
expect(found).not.toBeNull();
expect(found!.textValue).toBe('hello');
expect(found!.createdOn).toBeDefined();
});
});
Loading