diff --git a/packages/sdk/src/model-utils.ts b/packages/sdk/src/model-utils.ts index 2473f284b..e1cc43f0c 100644 --- a/packages/sdk/src/model-utils.ts +++ b/packages/sdk/src/model-utils.ts @@ -71,6 +71,36 @@ export function isDelegateModel(node: AstNode) { return isDataModel(node) && hasAttribute(node, '@@delegate'); } +/** + * Returns all fields that physically belong to a model's table: its directly declared + * fields plus fields from its mixins (recursively). + */ +export function getOwnedFields(model: DataModel | TypeDef): DataField[] { + const fields: DataField[] = [...model.fields]; + for (const mixin of model.mixins) { + if (mixin.ref) { + fields.push(...getOwnedFields(mixin.ref)); + } + } + return fields; +} + +/** + * Returns the name of the delegate base model that "owns" the given field in the context of + * `contextModel`. This handles both direct fields of delegate models and mixin fields that + * belong to a mixin used by a delegate base model. + */ +export function getDelegateOriginModel(field: DataField, contextModel: DataModel): string | undefined { + let base = contextModel.baseModel?.ref; + while (base) { + if (isDelegateModel(base) && getOwnedFields(base).includes(field)) { + return base.name; + } + base = base.baseModel?.ref; + } + return undefined; +} + export function isUniqueField(field: DataField) { if (hasAttribute(field, '@unique')) { return true; diff --git a/packages/sdk/src/prisma/prisma-schema-generator.ts b/packages/sdk/src/prisma/prisma-schema-generator.ts index a524755da..815403cac 100644 --- a/packages/sdk/src/prisma/prisma-schema-generator.ts +++ b/packages/sdk/src/prisma/prisma-schema-generator.ts @@ -42,7 +42,7 @@ import { import { AstUtils } from 'langium'; import { match } from 'ts-pattern'; import { ModelUtils } from '..'; -import { DELEGATE_AUX_RELATION_PREFIX, getIdFields } from '../model-utils'; +import { DELEGATE_AUX_RELATION_PREFIX, getDelegateOriginModel, getIdFields } from '../model-utils'; import { AttributeArgValue, ModelFieldType, @@ -204,7 +204,7 @@ export class PrismaSchemaGenerator { continue; // skip computed fields } // exclude non-id fields inherited from delegate - if (ModelUtils.isIdField(field, decl) || !this.isInheritedFromDelegate(field, decl)) { + if (ModelUtils.isIdField(field, decl) || !getDelegateOriginModel(field, decl)) { this.generateModelField(model, field, decl); } } @@ -311,7 +311,7 @@ export class PrismaSchemaGenerator { // when building physical schema, exclude `@default` for id fields inherited from delegate base !( ModelUtils.isIdField(field, contextModel) && - this.isInheritedFromDelegate(field, contextModel) && + getDelegateOriginModel(field, contextModel) && attr.decl.$refText === '@default' ), ) @@ -335,10 +335,6 @@ export class PrismaSchemaGenerator { return AstUtils.streamAst(expr).some(isAuthInvocation); } - private isInheritedFromDelegate(field: DataField, contextModel: DataModel) { - return field.$container !== contextModel && ModelUtils.isDelegateModel(field.$container); - } - private makeFieldAttribute(attr: DataFieldAttribute) { const attrName = attr.decl.ref!.name; return new PrismaFieldAttribute( diff --git a/packages/sdk/src/ts-schema-generator.ts b/packages/sdk/src/ts-schema-generator.ts index 90f6ceafa..b10068a3f 100644 --- a/packages/sdk/src/ts-schema-generator.ts +++ b/packages/sdk/src/ts-schema-generator.ts @@ -45,6 +45,7 @@ import { ModelUtils } from '.'; import { getAttribute, getAuthDecl, + getDelegateOriginModel, getIdFields, hasAttribute, isDelegateModel, @@ -587,17 +588,14 @@ export class TsSchemaGenerator { if ( contextModel && // id fields are duplicated in inherited models - !isIdField(field, contextModel) && - field.$container !== contextModel && - isDelegateModel(field.$container) + !isIdField(field, contextModel) ) { - // field is inherited from delegate - objectFields.push( - ts.factory.createPropertyAssignment( - 'originModel', - ts.factory.createStringLiteral(field.$container.name), - ), - ); + const delegateOrigin = getDelegateOriginModel(field, contextModel); + if (delegateOrigin) { + objectFields.push( + ts.factory.createPropertyAssignment('originModel', ts.factory.createStringLiteral(delegateOrigin)), + ); + } } // discriminator diff --git a/packages/testtools/src/client.ts b/packages/testtools/src/client.ts index 69513eeb4..c9d0ca8b3 100644 --- a/packages/testtools/src/client.ts +++ b/packages/testtools/src/client.ts @@ -238,6 +238,10 @@ export async function createTestClient( execSync('npx prisma db push --schema ./schema.prisma --skip-generate --force-reset', { cwd: workDir, stdio: options.debug ? 'inherit' : 'ignore', + env: { + ...process.env, + PRISMA_USER_CONSENT_FOR_DANGEROUS_AI_ACTION: 'true', + }, }); } else { await prepareDatabase(provider, dbName); diff --git a/tests/regression/test/issue-2351.test.ts b/tests/regression/test/issue-2351.test.ts new file mode 100644 index 000000000..9d883200d --- /dev/null +++ b/tests/regression/test/issue-2351.test.ts @@ -0,0 +1,79 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +// https://github.com/zenstackhq/zenstack/issues/2351 +describe('Regression for issue 2351', () => { + it('should correctly query delegate model that inherits from a model using a mixin abstract type', async () => { + const db = await createPolicyTestClient( + ` +type BaseEntity { + id String @id @default(cuid()) + createdOn DateTime @default(now()) + updatedOn DateTime @updatedAt + isDeleted Boolean @default(false) + isArchived Boolean @default(false) +} + +enum DataType { + DataText + DataNumber +} + +model RoutineData with BaseEntity { + dataType DataType + routineId String + Routine Routine @relation(fields: [routineId], references: [id]) + @@delegate(dataType) + @@allow('all', auth().id == Routine.userId) +} + +model Routine { + id String @id @default(cuid()) + userId String + User User @relation(fields: [userId], references: [id]) + data RoutineData[] + @@allow('all', true) +} + +model User { + id String @id @default(cuid()) + name String + routines Routine[] + @@allow('all', true) +} + +model DataText extends RoutineData { + textValue String +} + `, + { usePrismaPush: true }, + ); + + const user = await db.user.create({ + data: { + name: 'Test User', + }, + }); + + const routine = await db.routine.create({ + data: { + userId: user.id, + }, + }); + + const authDb = db.$setAuth({ id: user.id }); + const created = await authDb.dataText.create({ + data: { textValue: 'hello', routineId: routine.id }, + }); + expect(created.textValue).toBe('hello'); + expect(created.isDeleted).toBe(false); + expect(created.isArchived).toBe(false); + + const found = await authDb.dataText.findUnique({ + where: { id: created.id }, + }); + expect(found).not.toBeNull(); + expect(found!.textValue).toBe('hello'); + expect(found!.createdOn).toBeDefined(); + }); +});