Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/language/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ function nanoid(length: Int?, format: String?): String {
function ulid(format: String?): String {
} @@@expressionContext([DefaultValue])

/**
* Generates a custom identifier. The ORM client must be initialized with an
* implementation of this function.
*/
function customId(length: Int?): String {
} @@@expressionContext([DefaultValue])

/**
* Creates a sequence of integers in the underlying database and assign the incremented
* values to the ID values of the created records based on the sequence.
Expand Down
14 changes: 14 additions & 0 deletions packages/language/src/validators/function-invocation-validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,20 @@ export default class FunctionInvocationValidator implements AstValidator<Express
}
}

@func('customId')
private _checkCustomId(expr: InvocationExpr, accept: ValidationAcceptor) {
// first argument must be positive if provided
const lengthArg = expr.args[0]?.value;
if (lengthArg) {
const length = getLiteral<number>(lengthArg);
if (length !== undefined && length <= 0) {
accept('error', 'first argument must be a positive number', {
node: expr.args[0]!,
});
}
}
}

@func('auth')
private _checkAuth(expr: InvocationExpr, accept: ValidationAcceptor) {
if (!expr.$resolvedType) {
Expand Down
32 changes: 32 additions & 0 deletions packages/language/test/function-invocation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,36 @@ describe('Function Invocation Tests', () => {
);
});
});

describe('customId() length validation', () => {
it('should reject non-positive lengths', async () => {
await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(customId(0))
}
`,
'first argument must be a positive number',
);

await loadSchemaWithError(
`
datasource db {
provider = 'sqlite'
url = 'file:./dev.db'
}

model User {
id String @id @default(customId(-1))
}
`,
'first argument must be a positive number',
);
});
});
});
16 changes: 14 additions & 2 deletions packages/orm/src/client/crud/operations/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
}
if (!(field in data)) {
if (typeof fieldDef?.default === 'object' && 'kind' in fieldDef.default) {
const generated = this.evalGenerator(fieldDef.default);
const generated = this.evalGenerator(fieldDef.default, modelDef.name, field);
if (generated !== undefined) {
values[field] = this.dialect.transformInput(
generated,
Expand Down Expand Up @@ -1072,7 +1072,7 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return values;
}

private evalGenerator(defaultValue: Expression) {
private evalGenerator(defaultValue: Expression, model: string, field: string) {
if (ExpressionUtils.isCall(defaultValue)) {
const firstArgVal =
defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0])
Expand All @@ -1095,6 +1095,18 @@ export abstract class BaseOperationHandler<Schema extends SchemaDef> {
return this.formatGeneratedValue(generated, defaultValue.args?.[1]);
})
.with('ulid', () => this.formatGeneratedValue(ulid(), defaultValue.args?.[0]))
.with('customId', () => {
invariant(this.client.$options.customId, '"customId" implementation not provided');
const length = typeof firstArgVal === 'number' ? firstArgVal : undefined;
const generated = this.client.$options.customId({
client: this.client,
model,
field,
length,
});
invariant(generated && typeof generated === 'string', '"customId" must return a non-empty string');
Comment thread
sanny-io marked this conversation as resolved.
Outdated
return generated;
})
Comment thread
sanny-io marked this conversation as resolved.
.otherwise(() => undefined);
} else if (
ExpressionUtils.isMember(defaultValue) &&
Expand Down
30 changes: 30 additions & 0 deletions packages/orm/src/client/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ export type ZModelFunction<Schema extends SchemaDef> = (
context: ZModelFunctionContext<Schema>,
) => Expression<unknown>;

export type CustomIdFunctionContext<Schema extends SchemaDef> = {
/**
* ZenStack client instance.
*/
client: ClientContract<Schema>;

/**
* The model for which the ID should be generated for.
*/
model: string;

/**
* The field for which the ID should be generated for.
*/
field: string;

/**
* The length of the ID as requested by the schema.
*/
length?: number;
};

export type CustomIdFunction<Schema extends SchemaDef> = (ctx: CustomIdFunctionContext<Schema>) => string;

