From 7b8b5c6e0c7f57ce431b994f4be293fd50dc68d3 Mon Sep 17 00:00:00 2001 From: omar-dulaimi Date: Mon, 21 Jul 2025 18:45:08 +0300 Subject: [PATCH 1/2] Add Swagger and relation splitting features - Add swagger flag to generate @ApiProperty decorators alongside class-validator decorators - Add separateRelationFields flag to split models into base/relation classes for better NestJS integration - Support for all Prisma types including proper Float handling with @IsNumber - Comprehensive test coverage for both new features - Updated documentation with examples and usage patterns - Backward compatible - existing schemas work unchanged Fixes #18 --- CLAUDE.md | 88 ++++++++++- README.md | 55 ++++++- src/generate-class.ts | 229 ++++++++++++++++++++++++++++- src/helpers.ts | 103 ++++++++++++- src/prisma-generator.ts | 14 +- tests/relation-splitting.test.ts | 113 ++++++++++++++ tests/schemas/full-features.prisma | 35 +++++ tests/schemas/swagger.prisma | 34 +++++ tests/swagger-generation.test.ts | 69 +++++++++ 9 files changed, 731 insertions(+), 9 deletions(-) create mode 100644 tests/relation-splitting.test.ts create mode 100644 tests/schemas/full-features.prisma create mode 100644 tests/schemas/swagger.prisma create mode 100644 tests/swagger-generation.test.ts diff --git a/CLAUDE.md b/CLAUDE.md index e2f34a7..a90c770 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -64,13 +64,99 @@ generated/ └── index.ts ``` +### Configuration Options + The generator is configured via Prisma schema: ```prisma +generator class_validator { + provider = "prisma-class-validator-generator" + output = "./generated" // optional, defaults to ./generated + swagger = "true" // optional, adds @ApiProperty decorators + separateRelationFields = "true" // optional, creates separate base/relation classes +} +``` + +#### Configuration Flags + +**`swagger`** (optional, default: `false`) +- Adds NestJS Swagger `@ApiProperty` decorators alongside class-validator decorators +- Includes type information, examples, array handling, and enum values +- Useful for automatic API documentation generation in NestJS applications + +**`separateRelationFields`** (optional, default: `false`) +- Splits models into separate base and relation classes for better NestJS integration +- Creates `ModelBase` (scalar fields only), `ModelRelations` (relations only), and combined `Model` class +- Enables use of NestJS mapped types like `PickType`, `PartialType`, etc. +- Perfect for DTOs that need to exclude relations or work with specific field subsets + +#### Example Usage + +**Basic Usage (class-validator only):** +```prisma +generator class_validator { + provider = "prisma-class-validator-generator" + output = "./generated" +} +``` +Generates: +```typescript +export class User { + @IsDefined() + @IsInt() + id!: number; + + @IsDefined() + @IsString() + email!: string; +} +``` + +**With Swagger Support:** +```prisma generator class_validator { provider = "prisma-class-validator-generator" - output = "./generated" // optional, defaults to ./generated + output = "./generated" + swagger = "true" +} +``` +Generates: +```typescript +export class User { + @IsDefined() + @ApiProperty({ example: 'Generated by autoincrement', type: "integer" }) + @IsInt() + id!: number; + + @IsDefined() + @ApiProperty({ type: "string" }) + @IsString() + email!: string; +} +``` + +**With Relation Splitting:** +```prisma +generator class_validator { + provider = "prisma-class-validator-generator" + output = "./generated" + separateRelationFields = "true" +} +``` +Generates: +- `UserBase.model.ts` - Only scalar fields with decorators +- `UserRelations.model.ts` - Only relation fields with decorators +- `User.model.ts` - Combined class extending UserBase with relations + +**Full NestJS Integration:** +```prisma +generator class_validator { + provider = "prisma-class-validator-generator" + output = "./generated" + swagger = "true" + separateRelationFields = "true" } ``` +Perfect for NestJS APIs with automatic Swagger docs and flexible DTOs. ## Modern Development Setup (Prisma 6+) diff --git a/README.md b/README.md index dcda515..ddcb95d 100644 --- a/README.md +++ b/README.md @@ -235,8 +235,10 @@ Customize the generator behavior: ```prisma generator class_validator { - provider = "prisma-class-validator-generator" - output = "./src/models" // Output directory + provider = "prisma-class-validator-generator" + output = "./src/models" // Output directory + swagger = "true" // Add Swagger decorators + separateRelationFields = "true" // Split base/relation classes } ``` @@ -245,6 +247,53 @@ generator class_validator { | Option | Type | Default | Description | |--------|------|---------|-------------| | `output` | `string` | `"./generated"` | Output directory for generated models | +| `swagger` | `string` | `"false"` | Add NestJS `@ApiProperty` decorators for Swagger docs | +| `separateRelationFields` | `string` | `"false"` | Generate separate base and relation classes for flexible DTOs | + +### 🌟 New in v6.0.0-beta.1: NestJS & Swagger Integration + +#### Swagger Support (`swagger = "true"`) + +Automatically generates NestJS Swagger decorators alongside class-validator decorators: + +```typescript +export class User { + @IsDefined() + @ApiProperty({ example: 'Generated by autoincrement', type: "integer" }) + @IsInt() + id!: number; + + @IsDefined() + @ApiProperty({ type: "string" }) + @IsString() + email!: string; + + @IsOptional() + @ApiProperty({ type: "string", required: false }) + @IsString() + name?: string | null; +} +``` + +#### Relation Field Splitting (`separateRelationFields = "true"`) + +Perfect for NestJS DTOs - generates separate classes for maximum flexibility: + +- **`UserBase.model.ts`** - Only scalar fields with validation decorators +- **`UserRelations.model.ts`** - Only relation fields +- **`User.model.ts`** - Combined class extending UserBase + +This enables powerful NestJS patterns: +```typescript +// Create DTO without relations using PickType +export class CreateUserDto extends PickType(UserBase, ['email', 'name']) {} + +// Update DTO with partial fields +export class UpdateUserDto extends PartialType(UserBase) {} + +// Full model with relations for responses +export class UserResponseDto extends User {} +``` ## 📚 Advanced Usage @@ -457,7 +506,7 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file ---
-

