Skip to content

Commit b276b54

Browse files
committed
refactor: validate RuleTrigger at load time using a pydantic discriminated union
Switch MessageRuleTrigger and ReactionRuleTrigger to pydantic BaseModel with a Literal interaction_type discriminator field. _load_rules now calls TypeAdapter.validate_python, which handles dispatch and validation in one step. This removes all None-guards and type: ignore comments from on_message and on_reaction_add.
1 parent 92bab70 commit b276b54

1 file changed

Lines changed: 30 additions & 31 deletions

File tree

bot/exts/levels/_cog.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import random
33
import re
44
import tomllib
5-
from dataclasses import dataclass
65
from pathlib import Path
7-
from typing import Literal, cast
6+
from typing import Annotated, Literal, cast
87

98
import discord
109
from async_rediscache import RedisCache
1110
from discord.ext import commands, tasks
11+
from pydantic import BaseModel, Field, TypeAdapter, ValidationError
1212
from pydis_core.utils.logging import get_logger
1313

1414
from bot import constants
@@ -67,11 +67,11 @@ def __init__(self, bot: SirRobin):
6767
self.active_rules_num = 3
6868
self.anti_active_rules_num = 1
6969

70-
self.active_reaction_rule_triggers: list[RuleTrigger] = []
71-
self.active_message_rule_triggers: list[RuleTrigger] = []
72-
self.anti_active_message_rule_triggers: list[RuleTrigger] = []
73-
self.anti_active_reaction_rule_triggers: list[RuleTrigger] = []
74-
self.all_message_rule_triggers: list[RuleTrigger] = []
70+
self.active_reaction_rule_triggers: list[ReactionRuleTrigger] = []
71+
self.active_message_rule_triggers: list[MessageRuleTrigger] = []
72+
self.anti_active_message_rule_triggers: list[MessageRuleTrigger] = []
73+
self.anti_active_reaction_rule_triggers: list[ReactionRuleTrigger] = []
74+
self.all_message_rule_triggers: list[MessageRuleTrigger] = []
7575
self.sorted_level_thresholds: list[tuple[int, int]] = []
7676

7777

@@ -109,9 +109,9 @@ async def _load_rules(self) -> None:
109109

110110
rule_name = toml_file.stem
111111
try:
112-
rule_triggers = [RuleTrigger(**rule_trigger) for rule_trigger in rule_dict["rule"]]
113-
rule = LevelRules(rule_name, rule_triggers)
114-
except (TypeError, KeyError):
112+
rule_triggers = _rule_trigger_adapter.validate_python(rule_dict["rule"])
113+
rule = LevelRules(name=rule_name, rule_triggers=rule_triggers)
114+
except (KeyError, ValidationError):
115115
logger.info(f"{toml_file} not properly formatted, skipping.")
116116
continue
117117

@@ -267,20 +267,14 @@ async def on_message(self, msg: discord.Message) -> None:
267267
total_points = 0
268268
rule_matches = 0
269269
for rule_trigger in self.active_message_rule_triggers:
270-
re_pattern = rule_trigger.message_content
271-
if re_pattern is None:
272-
continue
273-
match = re.search(re_pattern, msg.content)
270+
match = re.search(rule_trigger.message_content, msg.content)
274271
if match:
275272
total_points += rule_trigger.points
276273
rule_matches += 1
277274

278275
anti_active_rule_matches = 0
279276
for anti_active_rule_trigger in self.anti_active_message_rule_triggers:
280-
re_pattern = anti_active_rule_trigger.message_content
281-
if re_pattern is None:
282-
continue
283-
match = re.search(re_pattern, msg.content)
277+
match = re.search(anti_active_rule_trigger.message_content, msg.content)
284278
if match:
285279
anti_active_rule_matches += 1
286280
rule_matches += 1
@@ -295,10 +289,7 @@ async def on_message(self, msg: discord.Message) -> None:
295289

296290
total_rule_matches = 0
297291
for rule_trigger in self.all_message_rule_triggers:
298-
re_pattern = rule_trigger.message_content
299-
if re_pattern is None:
300-
continue
301-
match = re.search(re_pattern, msg.content)
292+
match = re.search(rule_trigger.message_content, msg.content)
302293
if match:
303294
total_rule_matches += 1
304295
if total_rule_matches >= 3:
@@ -329,13 +320,13 @@ async def on_reaction_add(self, reaction: discord.Reaction, user: discord.Member
329320
total_points = 0
330321
rule_matches = 0
331322
for rule_trigger in self.active_reaction_rule_triggers:
332-
if rule_trigger.reaction_content and emoji_name in rule_trigger.reaction_content:
323+
if emoji_name in rule_trigger.reaction_content:
333324
total_points += rule_trigger.points
334325
rule_matches += 1
335326

336327
anti_active_rule_matches = 0
337328
for anti_active_rule_trigger in self.anti_active_reaction_rule_triggers:
338-
if anti_active_rule_trigger.reaction_content and emoji_name in anti_active_rule_trigger.reaction_content:
329+
if emoji_name in anti_active_rule_trigger.reaction_content:
339330
anti_active_rule_matches += 1
340331
rule_matches += 1
341332

@@ -438,14 +429,22 @@ async def role_reset(self, ctx: commands.Context, member: discord.Member) -> Non
438429

439430
# Please see ./rules/README.md for how to format rules
440431

441-
@dataclass
442-
class RuleTrigger:
443-
interaction_type: Literal["message", "reaction"]
444-
reaction_content: list[str] | None = None
445-
message_content: str | None = None
432+
class MessageRuleTrigger(BaseModel):
433+
interaction_type: Literal["message"]
434+
message_content: str
446435
points: int = 0
447436

448-
@dataclass
449-
class LevelRules:
437+
class ReactionRuleTrigger(BaseModel):
438+
interaction_type: Literal["reaction"]
439+
reaction_content: list[str]
440+
points: int = 0
441+
442+
RuleTrigger = Annotated[
443+
MessageRuleTrigger | ReactionRuleTrigger,
444+
Field(discriminator="interaction_type"),
445+
]
446+
_rule_trigger_adapter = TypeAdapter(list[RuleTrigger])
447+
448+
class LevelRules(BaseModel):
450449
name: str
451450
rule_triggers: list[RuleTrigger]

0 commit comments

Comments
 (0)