Skip to content

Commit adf8b00

Browse files
committed
Use Annotated to model discord Enums
1 parent 7e7fd07 commit adf8b00

6 files changed

Lines changed: 105 additions & 110 deletions

File tree

configs/europython_2025.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import discord
4+
35
from discord_guild_configurator.models import (
46
Category,
57
CommunityFeatures,
@@ -44,10 +46,10 @@
4446
ROLE_BEGINNERS_DAY,
4547
]
4648
CONFIG = GuildConfig(
47-
verification_level="medium",
48-
default_notifications="only_mentions",
49-
explicit_content_filter="all_members",
50-
preferred_locale="american_english",
49+
verification_level=discord.VerificationLevel.medium,
50+
default_notifications=discord.NotificationLevel.only_mentions,
51+
explicit_content_filter=discord.ContentFilter.all_members,
52+
preferred_locale=discord.Locale.american_english,
5153
roles=[
5254
Role(
5355
name=ROLE_COC,

discord_guild_config.schema.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,12 +540,13 @@
540540
"preferred_locale": {
541541
"enum": [
542542
"american_english",
543-
"brazil_portuguese",
544543
"british_english",
545544
"bulgarian",
546545
"chinese",
546+
"taiwan_chinese",
547547
"croatian",
548548
"czech",
549+
"indonesian",
549550
"danish",
550551
"dutch",
551552
"finnish",
@@ -554,19 +555,18 @@
554555
"greek",
555556
"hindi",
556557
"hungarian",
557-
"indonesian",
558558
"italian",
559559
"japanese",
560560
"korean",
561561
"latin_american_spanish",
562562
"lithuanian",
563563
"norwegian",
564564
"polish",
565+
"brazil_portuguese",
565566
"romanian",
566567
"russian",
567568
"spain_spanish",
568569
"swedish",
569-
"taiwan_chinese",
570570
"thai",
571571
"turkish",
572572
"ukrainian",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ fixable = [
6767
"F841", # unused variable
6868
"B007", # unused loop variable
6969
"PIE790", # unnecessary pass
70+
"Q003", # change single/double quotes
7071
]
7172

7273
[tool.ruff.lint.per-file-ignores]

scripts/update-generated-types.py

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,54 @@
1+
# ruff: noqa: F541
12
"""Generate types based on Discord API."""
23

3-
from collections import defaultdict
44
from pathlib import Path
55

66
import discord
7-
from attr import dataclass
87

98
import discord_guild_configurator
109

1110
generated_file = Path(discord_guild_configurator.__file__).parent / "generated_models.py"
1211

1312

14-
@dataclass(frozen=True)
15-
class CodegenOutput:
16-
imports: dict[str, list[str]]
17-
code_lines: list[str]
18-
19-
20-
def generate_flag_lines(flags_cls: type[discord.flags.BaseFlags]) -> CodegenOutput:
21-
return CodegenOutput(
22-
imports={"typing": ["Literal"]},
23-
code_lines=[
24-
f"{flags_cls.__name__} = Literal[",
25-
*(f' "{option}",' for option in sorted(flags_cls.VALID_FLAGS)),
26-
"]",
27-
],
28-
)
29-
30-
31-
def generate_enum_lines(enum_cls: type[discord.Enum]) -> CodegenOutput:
32-
return CodegenOutput(
33-
imports={"typing": ["Literal"]},
34-
code_lines=[
35-
f"{enum_cls.__name__} = Literal[",
36-
*(f' "{option.name}",' for option in sorted(enum_cls)),
37-
"]",
38-
],
39-
)
40-
41-
42-
def generate_lines(outputs: list[CodegenOutput]) -> list[str]:
43-
imports: dict[str, set[str]] = defaultdict(set)
44-
code_lines = []
45-
for output in outputs:
46-
for import_module, import_names in output.imports.items():
47-
imports[import_module].update(import_names)
48-
49-
code_lines.extend(output.code_lines)
50-
51-
lines: list[str] = []
52-
for import_module, import_names in sorted(imports.items()):
53-
lines.append(f"from {import_module} import {', '.join(sorted(import_names))}")
54-
lines.append("")
55-
lines.extend(code_lines)
56-
lines.append("")
57-
return lines
58-
59-
60-
outputs = [
61-
generate_flag_lines(discord.Permissions),
62-
generate_enum_lines(discord.VerificationLevel),
63-
generate_enum_lines(discord.NotificationLevel),
64-
generate_enum_lines(discord.Locale),
65-
generate_enum_lines(discord.ContentFilter),
13+
def generate_flag_lines(flags_cls: type[discord.flags.BaseFlags]) -> list[str]:
14+
return [
15+
f"{flags_cls.__name__} = Literal[",
16+
*(f' "{option}",' for option in sorted(flags_cls.VALID_FLAGS)),
17+
"]",
18+
]
19+
20+
21+
def generate_enum_annotated_lines(enum_cls: type[discord.Enum]) -> list[str]:
22+
name = enum_cls.__name__
23+
return [
24+
f"{name} = Annotated[",
25+
f" discord.{name},",
26+
f" pydantic.PlainValidator(",
27+
f" lambda value: discord.{name}[value] if isinstance(value, str) else value",
28+
f" ),",
29+
f" pydantic.PlainSerializer(lambda value: value.name),",
30+
f" pydantic.Field(",
31+
f" json_schema_extra={{",
32+
f' "enum": [option.name for option in discord.{name}],',
33+
f' "type": "string",',
34+
f" }},",
35+
f" ),",
36+
f"]",
37+
]
38+
39+
40+
lines: list[str] = [
41+
"from typing import Annotated, Literal",
42+
"",
43+
"import discord",
44+
"import pydantic",
45+
"",
46+
*generate_flag_lines(discord.Permissions),
47+
*generate_enum_annotated_lines(discord.VerificationLevel),
48+
*generate_enum_annotated_lines(discord.NotificationLevel),
49+
*generate_enum_annotated_lines(discord.Locale),
50+
*generate_enum_annotated_lines(discord.ContentFilter),
51+
"",
6652
]
6753

68-
generated_file.write_text("\n".join(generate_lines(outputs)), encoding="UTF-8", newline="\n")
54+
generated_file.write_text("\n".join(lines), encoding="UTF-8", newline="\n")

src/discord_guild_configurator/generated_models.py

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from typing import Literal
1+
from typing import Annotated, Literal
2+
3+
import discord
4+
import pydantic
25

36
Permissions = Literal[
47
"add_reactions",
@@ -58,53 +61,55 @@
5861
"view_creator_monetization_analytics",
5962
"view_guild_insights",
6063
]
61-
VerificationLevel = Literal[
62-
"none",
63-
"low",
64-
"medium",
65-
"high",
66-
"highest",
64+
VerificationLevel = Annotated[
65+
discord.VerificationLevel,
66+
pydantic.PlainValidator(
67+
lambda value: discord.VerificationLevel[value] if isinstance(value, str) else value
68+
),
69+
pydantic.PlainSerializer(lambda value: value.name),
70+
pydantic.Field(
71+
json_schema_extra={
72+
"enum": [option.name for option in discord.VerificationLevel],
73+
"type": "string",
74+
},
75+
),
6776
]
68-
NotificationLevel = Literal[
69-
"all_messages",
70-
"only_mentions",
77+
NotificationLevel = Annotated[
78+
discord.NotificationLevel,
79+
pydantic.PlainValidator(
80+
lambda value: discord.NotificationLevel[value] if isinstance(value, str) else value
81+
),
82+
pydantic.PlainSerializer(lambda value: value.name),
83+
pydantic.Field(
84+
json_schema_extra={
85+
"enum": [option.name for option in discord.NotificationLevel],
86+
"type": "string",
87+
},
88+
),
7189
]
72-
Locale = Literal[
73-
"american_english",
74-
"brazil_portuguese",
75-
"british_english",
76-
"bulgarian",
77-
"chinese",
78-
"croatian",
79-
"czech",
80-
"danish",
81-
"dutch",
82-
"finnish",
83-
"french",
84-
"german",
85-
"greek",
86-
"hindi",
87-
"hungarian",
88-
"indonesian",
89-
"italian",
90-
"japanese",
91-
"korean",
92-
"latin_american_spanish",
93-
"lithuanian",
94-
"norwegian",
95-
"polish",
96-
"romanian",
97-
"russian",
98-
"spain_spanish",
99-
"swedish",
100-
"taiwan_chinese",
101-
"thai",
102-
"turkish",
103-
"ukrainian",
104-
"vietnamese",
90+
Locale = Annotated[
91+
discord.Locale,
92+
pydantic.PlainValidator(
93+
lambda value: discord.Locale[value] if isinstance(value, str) else value
94+
),
95+
pydantic.PlainSerializer(lambda value: value.name),
96+
pydantic.Field(
97+
json_schema_extra={
98+
"enum": [option.name for option in discord.Locale],
99+
"type": "string",
100+
},
101+
),
105102
]
106-
ContentFilter = Literal[
107-
"disabled",
108-
"no_role",
109-
"all_members",
103+
ContentFilter = Annotated[
104+
discord.ContentFilter,
105+
pydantic.PlainValidator(
106+
lambda value: discord.ContentFilter[value] if isinstance(value, str) else value
107+
),
108+
pydantic.PlainSerializer(lambda value: value.name),
109+
pydantic.Field(
110+
json_schema_extra={
111+
"enum": [option.name for option in discord.ContentFilter],
112+
"type": "string",
113+
},
114+
),
110115
]

src/discord_guild_configurator/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import textwrap
55
from typing import Annotated, Literal, Self
66

7+
import discord
78
from pydantic import (
89
AfterValidator,
910
Field,
@@ -128,7 +129,7 @@ def verify_system_channel_names(self) -> Self:
128129

129130
@model_validator(mode="after")
130131
def verify_verification_level(self) -> Self:
131-
if self.community_features and self.verification_level in ("none", "low"):
132+
if self.community_features and self.verification_level < discord.VerificationLevel.medium: # type: ignore[unsupported-operator]
132133
raise ValueError(
133134
"The Community feature requires a verification level of at least medium"
134135
)

0 commit comments

Comments
 (0)