diff --git a/src/decorators/enums.ts b/src/decorators/enums.ts index e798f8402..b4b8260c1 100644 --- a/src/decorators/enums.ts +++ b/src/decorators/enums.ts @@ -10,5 +10,6 @@ export function registerEnumType( name: enumConfig.name, description: enumConfig.description, valuesConfig: enumConfig.valuesConfig || {}, + directives: enumConfig.directives?.map(nameOrDefinition => ({ nameOrDefinition, args: {} })), }); } diff --git a/src/decorators/types.ts b/src/decorators/types.ts index 206c625c4..92506b068 100644 --- a/src/decorators/types.ts +++ b/src/decorators/types.ts @@ -77,6 +77,7 @@ export interface EnumConfig { name: string; description?: string; valuesConfig?: EnumValuesConfig; + directives?: string[]; } export type EnumValuesConfig = Partial< Record diff --git a/src/metadata/definitions/enum-metadata.ts b/src/metadata/definitions/enum-metadata.ts index 2fb295423..807927e6e 100644 --- a/src/metadata/definitions/enum-metadata.ts +++ b/src/metadata/definitions/enum-metadata.ts @@ -1,8 +1,10 @@ import { type EnumValuesConfig } from "@/decorators/types"; +import { type DirectiveMetadata } from "./directive-metadata"; export interface EnumMetadata { enumObj: object; name: string; description: string | undefined; valuesConfig: EnumValuesConfig; + directives?: DirectiveMetadata[]; } diff --git a/src/schema/definition-node.ts b/src/schema/definition-node.ts index f75e5dc2e..b77f33e16 100644 --- a/src/schema/definition-node.ts +++ b/src/schema/definition-node.ts @@ -2,6 +2,7 @@ import { type ConstArgumentNode, type ConstDirectiveNode, type DocumentNode, + type EnumTypeDefinitionNode, type FieldDefinitionNode, type GraphQLInputType, type GraphQLOutputType, @@ -181,3 +182,21 @@ export function getInterfaceTypeDefinitionNode( directives: directiveMetadata.map(getDirectiveNode), }; } + +export function getEnumTypeDefinitionNode( + name: string, + directiveMetadata?: DirectiveMetadata[], +): EnumTypeDefinitionNode | undefined { + if (!directiveMetadata || !directiveMetadata.length) { + return undefined; + } + + return { + kind: Kind.ENUM_TYPE_DEFINITION, + name: { + kind: Kind.NAME, + value: name, + }, + directives: directiveMetadata.map(getDirectiveNode), + }; +} diff --git a/src/schema/schema-generator.ts b/src/schema/schema-generator.ts index fd79618c3..1acf5bc12 100644 --- a/src/schema/schema-generator.ts +++ b/src/schema/schema-generator.ts @@ -56,6 +56,7 @@ import { import { ensureInstalledCorrectGraphQLPackage } from "@/utils/graphql-version"; import { BuildContext, type BuildContextOptions } from "./build-context"; import { + getEnumTypeDefinitionNode, getFieldDefinitionNode, getInputObjectTypeDefinitionNode, getInputValueDefinitionNode, @@ -255,6 +256,7 @@ export abstract class SchemaGenerator { type: new GraphQLEnumType({ name: enumMetadata.name, description: enumMetadata.description, + astNode: getEnumTypeDefinitionNode(enumMetadata.name, enumMetadata.directives), values: Object.keys(enumMap).reduce( (enumConfig, enumKey) => { const valueConfig = enumMetadata.valuesConfig[enumKey] || {}; diff --git a/tests/functional/enums.ts b/tests/functional/enums.ts index 65257a5be..59b9f0b45 100644 --- a/tests/functional/enums.ts +++ b/tests/functional/enums.ts @@ -1,5 +1,6 @@ import "reflect-metadata"; import { + type GraphQLEnumType, type GraphQLSchema, type IntrospectionEnumType, type IntrospectionInputObjectType, @@ -10,6 +11,7 @@ import { } from "graphql"; import { Arg, Field, InputType, Query, registerEnumType } from "type-graphql"; import { getMetadataStorage } from "@/metadata/getMetadataStorage"; +import { assertValidDirective } from "../helpers/directives/assertValidDirective"; import { getInnerInputFieldType, getInnerTypeOfNonNullableType, @@ -51,6 +53,15 @@ describe("Enums", () => { }, }); + enum DirectiveEnum { + Active = "ACTIVE", + Inactive = "INACTIVE", + } + registerEnumType(DirectiveEnum, { + name: "DirectiveEnum", + directives: ["@test"], + }); + @InputType() class NumberEnumInput { @Field(() => NumberEnum) @@ -88,6 +99,11 @@ describe("Enums", () => { isStringEnumEqualOne(@Arg("enum", () => StringEnum) stringEnum: StringEnum): boolean { return stringEnum === StringEnum.One; } + + @Query(() => DirectiveEnum) + getDirectiveEnumValue(): DirectiveEnum { + return DirectiveEnum.Active; + } } const schemaInfo = await getSchemaInfo({ @@ -201,6 +217,84 @@ describe("Enums", () => { "Two field deprecation reason", ); }); + + it("should properly emit directive in AST when directives are provided", async () => { + const enumType = schema.getType("DirectiveEnum") as GraphQLEnumType; + + expect(enumType).toBeDefined(); + expect(enumType.astNode).toBeDefined(); + assertValidDirective(enumType.astNode, "test"); + }); + + it("should leave astNode undefined when no directives are provided", async () => { + const enumType = schema.getType("NumberEnum") as GraphQLEnumType; + + expect(enumType).toBeDefined(); + expect(enumType.astNode).toBeUndefined(); + }); + + it("should properly emit directive with args in AST", async () => { + getMetadataStorage().clear(); + + enum ArgsDirectiveEnum { + On = "ON", + Off = "OFF", + } + registerEnumType(ArgsDirectiveEnum, { + name: "ArgsDirectiveEnum", + directives: ['@test(argNonNullDefault: "custom", argNull: "value")'], + }); + + class ArgsDirectiveEnumResolver { + @Query(() => ArgsDirectiveEnum) + getArgsDirectiveEnumValue(): ArgsDirectiveEnum { + return ArgsDirectiveEnum.On; + } + } + + const { schema: argsSchema } = await getSchemaInfo({ + resolvers: [ArgsDirectiveEnumResolver], + }); + + const enumType = argsSchema.getType("ArgsDirectiveEnum") as GraphQLEnumType; + expect(enumType.astNode).toBeDefined(); + assertValidDirective(enumType.astNode, "test", { + argNonNullDefault: `"custom"`, + argNull: `"value"`, + }); + }); + + it("should properly emit multiple directives in AST", async () => { + getMetadataStorage().clear(); + + enum MultiDirectiveEnum { + Yes = "YES", + No = "NO", + } + registerEnumType(MultiDirectiveEnum, { + name: "MultiDirectiveEnum", + directives: ["@test", '@deprecated(reason: "use something else")'], + }); + + class MultiDirectiveEnumResolver { + @Query(() => MultiDirectiveEnum) + getMultiDirectiveEnumValue(): MultiDirectiveEnum { + return MultiDirectiveEnum.Yes; + } + } + + const { schema: multiSchema } = await getSchemaInfo({ + resolvers: [MultiDirectiveEnumResolver], + }); + + const enumType = multiSchema.getType("MultiDirectiveEnum") as GraphQLEnumType; + expect(enumType.astNode).toBeDefined(); + expect(enumType.astNode!.directives).toHaveLength(2); + assertValidDirective(enumType.astNode, "test"); + assertValidDirective(enumType.astNode, "deprecated", { + reason: `"use something else"`, + }); + }); }); describe("Functional", () => { diff --git a/tests/helpers/directives/assertValidDirective.ts b/tests/helpers/directives/assertValidDirective.ts index 1fb87ab93..87ea3d852 100644 --- a/tests/helpers/directives/assertValidDirective.ts +++ b/tests/helpers/directives/assertValidDirective.ts @@ -1,4 +1,5 @@ import { + type EnumTypeDefinitionNode, type FieldDefinitionNode, type InputObjectTypeDefinitionNode, type InputValueDefinitionNode, @@ -10,6 +11,7 @@ import { type Maybe } from "@/typings"; export function assertValidDirective( astNode: Maybe< + | EnumTypeDefinitionNode | FieldDefinitionNode | ObjectTypeDefinitionNode | InputObjectTypeDefinitionNode