Made with ❤️ by the Prisma Class Validator Generator team

+

Made with ❤️ by Omar Dulaimi

⭐ Star us on GitHub🐛 Report Issues • diff --git a/src/generate-class.ts b/src/generate-class.ts index ef3311c..9a46232 100644 --- a/src/generate-class.ts +++ b/src/generate-class.ts @@ -1,25 +1,41 @@ import type { DMMF as PrismaDMMF } from '@prisma/generator-helper'; import path from 'path'; import { OptionalKind, Project, PropertyDeclarationStructure } from 'ts-morph'; +import type { GeneratorConfig } from './prisma-generator'; import { generateClassValidatorImport, generateEnumImports, generateHelpersImports, generatePrismaImport, generateRelationImportsImport, + generateSwaggerImport, getDecoratorsByFieldType, getDecoratorsImportsByType, + getSwaggerImportsByType, getTSDataTypeFromFieldType, shouldImportHelpers, shouldImportPrisma, + shouldImportSwagger, } from './helpers'; export default async function generateClass( project: Project, - outputDir: string, + config: GeneratorConfig, model: PrismaDMMF.Model, ) { - const dirPath = path.resolve(outputDir, 'models'); + if (config.separateRelationFields) { + generateSeparateRelationClasses(project, config, model); + } else { + generateSingleClass(project, config, model); + } +} + +function generateSingleClass( + project: Project, + config: GeneratorConfig, + model: PrismaDMMF.Model, +) { + const dirPath = path.resolve(config.outputDir, 'models'); const filePath = path.resolve(dirPath, `${model.name}.model.ts`); const sourceFile = project.createSourceFile(filePath, undefined, { overwrite: true, @@ -38,6 +54,12 @@ export default async function generateClass( } generateClassValidatorImport(sourceFile, validatorImports as Array); + + // Add Swagger imports if enabled + if (config.swagger && shouldImportSwagger(model.fields as PrismaDMMF.Field[])) { + const swaggerImports = getSwaggerImportsByType(model.fields as PrismaDMMF.Field[]); + generateSwaggerImport(sourceFile, swaggerImports); + } const relationImports = new Set(); model.fields.forEach((field) => { if (field.relationName && model.name !== field.type) { @@ -67,10 +89,211 @@ export default async function generateClass( hasExclamationToken: field.isRequired, hasQuestionToken: !field.isRequired, trailingTrivia: '\r\n', - decorators: getDecoratorsByFieldType(field), + decorators: getDecoratorsByFieldType(field, config.swagger), + }; + }, + ), + ], + }); +} + +function generateSeparateRelationClasses( + project: Project, + config: GeneratorConfig, + model: PrismaDMMF.Model, +) { + // Separate base fields from relation fields + const baseFields = model.fields.filter((field) => !field.relationName); + const relationFields = model.fields.filter((field) => field.relationName); + + // Generate base class (without relations) + generateBaseClass(project, config, model, baseFields); + + // Generate relation class (only relations) + if (relationFields.length > 0) { + generateRelationClass(project, config, model, relationFields); + } + + // Generate combined class that extends base and includes relations + generateCombinedClass(project, config, model, baseFields, relationFields); +} + +function generateBaseClass( + project: Project, + config: GeneratorConfig, + model: PrismaDMMF.Model, + fields: PrismaDMMF.Field[], +) { + const dirPath = path.resolve(config.outputDir, 'models'); + const filePath = path.resolve(dirPath, `${model.name}Base.model.ts`); + const sourceFile = project.createSourceFile(filePath, undefined, { + overwrite: true, + }); + + const validatorImports = [ + ...new Set( + fields + .map((field) => getDecoratorsImportsByType(field)) + .flatMap((item) => item), + ), + ]; + + if (shouldImportPrisma(fields as PrismaDMMF.Field[])) { + generatePrismaImport(sourceFile); + } + + generateClassValidatorImport(sourceFile, validatorImports as Array); + + // Add Swagger imports if enabled + if (config.swagger && shouldImportSwagger(fields as PrismaDMMF.Field[])) { + const swaggerImports = getSwaggerImportsByType(fields as PrismaDMMF.Field[]); + generateSwaggerImport(sourceFile, swaggerImports); + } + + if (shouldImportHelpers(fields as PrismaDMMF.Field[])) { + generateHelpersImports(sourceFile, ['getEnumValues']); + } + + generateEnumImports(sourceFile, fields as PrismaDMMF.Field[]); + + sourceFile.addClass({ + name: `${model.name}Base`, + isExported: true, + properties: [ + ...fields.map>( + (field) => { + return { + name: field.name, + type: getTSDataTypeFromFieldType(field), + hasExclamationToken: field.isRequired, + hasQuestionToken: !field.isRequired, + trailingTrivia: '\r\n', + decorators: getDecoratorsByFieldType(field, config.swagger), + }; + }, + ), + ], + }); +} + +function generateRelationClass( + project: Project, + config: GeneratorConfig, + model: PrismaDMMF.Model, + relationFields: PrismaDMMF.Field[], +) { + const dirPath = path.resolve(config.outputDir, 'models'); + const filePath = path.resolve(dirPath, `${model.name}Relations.model.ts`); + const sourceFile = project.createSourceFile(filePath, undefined, { + overwrite: true, + }); + + const validatorImports = [ + ...new Set( + relationFields + .map((field) => getDecoratorsImportsByType(field)) + .flatMap((item) => item), + ), + ]; + + generateClassValidatorImport(sourceFile, validatorImports as Array); + + // Add Swagger imports if enabled + if (config.swagger && shouldImportSwagger(relationFields as PrismaDMMF.Field[])) { + const swaggerImports = getSwaggerImportsByType(relationFields as PrismaDMMF.Field[]); + generateSwaggerImport(sourceFile, swaggerImports); + } + + const relationImports = new Set(); + relationFields.forEach((field) => { + if (field.relationName && model.name !== field.type) { + relationImports.add(field.type); + } + }); + + generateRelationImportsImport(sourceFile, [ + ...relationImports, + ] as Array); + + sourceFile.addClass({ + name: `${model.name}Relations`, + isExported: true, + properties: [ + ...relationFields.map>( + (field) => { + return { + name: field.name, + type: getTSDataTypeFromFieldType(field), + hasExclamationToken: field.isRequired, + hasQuestionToken: !field.isRequired, + trailingTrivia: '\r\n', + decorators: getDecoratorsByFieldType(field, config.swagger), }; }, ), ], }); } + +function generateCombinedClass( + project: Project, + config: GeneratorConfig, + model: PrismaDMMF.Model, + baseFields: PrismaDMMF.Field[], + relationFields: PrismaDMMF.Field[], +) { + const dirPath = path.resolve(config.outputDir, 'models'); + const filePath = path.resolve(dirPath, `${model.name}.model.ts`); + const sourceFile = project.createSourceFile(filePath, undefined, { + overwrite: true, + }); + + // Import base class + sourceFile.addImportDeclaration({ + moduleSpecifier: `./${model.name}Base.model`, + namedImports: [`${model.name}Base`], + }); + + // Import relation types for the combined class + const relationImports = new Set(); + relationFields.forEach((field) => { + if (field.relationName && model.name !== field.type) { + relationImports.add(field.type); + } + }); + + if (relationImports.size > 0) { + generateRelationImportsImport(sourceFile, [ + ...relationImports, + ] as Array); + } + + // Combined class extends base and includes relations + if (relationFields.length > 0) { + sourceFile.addClass({ + name: model.name, + isExported: true, + extends: `${model.name}Base`, + properties: [ + ...relationFields.map>( + (field) => { + return { + name: field.name, + type: getTSDataTypeFromFieldType(field), + hasExclamationToken: field.isRequired, + hasQuestionToken: !field.isRequired, + trailingTrivia: '\r\n', + }; + }, + ), + ], + }); + } else { + // If no relations, just extend base + sourceFile.addClass({ + name: model.name, + isExported: true, + extends: `${model.name}Base`, + }); + } +} diff --git a/src/helpers.ts b/src/helpers.ts index e238bd2..c570ebc 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -77,8 +77,18 @@ export const getTSDataTypeFromFieldType = (field: PrismaDMMF.Field) => { return type; }; -export const getDecoratorsByFieldType = (field: PrismaDMMF.Field) => { +export const getDecoratorsByFieldType = (field: PrismaDMMF.Field, includeSwagger: boolean = false) => { const decorators: OptionalKind[] = []; + + // Add Swagger decorators first if enabled + if (includeSwagger) { + const swaggerDecorator = getSwaggerDecoratorByFieldType(field); + if (swaggerDecorator) { + decorators.push(swaggerDecorator); + } + } + + // Add class-validator decorators switch (field.type) { case 'Int': decorators.push({ @@ -86,6 +96,12 @@ export const getDecoratorsByFieldType = (field: PrismaDMMF.Field) => { arguments: [], }); break; + case 'Float': + decorators.push({ + name: 'IsNumber', + arguments: [], + }); + break; case 'DateTime': decorators.push({ name: 'IsDate', @@ -125,12 +141,77 @@ export const getDecoratorsByFieldType = (field: PrismaDMMF.Field) => { return decorators; }; +export const getSwaggerDecoratorByFieldType = (field: PrismaDMMF.Field) => { + const args: string[] = []; + + // Base properties + if (field.hasDefaultValue && field.default !== null) { + if (typeof field.default === 'object' && 'name' in field.default) { + // Handle function defaults like autoincrement(), now() + args.push(`example: 'Generated by ${field.default.name}'`); + } else { + args.push(`example: ${JSON.stringify(field.default)}`); + } + } + + // Type-specific properties + switch (field.type) { + case 'Int': + args.push('type: "integer"'); + break; + case 'Float': + args.push('type: "number"'); + break; + case 'String': + args.push('type: "string"'); + break; + case 'Boolean': + args.push('type: "boolean"'); + break; + case 'DateTime': + args.push('type: "string"', 'format: "date-time"'); + break; + case 'Decimal': + args.push('type: "string"', 'description: "Decimal value as string"'); + break; + case 'Json': + args.push('type: "object"'); + break; + case 'Bytes': + args.push('type: "string"', 'format: "byte"'); + break; + } + + // Array handling + if (field.isList) { + args.push('isArray: true'); + } + + // Required/optional + if (!field.isRequired) { + args.push('required: false'); + } + + // Enum handling + if (field.kind === 'enum') { + args.push(`enum: Object.values(${field.type})`); + } + + return { + name: 'ApiProperty', + arguments: args.length > 0 ? [`{ ${args.join(', ')} }`] : [], + }; +}; + export const getDecoratorsImportsByType = (field: PrismaDMMF.Field) => { const validatorImports = new Set(); switch (field.type) { case 'Int': validatorImports.add('IsInt'); break; + case 'Float': + validatorImports.add('IsNumber'); + break; case 'DateTime': validatorImports.add('IsDate'); break; @@ -204,6 +285,26 @@ export const generateEnumImports = ( } }; +export const shouldImportSwagger = (fields: PrismaDMMF.Field[]) => { + return fields.length > 0; // Always import if we have fields and swagger is enabled +}; + +export const getSwaggerImportsByType = (fields: PrismaDMMF.Field[]) => { + const swaggerImports = new Set(['ApiProperty']); + // Add more swagger imports as needed + return [...swaggerImports]; +}; + +export const generateSwaggerImport = ( + sourceFile: SourceFile, + swaggerImports: Array, +) => { + sourceFile.addImportDeclaration({ + moduleSpecifier: '@nestjs/swagger', + namedImports: swaggerImports, + }); +}; + export function generateEnumsIndexFile( sourceFile: SourceFile, enumNames: string[], diff --git a/src/prisma-generator.ts b/src/prisma-generator.ts index 232fced..bb0d303 100644 --- a/src/prisma-generator.ts +++ b/src/prisma-generator.ts @@ -9,8 +9,20 @@ import { generateEnumsIndexFile, generateModelsIndexFile } from './helpers'; import { project } from './project'; import removeDir from './utils/removeDir'; +export interface GeneratorConfig { + outputDir: string; + swagger: boolean; + separateRelationFields: boolean; +} + export async function generate(options: GeneratorOptions) { const outputDir = parseEnvValue(options.generator.output as EnvValue); + + const config: GeneratorConfig = { + outputDir, + swagger: options.generator.config?.swagger === 'true', + separateRelationFields: options.generator.config?.separateRelationFields === 'true', + }; await fs.mkdir(outputDir, { recursive: true }); await removeDir(outputDir, true); @@ -39,7 +51,7 @@ export async function generate(options: GeneratorOptions) { } prismaClientDmmf.datamodel.models.forEach((model) => - generateClass(project, outputDir, model), + generateClass(project, config, model), ); const helpersIndexSourceFile = project.createSourceFile( diff --git a/tests/relation-splitting.test.ts b/tests/relation-splitting.test.ts new file mode 100644 index 0000000..32f43f1 --- /dev/null +++ b/tests/relation-splitting.test.ts @@ -0,0 +1,113 @@ +import { exec } from 'child_process'; +import { promisify } from 'util'; +import { existsSync, readFileSync } from 'fs'; +import { describe, it, expect, beforeAll } from 'vitest'; +import path from 'path'; + +const execAsync = promisify(exec); + +describe('Relation Splitting Generation', () => { + beforeAll(async () => { + // Build the generator first + await execAsync('npm run build'); + + // Generate models for full-features schema + const schemaPath = path.resolve(__dirname, 'schemas/full-features.prisma'); + await execAsync(`npx prisma generate --schema="${schemaPath}"`); + }, 60000); + + it('should generate separate base and relation classes', () => { + const outputPath = path.resolve(__dirname, 'generated/full-features'); + const modelsDir = path.join(outputPath, 'models'); + + // Check that all expected files are generated + expect(() => readFileSync(path.join(modelsDir, 'UserBase.model.ts'))).not.toThrow(); + expect(() => readFileSync(path.join(modelsDir, 'UserRelations.model.ts'))).not.toThrow(); + expect(() => readFileSync(path.join(modelsDir, 'User.model.ts'))).not.toThrow(); + + expect(() => readFileSync(path.join(modelsDir, 'PostBase.model.ts'))).not.toThrow(); + expect(() => readFileSync(path.join(modelsDir, 'PostRelations.model.ts'))).not.toThrow(); + expect(() => readFileSync(path.join(modelsDir, 'Post.model.ts'))).not.toThrow(); + }); + + it('should generate UserBase with only non-relation fields', () => { + const outputPath = path.resolve(__dirname, 'generated/full-features'); + const userBasePath = path.join(outputPath, 'models', 'UserBase.model.ts'); + const userBase = readFileSync(userBasePath, 'utf-8'); + + // Should contain scalar fields + expect(userBase).toContain('id!: number'); + expect(userBase).toContain('email!: string'); + expect(userBase).toContain('name?: string | null'); + + // Should NOT contain relation fields + expect(userBase).not.toContain('posts'); + + // Should have both class-validator and Swagger decorators + expect(userBase).toContain('@IsInt()'); + expect(userBase).toContain('@ApiProperty({'); + }); + + it('should generate UserRelations with only relation fields', () => { + const outputPath = path.resolve(__dirname, 'generated/full-features'); + const userRelationsPath = path.join(outputPath, 'models', 'UserRelations.model.ts'); + const userRelations = readFileSync(userRelationsPath, 'utf-8'); + + // Should contain relation fields + expect(userRelations).toContain('posts!: Post[]'); + + // Should NOT contain scalar fields + expect(userRelations).not.toContain('id!: number'); + expect(userRelations).not.toContain('email!: string'); + + // Should import related models + expect(userRelations).toContain('import { Post } from "./"'); + + // Should have decorators for relations + expect(userRelations).toContain('@ApiProperty({ isArray: true })'); + }); + + it('should generate combined User class extending base', () => { + const outputPath = path.resolve(__dirname, 'generated/full-features'); + const userPath = path.join(outputPath, 'models', 'User.model.ts'); + const user = readFileSync(userPath, 'utf-8'); + + // Should extend base class + expect(user).toContain('extends UserBase'); + + // Should import base class + expect(user).toContain('import { UserBase } from "./UserBase.model"'); + + // Should import relation types + expect(user).toContain('import { Post } from "./"'); + + // Should include relation properties + expect(user).toContain('posts!: Post[]'); + }); + + it('should generate PostBase without relation fields', () => { + const outputPath = path.resolve(__dirname, 'generated/full-features'); + const postBasePath = path.join(outputPath, 'models', 'PostBase.model.ts'); + const postBase = readFileSync(postBasePath, 'utf-8'); + + // Should contain all scalar fields including foreign key + expect(postBase).toContain('id!: number'); + expect(postBase).toContain('title!: string'); + expect(postBase).toContain('authorId?: number | null'); + expect(postBase).toContain('rating!: number'); + + // Should NOT contain relation fields (but should contain foreign key) + expect(postBase).not.toContain('author?: User'); + expect(postBase).not.toContain('import { User } from "./"'); + }); + + it('should handle models with no relations correctly', () => { + // If we had a model with no relations, it should still work + const outputPath = path.resolve(__dirname, 'generated/full-features'); + const userPath = path.join(outputPath, 'models', 'User.model.ts'); + const user = readFileSync(userPath, 'utf-8'); + + // Should be valid TypeScript + expect(user).toContain('export class User extends UserBase'); + }); +}); \ No newline at end of file diff --git a/tests/schemas/full-features.prisma b/tests/schemas/full-features.prisma new file mode 100644 index 0000000..883ba6b --- /dev/null +++ b/tests/schemas/full-features.prisma @@ -0,0 +1,35 @@ +generator client { + provider = "prisma-client-js" +} + +generator class_validator { + provider = "node ./lib/generator.js" + output = "../generated/full-features" + swagger = "true" + separateRelationFields = "true" +} + +datasource db { + provider = "sqlite" + url = "file:./test.db" +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + name String? + posts Post[] +} + +model Post { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + title String + content String? + published Boolean @default(false) + viewCount Int @default(0) + author User? @relation(fields: [authorId], references: [id]) + authorId Int? + rating Float +} \ No newline at end of file diff --git a/tests/schemas/swagger.prisma b/tests/schemas/swagger.prisma new file mode 100644 index 0000000..28d16d8 --- /dev/null +++ b/tests/schemas/swagger.prisma @@ -0,0 +1,34 @@ +generator client { + provider = "prisma-client-js" +} + +generator class_validator { + provider = "node ./lib/generator.js" + output = "../generated/swagger" + swagger = "true" +} + +datasource db { + provider = "sqlite" + url = "file:./test.db" +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + name String? + posts Post[] +} + +model Post { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + title String + content String? + published Boolean @default(false) + viewCount Int @default(0) + author User? @relation(fields: [authorId], references: [id]) + authorId Int? + rating Float +} \ No newline at end of file diff --git a/tests/swagger-generation.test.ts b/tests/swagger-generation.test.ts new file mode 100644 index 0000000..3c6c65a --- /dev/null +++ b/tests/swagger-generation.test.ts @@ -0,0 +1,69 @@ +import { exec } from 'child_process'; +import { promisify } from 'util'; +import { existsSync, readFileSync } from 'fs'; +import { describe, it, expect, beforeAll } from 'vitest'; +import path from 'path'; + +const execAsync = promisify(exec); + +describe('Swagger Generation', () => { + beforeAll(async () => { + // Build the generator first + await execAsync('npm run build'); + + // Generate models for swagger schema + const schemaPath = path.resolve(__dirname, 'schemas/swagger.prisma'); + await execAsync(`npx prisma generate --schema="${schemaPath}"`); + }, 60000); + + it('should generate models with Swagger decorators when enabled', () => { + const outputPath = path.resolve(__dirname, 'generated/swagger'); + const userModelPath = path.join(outputPath, 'models', 'User.model.ts'); + const userModel = readFileSync(userModelPath, 'utf-8'); + + // Check for Swagger imports + expect(userModel).toContain('import { ApiProperty } from "@nestjs/swagger"'); + + // Check for ApiProperty decorators + expect(userModel).toContain('@ApiProperty({'); + expect(userModel).toContain('type: "integer"'); + expect(userModel).toContain('type: "string"'); + expect(userModel).toContain('required: false'); + expect(userModel).toContain('isArray: true'); + }); + + it('should generate Post model with correct Swagger decorators', () => { + const outputPath = path.resolve(__dirname, 'generated/swagger'); + const postModelPath = path.join(outputPath, 'models', 'Post.model.ts'); + const postModel = readFileSync(postModelPath, 'utf-8'); + + // Check for DateTime format + expect(postModel).toContain('format: "date-time"'); + + // Check for boolean type + expect(postModel).toContain('type: "boolean"'); + + // Check for Float handling + expect(postModel).toContain('type: "number"'); + expect(postModel).toContain('@IsNumber()'); + + // Check for default value examples + expect(postModel).toContain('example: false'); + expect(postModel).toContain('example: 0'); + }); + + it('should include both class-validator and Swagger decorators', () => { + const outputPath = path.resolve(__dirname, 'generated/swagger'); + const userModelPath = path.join(outputPath, 'models', 'User.model.ts'); + const userModel = readFileSync(userModelPath, 'utf-8'); + + // Check for class-validator decorators + expect(userModel).toContain('@IsInt()'); + expect(userModel).toContain('@IsString()'); + expect(userModel).toContain('@IsDefined()'); + expect(userModel).toContain('@IsOptional()'); + + // Check for Swagger decorators + expect(userModel).toContain('@ApiProperty({'); + }); +}); \ No newline at end of file From b6c184ee3f4ab251d37e10b9acd4374c4be17003 Mon Sep 17 00:00:00 2001 From: omar-dulaimi Date: Mon, 21 Jul 2025 18:47:06 +0300 Subject: [PATCH 2/2] Fix code formatting with Prettier --- src/generate-class.ts | 50 ++++++++++++++---------- src/helpers.ts | 25 ++++++------ src/prisma-generator.ts | 5 ++- tests/relation-splitting.test.ts | 66 ++++++++++++++++++++------------ tests/swagger-generation.test.ts | 24 ++++++------ 5 files changed, 101 insertions(+), 69 deletions(-) diff --git a/src/generate-class.ts b/src/generate-class.ts index 9a46232..b66aaf9 100644 --- a/src/generate-class.ts +++ b/src/generate-class.ts @@ -56,8 +56,13 @@ function generateSingleClass( generateClassValidatorImport(sourceFile, validatorImports as Array); // Add Swagger imports if enabled - if (config.swagger && shouldImportSwagger(model.fields as PrismaDMMF.Field[])) { - const swaggerImports = getSwaggerImportsByType(model.fields as PrismaDMMF.Field[]); + if ( + config.swagger && + shouldImportSwagger(model.fields as PrismaDMMF.Field[]) + ) { + const swaggerImports = getSwaggerImportsByType( + model.fields as PrismaDMMF.Field[], + ); generateSwaggerImport(sourceFile, swaggerImports); } const relationImports = new Set(); @@ -105,15 +110,15 @@ function generateSeparateRelationClasses( // Separate base fields from relation fields const baseFields = model.fields.filter((field) => !field.relationName); const relationFields = model.fields.filter((field) => field.relationName); - + // Generate base class (without relations) generateBaseClass(project, config, model, baseFields); - + // Generate relation class (only relations) if (relationFields.length > 0) { generateRelationClass(project, config, model, relationFields); } - + // Generate combined class that extends base and includes relations generateCombinedClass(project, config, model, baseFields, relationFields); } @@ -146,7 +151,9 @@ function generateBaseClass( // Add Swagger imports if enabled if (config.swagger && shouldImportSwagger(fields as PrismaDMMF.Field[])) { - const swaggerImports = getSwaggerImportsByType(fields as PrismaDMMF.Field[]); + const swaggerImports = getSwaggerImportsByType( + fields as PrismaDMMF.Field[], + ); generateSwaggerImport(sourceFile, swaggerImports); } @@ -160,18 +167,16 @@ function generateBaseClass( name: `${model.name}Base`, isExported: true, properties: [ - ...fields.map>( - (field) => { - return { - name: field.name, - type: getTSDataTypeFromFieldType(field), - hasExclamationToken: field.isRequired, - hasQuestionToken: !field.isRequired, - trailingTrivia: '\r\n', - decorators: getDecoratorsByFieldType(field, config.swagger), - }; - }, - ), + ...fields.map>((field) => { + return { + name: field.name, + type: getTSDataTypeFromFieldType(field), + hasExclamationToken: field.isRequired, + hasQuestionToken: !field.isRequired, + trailingTrivia: '\r\n', + decorators: getDecoratorsByFieldType(field, config.swagger), + }; + }), ], }); } @@ -199,8 +204,13 @@ function generateRelationClass( generateClassValidatorImport(sourceFile, validatorImports as Array); // Add Swagger imports if enabled - if (config.swagger && shouldImportSwagger(relationFields as PrismaDMMF.Field[])) { - const swaggerImports = getSwaggerImportsByType(relationFields as PrismaDMMF.Field[]); + if ( + config.swagger && + shouldImportSwagger(relationFields as PrismaDMMF.Field[]) + ) { + const swaggerImports = getSwaggerImportsByType( + relationFields as PrismaDMMF.Field[], + ); generateSwaggerImport(sourceFile, swaggerImports); } diff --git a/src/helpers.ts b/src/helpers.ts index c570ebc..f6ce8b3 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -68,18 +68,21 @@ export const getTSDataTypeFromFieldType = (field: PrismaDMMF.Field) => { if (field.isList) { type = `${type}[]`; } - + // Add null union for optional fields to match Prisma client behavior if (!field.isRequired) { type = `${type} | null`; } - + return type; }; -export const getDecoratorsByFieldType = (field: PrismaDMMF.Field, includeSwagger: boolean = false) => { +export const getDecoratorsByFieldType = ( + field: PrismaDMMF.Field, + includeSwagger: boolean = false, +) => { const decorators: OptionalKind[] = []; - + // Add Swagger decorators first if enabled if (includeSwagger) { const swaggerDecorator = getSwaggerDecoratorByFieldType(field); @@ -87,7 +90,7 @@ export const getDecoratorsByFieldType = (field: PrismaDMMF.Field, includeSwagger decorators.push(swaggerDecorator); } } - + // Add class-validator decorators switch (field.type) { case 'Int': @@ -143,7 +146,7 @@ export const getDecoratorsByFieldType = (field: PrismaDMMF.Field, includeSwagger export const getSwaggerDecoratorByFieldType = (field: PrismaDMMF.Field) => { const args: string[] = []; - + // Base properties if (field.hasDefaultValue && field.default !== null) { if (typeof field.default === 'object' && 'name' in field.default) { @@ -153,7 +156,7 @@ export const getSwaggerDecoratorByFieldType = (field: PrismaDMMF.Field) => { args.push(`example: ${JSON.stringify(field.default)}`); } } - + // Type-specific properties switch (field.type) { case 'Int': @@ -181,22 +184,22 @@ export const getSwaggerDecoratorByFieldType = (field: PrismaDMMF.Field) => { args.push('type: "string"', 'format: "byte"'); break; } - + // Array handling if (field.isList) { args.push('isArray: true'); } - + // Required/optional if (!field.isRequired) { args.push('required: false'); } - + // Enum handling if (field.kind === 'enum') { args.push(`enum: Object.values(${field.type})`); } - + return { name: 'ApiProperty', arguments: args.length > 0 ? [`{ ${args.join(', ')} }`] : [], diff --git a/src/prisma-generator.ts b/src/prisma-generator.ts index bb0d303..7286784 100644 --- a/src/prisma-generator.ts +++ b/src/prisma-generator.ts @@ -17,11 +17,12 @@ export interface GeneratorConfig { export async function generate(options: GeneratorOptions) { const outputDir = parseEnvValue(options.generator.output as EnvValue); - + const config: GeneratorConfig = { outputDir, swagger: options.generator.config?.swagger === 'true', - separateRelationFields: options.generator.config?.separateRelationFields === 'true', + separateRelationFields: + options.generator.config?.separateRelationFields === 'true', }; await fs.mkdir(outputDir, { recursive: true }); await removeDir(outputDir, true); diff --git a/tests/relation-splitting.test.ts b/tests/relation-splitting.test.ts index 32f43f1..8594afd 100644 --- a/tests/relation-splitting.test.ts +++ b/tests/relation-splitting.test.ts @@ -10,7 +10,7 @@ describe('Relation Splitting Generation', () => { beforeAll(async () => { // Build the generator first await execAsync('npm run build'); - + // Generate models for full-features schema const schemaPath = path.resolve(__dirname, 'schemas/full-features.prisma'); await execAsync(`npx prisma generate --schema="${schemaPath}"`); @@ -19,30 +19,42 @@ describe('Relation Splitting Generation', () => { it('should generate separate base and relation classes', () => { const outputPath = path.resolve(__dirname, 'generated/full-features'); const modelsDir = path.join(outputPath, 'models'); - + // Check that all expected files are generated - expect(() => readFileSync(path.join(modelsDir, 'UserBase.model.ts'))).not.toThrow(); - expect(() => readFileSync(path.join(modelsDir, 'UserRelations.model.ts'))).not.toThrow(); - expect(() => readFileSync(path.join(modelsDir, 'User.model.ts'))).not.toThrow(); - - expect(() => readFileSync(path.join(modelsDir, 'PostBase.model.ts'))).not.toThrow(); - expect(() => readFileSync(path.join(modelsDir, 'PostRelations.model.ts'))).not.toThrow(); - expect(() => readFileSync(path.join(modelsDir, 'Post.model.ts'))).not.toThrow(); + expect(() => + readFileSync(path.join(modelsDir, 'UserBase.model.ts')), + ).not.toThrow(); + expect(() => + readFileSync(path.join(modelsDir, 'UserRelations.model.ts')), + ).not.toThrow(); + expect(() => + readFileSync(path.join(modelsDir, 'User.model.ts')), + ).not.toThrow(); + + expect(() => + readFileSync(path.join(modelsDir, 'PostBase.model.ts')), + ).not.toThrow(); + expect(() => + readFileSync(path.join(modelsDir, 'PostRelations.model.ts')), + ).not.toThrow(); + expect(() => + readFileSync(path.join(modelsDir, 'Post.model.ts')), + ).not.toThrow(); }); it('should generate UserBase with only non-relation fields', () => { const outputPath = path.resolve(__dirname, 'generated/full-features'); const userBasePath = path.join(outputPath, 'models', 'UserBase.model.ts'); const userBase = readFileSync(userBasePath, 'utf-8'); - + // Should contain scalar fields expect(userBase).toContain('id!: number'); expect(userBase).toContain('email!: string'); expect(userBase).toContain('name?: string | null'); - + // Should NOT contain relation fields expect(userBase).not.toContain('posts'); - + // Should have both class-validator and Swagger decorators expect(userBase).toContain('@IsInt()'); expect(userBase).toContain('@ApiProperty({'); @@ -50,19 +62,23 @@ describe('Relation Splitting Generation', () => { it('should generate UserRelations with only relation fields', () => { const outputPath = path.resolve(__dirname, 'generated/full-features'); - const userRelationsPath = path.join(outputPath, 'models', 'UserRelations.model.ts'); + const userRelationsPath = path.join( + outputPath, + 'models', + 'UserRelations.model.ts', + ); const userRelations = readFileSync(userRelationsPath, 'utf-8'); - + // Should contain relation fields expect(userRelations).toContain('posts!: Post[]'); - + // Should NOT contain scalar fields expect(userRelations).not.toContain('id!: number'); expect(userRelations).not.toContain('email!: string'); - + // Should import related models expect(userRelations).toContain('import { Post } from "./"'); - + // Should have decorators for relations expect(userRelations).toContain('@ApiProperty({ isArray: true })'); }); @@ -71,16 +87,16 @@ describe('Relation Splitting Generation', () => { const outputPath = path.resolve(__dirname, 'generated/full-features'); const userPath = path.join(outputPath, 'models', 'User.model.ts'); const user = readFileSync(userPath, 'utf-8'); - + // Should extend base class expect(user).toContain('extends UserBase'); - + // Should import base class expect(user).toContain('import { UserBase } from "./UserBase.model"'); - + // Should import relation types expect(user).toContain('import { Post } from "./"'); - + // Should include relation properties expect(user).toContain('posts!: Post[]'); }); @@ -89,13 +105,13 @@ describe('Relation Splitting Generation', () => { const outputPath = path.resolve(__dirname, 'generated/full-features'); const postBasePath = path.join(outputPath, 'models', 'PostBase.model.ts'); const postBase = readFileSync(postBasePath, 'utf-8'); - + // Should contain all scalar fields including foreign key expect(postBase).toContain('id!: number'); expect(postBase).toContain('title!: string'); expect(postBase).toContain('authorId?: number | null'); expect(postBase).toContain('rating!: number'); - + // Should NOT contain relation fields (but should contain foreign key) expect(postBase).not.toContain('author?: User'); expect(postBase).not.toContain('import { User } from "./"'); @@ -106,8 +122,8 @@ describe('Relation Splitting Generation', () => { const outputPath = path.resolve(__dirname, 'generated/full-features'); const userPath = path.join(outputPath, 'models', 'User.model.ts'); const user = readFileSync(userPath, 'utf-8'); - + // Should be valid TypeScript expect(user).toContain('export class User extends UserBase'); }); -}); \ No newline at end of file +}); diff --git a/tests/swagger-generation.test.ts b/tests/swagger-generation.test.ts index 3c6c65a..0d08d8b 100644 --- a/tests/swagger-generation.test.ts +++ b/tests/swagger-generation.test.ts @@ -10,7 +10,7 @@ describe('Swagger Generation', () => { beforeAll(async () => { // Build the generator first await execAsync('npm run build'); - + // Generate models for swagger schema const schemaPath = path.resolve(__dirname, 'schemas/swagger.prisma'); await execAsync(`npx prisma generate --schema="${schemaPath}"`); @@ -20,10 +20,12 @@ describe('Swagger Generation', () => { const outputPath = path.resolve(__dirname, 'generated/swagger'); const userModelPath = path.join(outputPath, 'models', 'User.model.ts'); const userModel = readFileSync(userModelPath, 'utf-8'); - + // Check for Swagger imports - expect(userModel).toContain('import { ApiProperty } from "@nestjs/swagger"'); - + expect(userModel).toContain( + 'import { ApiProperty } from "@nestjs/swagger"', + ); + // Check for ApiProperty decorators expect(userModel).toContain('@ApiProperty({'); expect(userModel).toContain('type: "integer"'); @@ -36,17 +38,17 @@ describe('Swagger Generation', () => { const outputPath = path.resolve(__dirname, 'generated/swagger'); const postModelPath = path.join(outputPath, 'models', 'Post.model.ts'); const postModel = readFileSync(postModelPath, 'utf-8'); - + // Check for DateTime format expect(postModel).toContain('format: "date-time"'); - + // Check for boolean type expect(postModel).toContain('type: "boolean"'); - + // Check for Float handling expect(postModel).toContain('type: "number"'); expect(postModel).toContain('@IsNumber()'); - + // Check for default value examples expect(postModel).toContain('example: false'); expect(postModel).toContain('example: 0'); @@ -56,14 +58,14 @@ describe('Swagger Generation', () => { const outputPath = path.resolve(__dirname, 'generated/swagger'); const userModelPath = path.join(outputPath, 'models', 'User.model.ts'); const userModel = readFileSync(userModelPath, 'utf-8'); - + // Check for class-validator decorators expect(userModel).toContain('@IsInt()'); expect(userModel).toContain('@IsString()'); expect(userModel).toContain('@IsDefined()'); expect(userModel).toContain('@IsOptional()'); - + // Check for Swagger decorators expect(userModel).toContain('@ApiProperty({'); }); -}); \ No newline at end of file +});