/**
* ZenStack client options.
*/
Expand Down Expand Up @@ -82,6 +106,12 @@ export type ClientOptions<Schema extends SchemaDef> = {
*/
validateInput?: boolean;

/**
* Implementation of a custom ID generation function, which is called from ZModel as
* `@default(customId())`.
*/
customId?: CustomIdFunction<Schema>;

/**
* Options for omitting fields in ORM query results.
*/
Expand Down
164 changes: 164 additions & 0 deletions tests/e2e/orm/client-api/custom-id.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import { createTestClient } from '@zenstackhq/testtools';
import { describe, expect, it } from 'vitest';

const schema = `
model User {
uid String @id @default(customId())
posts Post[]
}

model Post {
pid String @id @default(customId())
userId String?
user User? @relation(fields: [userId], references: [uid])
comments Comment[]
}

model Comment {
cid String @id @default(customId())
postId String?
post Post? @relation(fields: [postId], references: [pid])
}
`;

describe('customId', () => {
it('works with no arguments', async () => {
const client = await createTestClient(schema, {
customId: ({ model, field, length }) => `${model}.${field}.${length ?? 16}`,
});

await expect(client.user.create({ data: {} })).resolves.toMatchObject({
uid: 'User.uid.16',
});

await expect(client.post.create({ data: {} })).resolves.toMatchObject({
pid: 'Post.pid.16',
});

await expect(client.comment.create({ data: {} })).resolves.toMatchObject({
cid: 'Comment.cid.16',
});
});

it('works with arguments', async () => {
const schema = `
model User {
uid String @id @default(customId(8))
posts Post[]
}

model Post {
pid String @id @default(customId(8))
userId String?
user User? @relation(fields: [userId], references: [uid])
comments Comment[]
}

model Comment {
cid String @id @default(customId(8))
postId String?
post Post? @relation(fields: [postId], references: [pid])
}
`;

const client = await createTestClient(schema, {
customId: ({ model, field, length }) => `${model}.${field}.${length}`,
});

await expect(client.user.create({ data: {} })).resolves.toMatchObject({
uid: 'User.uid.8',
});

await expect(client.post.create({ data: {} })).resolves.toMatchObject({
pid: 'Post.pid.8',
});

await expect(client.comment.create({ data: {} })).resolves.toMatchObject({
cid: 'Comment.cid.8',
});
});

it('works with nested', async () => {
const client = await createTestClient(schema, {
customId: ({ model, field, length }) => `${model}.${field}.${length ?? 16}`,
});

await expect(client.user.create({
data: {
posts: {
create: {},
},
},
})).resolves.toMatchObject({
uid: 'User.uid.16',
});

await expect(client.post.findUnique({
where: {
pid: 'Post.pid.16',
}
})).resolves.toBeTruthy();
});

it('works with deeply nested', async () => {
const client = await createTestClient(schema, {
customId: ({ model, field, length }) => `${model}.${field}.${length ?? 16}`,
});

await expect(client.user.create({
data: {
posts: {
create: {
comments: {
create: {},
},
},
},
},
})).resolves.toMatchObject({
uid: 'User.uid.16',
});

await expect(client.post.findUnique({
where: {
pid: 'Post.pid.16',
}
})).resolves.toBeTruthy();

await expect(client.comment.findUnique({
where: {
cid: 'Comment.cid.16',
}
})).resolves.toBeTruthy();
});

it('rejects without an implementation', async () => {
const client = await createTestClient(schema);
await expect(client.user.create({ data: {} })).rejects.toThrowError('implementation not provided');
});

it('rejects without a valid implementation (undefined)', async () => {
// @ts-expect-error
const client = await createTestClient(schema, {
customId: () => undefined,
});
// @ts-expect-error
await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string');
});

it('rejects without a valid implementation (empty string)', async () => {
const client = await createTestClient(schema, {
customId: () => '',
});
await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string');
});

it('rejects without a valid implementation (non-string)', async () => {
// @ts-expect-error
const client = await createTestClient(schema, {
customId: () => 1,
});
// @ts-expect-error
await expect(client.user.create({ data: {} })).rejects.toThrowError('non-empty string');
});
});
Loading