|
| 1 | +# ruff: noqa: F541 |
1 | 2 | """Generate types based on Discord API.""" |
2 | 3 |
|
3 | | -from collections import defaultdict |
4 | 4 | from pathlib import Path |
5 | 5 |
|
6 | 6 | import discord |
7 | | -from attr import dataclass |
8 | 7 |
|
9 | 8 | import discord_guild_configurator |
10 | 9 |
|
11 | 10 | generated_file = Path(discord_guild_configurator.__file__).parent / "generated_models.py" |
12 | 11 |
|
13 | 12 |
|
14 | | -@dataclass(frozen=True) |
15 | | -class CodegenOutput: |
16 | | - imports: dict[str, list[str]] |
17 | | - code_lines: list[str] |
18 | | - |
19 | | - |
20 | | -def generate_flag_lines(flags_cls: type[discord.flags.BaseFlags]) -> CodegenOutput: |
21 | | - return CodegenOutput( |
22 | | - imports={"typing": ["Literal"]}, |
23 | | - code_lines=[ |
24 | | - f"{flags_cls.__name__} = Literal[", |
25 | | - *(f' "{option}",' for option in sorted(flags_cls.VALID_FLAGS)), |
26 | | - "]", |
27 | | - ], |
28 | | - ) |
29 | | - |
30 | | - |
31 | | -def generate_enum_lines(enum_cls: type[discord.Enum]) -> CodegenOutput: |
32 | | - return CodegenOutput( |
33 | | - imports={"typing": ["Literal"]}, |
34 | | - code_lines=[ |
35 | | - f"{enum_cls.__name__} = Literal[", |
36 | | - *(f' "{option.name}",' for option in sorted(enum_cls)), |
37 | | - "]", |
38 | | - ], |
39 | | - ) |
40 | | - |
41 | | - |
42 | | -def generate_lines(outputs: list[CodegenOutput]) -> list[str]: |
43 | | - imports: dict[str, set[str]] = defaultdict(set) |
44 | | - code_lines = [] |
45 | | - for output in outputs: |
46 | | - for import_module, import_names in output.imports.items(): |
47 | | - imports[import_module].update(import_names) |
48 | | - |
49 | | - code_lines.extend(output.code_lines) |
50 | | - |
51 | | - lines: list[str] = [] |
52 | | - for import_module, import_names in sorted(imports.items()): |
53 | | - lines.append(f"from {import_module} import {', '.join(sorted(import_names))}") |
54 | | - lines.append("") |
55 | | - lines.extend(code_lines) |
56 | | - lines.append("") |
57 | | - return lines |
58 | | - |
59 | | - |
60 | | -outputs = [ |
61 | | - generate_flag_lines(discord.Permissions), |
62 | | - generate_enum_lines(discord.VerificationLevel), |
63 | | - generate_enum_lines(discord.NotificationLevel), |
64 | | - generate_enum_lines(discord.Locale), |
65 | | - generate_enum_lines(discord.ContentFilter), |
| 13 | +def generate_flag_lines(flags_cls: type[discord.flags.BaseFlags]) -> list[str]: |
| 14 | + return [ |
| 15 | + f"{flags_cls.__name__} = Literal[", |
| 16 | + *(f' "{option}",' for option in sorted(flags_cls.VALID_FLAGS)), |
| 17 | + "]", |
| 18 | + ] |
| 19 | + |
| 20 | + |
| 21 | +def generate_enum_annotated_lines(enum_cls: type[discord.Enum]) -> list[str]: |
| 22 | + name = enum_cls.__name__ |
| 23 | + return [ |
| 24 | + f"{name} = Annotated[", |
| 25 | + f" discord.{name},", |
| 26 | + f" pydantic.PlainValidator(", |
| 27 | + f" lambda value: discord.{name}[value] if isinstance(value, str) else value", |
| 28 | + f" ),", |
| 29 | + f" pydantic.PlainSerializer(lambda value: value.name),", |
| 30 | + f" pydantic.Field(", |
| 31 | + f" json_schema_extra={{", |
| 32 | + f' "enum": [option.name for option in discord.{name}],', |
| 33 | + f' "type": "string",', |
| 34 | + f" }},", |
| 35 | + f" ),", |
| 36 | + f"]", |
| 37 | + ] |
| 38 | + |
| 39 | + |
| 40 | +lines: list[str] = [ |
| 41 | + "from typing import Annotated, Literal", |
| 42 | + "", |
| 43 | + "import discord", |
| 44 | + "import pydantic", |
| 45 | + "", |
| 46 | + *generate_flag_lines(discord.Permissions), |
| 47 | + *generate_enum_annotated_lines(discord.VerificationLevel), |
| 48 | + *generate_enum_annotated_lines(discord.NotificationLevel), |
| 49 | + *generate_enum_annotated_lines(discord.Locale), |
| 50 | + *generate_enum_annotated_lines(discord.ContentFilter), |
| 51 | + "", |
66 | 52 | ] |
67 | 53 |
|
68 | | -generated_file.write_text("\n".join(generate_lines(outputs)), encoding="UTF-8", newline="\n") |
| 54 | +generated_file.write_text("\n".join(lines), encoding="UTF-8", newline="\n") |
0 commit comments