Skip to content

Commit e492c93

Browse files
authored
fix(cli): add missing opposite relation fields during db pull when multiple FKs target the same model (#2652)
1 parent 79498da commit e492c93

3 files changed

Lines changed: 162 additions & 35 deletions

File tree

packages/cli/src/actions/db.ts

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { formatDocument, ZModelCodeGenerator } from '@zenstackhq/language';
2-
import { DataModel, Enum, type Model } from '@zenstackhq/language/ast';
2+
import { DataModel, Enum, isDataField, type DataField, type Model } from '@zenstackhq/language/ast';
33
import colors from 'colors';
44
import fs from 'node:fs';
55
import path from 'node:path';
@@ -14,7 +14,7 @@ import {
1414
} from './action-utils';
1515
import { consolidateEnums, syncEnums, syncRelation, syncTable, type Relation } from './pull';
1616
import { providers as pullProviders } from './pull/provider';
17-
import { getDatasource, getDbName, getRelationFieldsKey, getRelationFkName, isDatabaseManagedAttribute } from './pull/utils';
17+
import { getDatasource, getDbName, getRelationFieldsKey, getRelationFkName, getRelationName, isDatabaseManagedAttribute } from './pull/utils';
1818
import type { DataSourceProviderType } from '@zenstackhq/schema';
1919
import { CliError } from '../cli-error';
2020

@@ -35,6 +35,25 @@ export type PullOptions = {
3535
indent: number;
3636
};
3737

38+
function hasRelationFieldsArg(field: DataField) {
39+
const relationAttr = field.attributes.find((a) => a.decl.ref?.name === '@relation');
40+
return !!relationAttr?.args.some((a) => a.name === 'fields');
41+
}
42+
43+
function getReferencedModelName(field: DataField) {
44+
return field.type.reference?.ref ? getDbName(field.type.reference.ref) : undefined;
45+
}
46+
47+
function matchesRelationNameFallback(field: DataField, relationName: string, candidate: DataField) {
48+
const referencedModelName = getReferencedModelName(field);
49+
return (
50+
!!referencedModelName &&
51+
getRelationName(candidate) === relationName &&
52+
hasRelationFieldsArg(candidate) === hasRelationFieldsArg(field) &&
53+
getReferencedModelName(candidate) === referencedModelName
54+
);
55+
}
56+
3857
/**
3958
* CLI action for db related commands
4059
*/
@@ -283,46 +302,52 @@ async function runPull(options: PullOptions) {
283302
}
284303

285304
newDataModel.fields.forEach((f) => {
286-
// Prioritized matching: exact db name > relation fields key > relation FK name > type reference
305+
// Prioritized matching: exact db name > relation fields key > relation FK name > relation name > type reference
287306
let originalFields = originalDataModel.fields.filter((d) => getDbName(d) === getDbName(f));
288307

289-
// If this is a back-reference relation field (has @relation but no `fields` arg), silently skip
290-
const isRelationField =
291-
f.$type === 'DataField' && !!(f as any).attributes?.some((a: any) => a?.decl?.ref?.name === '@relation');
292-
if (originalFields.length === 0 && isRelationField && !getRelationFieldsKey(f as any)) {
293-
return;
294-
}
295-
296308
if (originalFields.length === 0) {
297309
// Try matching by relation fields key (the `fields` attribute in @relation)
298310
// This matches relation fields by their FK field references
299-
const newFieldsKey = getRelationFieldsKey(f as any);
311+
const newFieldsKey = isDataField(f) ? getRelationFieldsKey(f) : undefined;
300312
if (newFieldsKey) {
301313
originalFields = originalDataModel.fields.filter(
302-
(d) => getRelationFieldsKey(d as any) === newFieldsKey,
314+
(d) => isDataField(d) && getRelationFieldsKey(d) === newFieldsKey,
303315
);
304316
}
305317
}
306318

307319
if (originalFields.length === 0) {
308320
// Try matching by relation FK name (the `map` attribute in @relation)
309-
originalFields = originalDataModel.fields.filter(
310-
(d) =>
311-
getRelationFkName(d as any) === getRelationFkName(f as any) &&
312-
!!getRelationFkName(d as any) &&
313-
!!getRelationFkName(f as any),
314-
);
321+
const newFkName = isDataField(f) ? getRelationFkName(f) : undefined;
322+
if (newFkName) {
323+
originalFields = originalDataModel.fields.filter(
324+
(d) => isDataField(d) && getRelationFkName(d) === newFkName,
325+
);
326+
}
327+
}
328+
329+
if (originalFields.length === 0) {
330+
// Try matching by relation name (the `name` arg in @relation)
331+
// This is essential for back-reference fields that only have a relation name
332+
const newRelName = isDataField(f) ? getRelationName(f) : undefined;
333+
if (newRelName) {
334+
originalFields = originalDataModel.fields.filter(
335+
(d) =>
336+
isDataField(d) &&
337+
isDataField(f) &&
338+
matchesRelationNameFallback(f, newRelName, d),
339+
);
340+
}
315341
}
316342

317343
if (originalFields.length === 0) {
318344
// Try matching by type reference
319345
// We need this because for relations that don't have @relation, we can only check if the original exists by the field type.
320346
// Yes, in this case it can potentially result in multiple original fields, but we only want to ensure that at least one relation exists.
321-
// In the future, we might implement some logic to detect how many of these types of relations we need and add/remove fields based on this.
322347
originalFields = originalDataModel.fields.filter(
323348
(d) =>
324-
f.$type === 'DataField' &&
325-
d.$type === 'DataField' &&
349+
isDataField(f) &&
350+
isDataField(d) &&
326351
f.type.reference?.ref &&
327352
d.type.reference?.ref &&
328353
getDbName(f.type.reference.ref) === getDbName(d.type.reference.ref),
@@ -332,7 +357,7 @@ async function runPull(options: PullOptions) {
332357
if (originalFields.length > 1) {
333358
// If this is a back-reference relation field (no `fields` attribute),
334359
// silently skip when there are multiple potential matches
335-
const isBackReferenceField = !getRelationFieldsKey(f as any);
360+
const isBackReferenceField = isDataField(f) && !getRelationFieldsKey(f);
336361
if (!isBackReferenceField) {
337362
console.warn(
338363
colors.yellow(
@@ -499,31 +524,43 @@ async function runPull(options: PullOptions) {
499524
});
500525
originalDataModel.fields
501526
.filter((f) => {
502-
// Prioritized matching: exact db name > relation fields key > relation FK name > type reference
527+
// Prioritized matching: exact db name > relation fields key > relation FK name > relation name > type reference
503528
const matchByDbName = newDataModel.fields.find((d) => getDbName(d) === getDbName(f));
504529
if (matchByDbName) return false;
505530

506531
// Try matching by relation fields key (the `fields` attribute in @relation)
507-
const originalFieldsKey = getRelationFieldsKey(f as any);
532+
const originalFieldsKey = isDataField(f) ? getRelationFieldsKey(f) : undefined;
508533
if (originalFieldsKey) {
509534
const matchByFieldsKey = newDataModel.fields.find(
510-
(d) => getRelationFieldsKey(d as any) === originalFieldsKey,
535+
(d) => isDataField(d) && getRelationFieldsKey(d) === originalFieldsKey,
511536
);
512537
if (matchByFieldsKey) return false;
513538
}
514539

515-
const matchByFkName = newDataModel.fields.find(
516-
(d) =>
517-
getRelationFkName(d as any) === getRelationFkName(f as any) &&
518-
!!getRelationFkName(d as any) &&
519-
!!getRelationFkName(f as any),
520-
);
521-
if (matchByFkName) return false;
540+
const originalFkName = isDataField(f) ? getRelationFkName(f) : undefined;
541+
if (originalFkName) {
542+
const matchByFkName = newDataModel.fields.find(
543+
(d) => isDataField(d) && getRelationFkName(d) === originalFkName,
544+
);
545+
if (matchByFkName) return false;
546+
}
547+
548+
// Try matching by relation name (for named back-reference fields)
549+
const originalRelName = isDataField(f) ? getRelationName(f) : undefined;
550+
if (originalRelName) {
551+
const matchByRelName = newDataModel.fields.find(
552+
(d) =>
553+
isDataField(d) &&
554+
isDataField(f) &&
555+
matchesRelationNameFallback(f, originalRelName, d),
556+
);
557+
if (matchByRelName) return false;
558+
}
522559

523560
const matchByTypeRef = newDataModel.fields.find(
524561
(d) =>
525-
f.$type === 'DataField' &&
526-
d.$type === 'DataField' &&
562+
isDataField(f) &&
563+
isDataField(d) &&
527564
f.type.reference?.ref &&
528565
d.type.reference?.ref &&
529566
getDbName(f.type.reference.ref) === getDbName(d.type.reference.ref),

packages/cli/src/actions/pull/utils.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import {
1414
type StringLiteral,
1515
} from '@zenstackhq/language/ast';
1616
import type { AstFactory, ExpressionBuilder } from '@zenstackhq/language/factory';
17-
import { getLiteralArray, getStringLiteral } from '@zenstackhq/language/utils';
17+
import { getAttributeArgLiteral, getLiteralArray, getStringLiteral } from '@zenstackhq/language/utils';
1818
import type { DataSourceProviderType } from '@zenstackhq/schema';
1919
import type { Reference } from 'langium';
2020
import { CliError } from '../../cli-error';
@@ -122,6 +122,19 @@ export function getRelationFkName(decl: DataField): string | undefined {
122122
return schemaAttrValue?.value;
123123
}
124124

125+
/**
126+
* Gets the relation name from the @relation attribute's `name` argument.
127+
* e.g., @relation('myRelation', fields: [...], references: [...]) -> "myRelation"
128+
* e.g., @relation(name: 'myRelation', fields: [...], references: [...]) -> "myRelation"
129+
* e.g., @relation(fields: [...], references: [...]) -> undefined
130+
* e.g., @relation('backRef') -> "backRef"
131+
*/
132+
export function getRelationName(decl: DataField): string | undefined {
133+
const relationAttr = decl?.attributes?.find((a) => a.decl?.ref?.name === '@relation');
134+
if (!relationAttr) return undefined;
135+
return getAttributeArgLiteral(relationAttr, 'name');
136+
}
137+
125138
/**
126139
* Gets the FK field names from the @relation attribute's `fields` argument.
127140
* Returns a sorted, comma-separated string of field names for comparison.

packages/cli/test/db/pull.test.ts

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,83 @@ model Tag {
152152
expect(restoredSchema).toEqual(schema);
153153
});
154154

155+
it('should restore opposite relation fields when multiple models have FKs to the same target', async () => {
156+
const { workDir, schema } = await createProject(
157+
`model Comment {
158+
id Int @id @default(autoincrement())
159+
text String
160+
commentCreatedBy User? @relation('Comment_createdByToUser', fields: [createdBy], references: [id])
161+
createdBy Int?
162+
commentUpdatedBy User? @relation('Comment_updatedByToUser', fields: [updatedBy], references: [id])
163+
updatedBy Int?
164+
}
165+
166+
model Post {
167+
id Int @id @default(autoincrement())
168+
title String
169+
postCreatedBy User? @relation('Post_createdByToUser', fields: [createdBy], references: [id])
170+
createdBy Int?
171+
postUpdatedBy User? @relation('Post_updatedByToUser', fields: [updatedBy], references: [id])
172+
updatedBy Int?
173+
}
174+
175+
model User {
176+
id Int @id @default(autoincrement())
177+
email String @unique
178+
commentCreatedBy Comment[] @relation('Comment_createdByToUser')
179+
commentUpdatedBy Comment[] @relation('Comment_updatedByToUser')
180+
postCreatedBy Post[] @relation('Post_createdByToUser')
181+
postUpdatedBy Post[] @relation('Post_updatedByToUser')
182+
}`,
183+
);
184+
runCli('db push', workDir);
185+
186+
const schemaFile = path.join(workDir, 'zenstack/schema.zmodel');
187+
188+
fs.writeFileSync(schemaFile, getDefaultPrelude());
189+
runCli('db pull --indent 4', workDir);
190+
191+
const restoredSchema = getSchema(workDir);
192+
expect(restoredSchema).toEqual(schema);
193+
});
194+
195+
it('should preserve opposite relation fields when multiple models have FKs to the same target', async () => {
196+
const { workDir, schema } = await createProject(
197+
`model Comment {
198+
id Int @id @default(autoincrement())
199+
text String
200+
commentCreatedBy User? @relation('Comment_createdByToUser', fields: [createdBy], references: [id])
201+
createdBy Int?
202+
commentUpdatedBy User? @relation('Comment_updatedByToUser', fields: [updatedBy], references: [id])
203+
updatedBy Int?
204+
}
205+
206+
model Post {
207+
id Int @id @default(autoincrement())
208+
title String
209+
postCreatedBy User? @relation('Post_createdByToUser', fields: [createdBy], references: [id])
210+
createdBy Int?
211+
postUpdatedBy User? @relation('Post_updatedByToUser', fields: [updatedBy], references: [id])
212+
updatedBy Int?
213+
}
214+
215+
model User {
216+
id Int @id @default(autoincrement())
217+
email String @unique
218+
commentCreatedBy Comment[] @relation('Comment_createdByToUser')
219+
commentUpdatedBy Comment[] @relation('Comment_updatedByToUser')
220+
postCreatedBy Post[] @relation('Post_createdByToUser')
221+
postUpdatedBy Post[] @relation('Post_updatedByToUser')
222+
}`,
223+
);
224+
runCli('db push', workDir);
225+
226+
runCli('db pull --indent 4', workDir);
227+
228+
const restoredSchema = getSchema(workDir);
229+
expect(restoredSchema).toEqual(schema);
230+
});
231+
155232
it('should restore one-to-one relation when FK is the single-column primary key', async () => {
156233
const { workDir, schema } = await createProject(
157234
`model Profile {

0 commit comments

Comments
 (0)