Skip to content

Commit 57f26b3

Browse files
Refactor stats cog (#472)
1 parent 411f0bf commit 57f26b3

3 files changed

Lines changed: 276 additions & 210 deletions

File tree

cogs/stats.py renamed to cogs/stats/__init__.py

Lines changed: 14 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -1,143 +1,29 @@
11
"""Contains cog classes for any stats interactions."""
22

3-
import io
43
import math
54
import re
65
from typing import TYPE_CHECKING
76

87
import discord
9-
import matplotlib.pyplot
10-
import mplcyberpunk
118

129
from config import settings
1310
from db.core.models import LeftDiscordMember
14-
from utils import CommandChecks, TeXBotBaseCog
11+
from utils import (
12+
CommandChecks,
13+
TeXBotBaseCog,
14+
)
1515
from utils.error_capture_decorators import capture_guild_does_not_exist_error
1616

17+
from .counts import get_channel_message_counts, get_server_message_counts
18+
from .graphs import amount_of_time_formatter, plot_bar_chart
19+
1720
if TYPE_CHECKING:
18-
from collections.abc import AsyncIterable, Collection, Sequence
21+
from collections.abc import AsyncIterable, Sequence
1922
from typing import Final
2023

21-
from matplotlib.text import Text as Plot_Text
22-
2324
from utils import TeXBotApplicationContext
2425

25-
__all__: "Sequence[str]" = ("StatsCommandsCog", "amount_of_time_formatter", "plot_bar_chart")
26-
27-
28-
def amount_of_time_formatter(value: float, time_scale: str) -> str:
29-
"""
30-
Format the amount of time value according to the provided time_scale.
31-
32-
E.g. past "1 days" => past "day",
33-
past "2.00 weeks" => past "2 weeks",
34-
past "3.14159 months" => past "3.14 months"
35-
"""
36-
if value == 1 or float(f"{value:.2f}") == 1:
37-
return f"{time_scale}"
38-
39-
if value % 1 == 0 or float(f"{value:.2f}") % 1 == 0:
40-
return f"{int(value)} {time_scale}s"
41-
42-
return f"{value:.2f} {time_scale}s"
43-
44-
45-
def plot_bar_chart(
46-
data: dict[str, int],
47-
x_label: str,
48-
y_label: str,
49-
title: str,
50-
filename: str,
51-
description: str,
52-
extra_text: str = "",
53-
) -> discord.File:
54-
"""Generate an image of a plot bar chart from the given data & format variables."""
55-
matplotlib.pyplot.style.use("cyberpunk")
56-
57-
max_data_value: int = max(data.values()) + 1
58-
59-
# NOTE: The "extra_values" dictionary represents columns of data that should be formatted differently to the standard data columns
60-
extra_values: dict[str, int] = {}
61-
if "Total" in data:
62-
extra_values["Total"] = data.pop("Total")
63-
64-
if len(data) > 4:
65-
data = {
66-
key: value
67-
for index, (key, value) in enumerate(data.items())
68-
if value > 0 or index <= 4
69-
}
70-
71-
bars = matplotlib.pyplot.bar(*zip(*data.items(), strict=True))
72-
73-
if extra_values:
74-
extra_bars = matplotlib.pyplot.bar(*zip(*extra_values.items(), strict=True))
75-
mplcyberpunk.add_bar_gradient(extra_bars)
76-
77-
mplcyberpunk.add_bar_gradient(bars)
78-
79-
x_tick_labels: Collection[Plot_Text] = matplotlib.pyplot.gca().get_xticklabels()
80-
count_x_tick_labels: int = len(x_tick_labels)
81-
82-
index: int
83-
tick_label: Plot_Text
84-
for index, tick_label in enumerate(x_tick_labels):
85-
if tick_label.get_text() == "Total":
86-
tick_label.set_fontweight("bold")
87-
88-
# NOTE: Shifts the y location of every other horizontal label down so that they do not overlap with one-another
89-
if index % 2 == 1 and count_x_tick_labels > 4:
90-
tick_label.set_y(tick_label.get_position()[1] - 0.044)
91-
92-
matplotlib.pyplot.yticks(range(0, max_data_value, math.ceil(max_data_value / 15)))
93-
94-
x_label_obj: Plot_Text = matplotlib.pyplot.xlabel(
95-
x_label,
96-
fontweight="bold",
97-
fontsize="large",
98-
wrap=True,
99-
)
100-
x_label_obj._get_wrap_line_width = lambda: 475 # type: ignore[attr-defined]
101-
102-
y_label_obj: Plot_Text = matplotlib.pyplot.ylabel(
103-
y_label,
104-
fontweight="bold",
105-
fontsize="large",
106-
wrap=True,
107-
)
108-
y_label_obj._get_wrap_line_width = lambda: 375 # type: ignore[attr-defined]
109-
110-
title_obj: Plot_Text = matplotlib.pyplot.title(title, fontsize="x-large", wrap=True)
111-
title_obj._get_wrap_line_width = lambda: 500 # type: ignore[attr-defined]
112-
113-
if extra_text:
114-
extra_text_obj: Plot_Text = matplotlib.pyplot.text(
115-
0.5,
116-
-0.27,
117-
extra_text,
118-
ha="center",
119-
transform=matplotlib.pyplot.gca().transAxes,
120-
wrap=True,
121-
fontstyle="italic",
122-
fontsize="small",
123-
)
124-
extra_text_obj._get_wrap_line_width = lambda: 400 # type: ignore[attr-defined]
125-
matplotlib.pyplot.subplots_adjust(bottom=0.2)
126-
127-
plot_file = io.BytesIO()
128-
matplotlib.pyplot.savefig(plot_file, format="png")
129-
matplotlib.pyplot.close()
130-
plot_file.seek(0)
131-
132-
discord_plot_file: discord.File = discord.File(
133-
plot_file,
134-
filename,
135-
description=description,
136-
)
137-
138-
plot_file.close()
139-
140-
return discord_plot_file
26+
__all__: "Sequence[str]" = ("StatsCommandsCog",)
14127

14228

14329
class StatsCommandsCog(TeXBotBaseCog):
@@ -229,45 +115,11 @@ async def channel_stats(
229115

230116
await ctx.defer(ephemeral=True)
231117

232-
message_counts: dict[str, int] = {"Total": 0}
233-
234-
role_name: str
235-
for role_name in settings["STATISTICS_ROLES"]:
236-
if discord.utils.get(main_guild.roles, name=role_name):
237-
message_counts[f"@{role_name}"] = 0
238-
239-
message_history_period: AsyncIterable[discord.Message] = channel.history(
240-
after=discord.utils.utcnow() - settings["STATISTICS_DAYS"],
241-
)
242-
message: discord.Message
243-
async for message in message_history_period:
244-
if message.author.bot:
245-
continue
246-
247-
message_counts["Total"] += 1
248-
249-
if isinstance(message.author, discord.User):
250-
continue
251-
252-
author_role_names: set[str] = {
253-
author_role.name for author_role in message.author.roles
254-
}
255-
256-
author_role_name: str
257-
for author_role_name in author_role_names:
258-
if f"@{author_role_name}" in message_counts:
259-
is_author_role_name: bool = author_role_name == "Committee"
260-
if is_author_role_name and "Committee-Elect" in author_role_names:
261-
continue
262-
263-
if author_role_name == "Guest" and "Member" in author_role_names:
264-
continue
265-
266-
message_counts[f"@{author_role_name}"] += 1
118+
message_counts: dict[str, int] = await get_channel_message_counts(channel=channel)
267119

268120
if math.ceil(max(message_counts.values()) / 15) < 1:
269121
await self.command_send_error(
270-
ctx,
122+
ctx=ctx,
271123
message="There are not enough messages sent in this channel.",
272124
)
273125
return
@@ -315,57 +167,9 @@ async def server_stats(self, ctx: "TeXBotApplicationContext") -> None:
315167

316168
await ctx.defer(ephemeral=True)
317169

318-
message_counts: dict[str, dict[str, int]] = {
319-
"roles": {"Total": 0},
320-
"channels": {},
321-
}
322-
323-
role_name: str
324-
for role_name in settings["STATISTICS_ROLES"]:
325-
if discord.utils.get(main_guild.roles, name=role_name):
326-
message_counts["roles"][f"@{role_name}"] = 0
327-
328-
channel: discord.TextChannel
329-
for channel in main_guild.text_channels:
330-
member_has_access_to_channel: bool = channel.permissions_for(
331-
guest_role,
332-
).is_superset(
333-
discord.Permissions(send_messages=True),
334-
)
335-
if not member_has_access_to_channel:
336-
continue
337-
338-
message_counts["channels"][f"#{channel.name}"] = 0
339-
340-
message_history_period: AsyncIterable[discord.Message] = channel.history(
341-
after=discord.utils.utcnow() - settings["STATISTICS_DAYS"],
342-
)
343-
message: discord.Message
344-
async for message in message_history_period:
345-
if message.author.bot:
346-
continue
347-
348-
message_counts["channels"][f"#{channel.name}"] += 1
349-
message_counts["roles"]["Total"] += 1
350-
351-
if isinstance(message.author, discord.User):
352-
continue
353-
354-
author_role_names: set[str] = {
355-
author_role.name for author_role in message.author.roles
356-
}
357-
358-
author_role_name: str
359-
for author_role_name in author_role_names:
360-
if f"@{author_role_name}" in message_counts["roles"]:
361-
is_author_role_committee: bool = author_role_name == "Committee"
362-
if is_author_role_committee and "Committee-Elect" in author_role_names:
363-
continue
364-
365-
if author_role_name == "Guest" and "Member" in author_role_names:
366-
continue
367-
368-
message_counts["roles"][f"@{author_role_name}"] += 1
170+
message_counts: dict[str, dict[str, int]] = await get_server_message_counts(
171+
guild=main_guild, guest_role=guest_role
172+
)
369173

370174
TOO_FEW_ROLES_STATS: Final[bool] = (
371175
math.ceil(max(message_counts["roles"].values()) / 15) < 1

cogs/stats/counts.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Contains methods relating to counting messages in channels."""
2+
3+
from typing import TYPE_CHECKING
4+
5+
import discord
6+
7+
from config import settings
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import AsyncIterable, Sequence
11+
12+
13+
__all__: "Sequence[str]" = ("get_channel_message_counts", "get_server_message_counts")
14+
15+
16+
async def get_channel_message_counts(channel: discord.TextChannel) -> dict[str, int]:
17+
"""
18+
Get the message counts for each role in the given channel.
19+
20+
The message counts are stored in a mapping with the role name (prefixed by `@`) as the key
21+
and the number of messages sent by users with that role as the value.
22+
The mapping also includes a "Total" key for the total number of messages.
23+
"""
24+
message_counts: dict[str, int] = {"Total": 0}
25+
26+
role_name: str
27+
for role_name in settings["STATISTICS_ROLES"]:
28+
if discord.utils.get(channel.guild.roles, name=role_name):
29+
message_counts[f"@{role_name}"] = 0
30+
31+
message_history_period: AsyncIterable[discord.Message] = channel.history(
32+
after=discord.utils.utcnow() - settings["STATISTICS_DAYS"],
33+
)
34+
message: discord.Message
35+
async for message in message_history_period:
36+
if message.author.bot:
37+
continue
38+
39+
message_counts["Total"] += 1
40+
41+
if isinstance(message.author, discord.User):
42+
continue
43+
44+
author_role_names: set[str] = {
45+
author_role.name for author_role in message.author.roles
46+
}
47+
48+
author_role_name: str
49+
for author_role_name in author_role_names:
50+
if f"@{author_role_name}" in message_counts:
51+
is_author_role_name: bool = author_role_name == "Committee"
52+
if is_author_role_name and "Committee-Elect" in author_role_names:
53+
continue
54+
55+
if author_role_name == "Guest" and "Member" in author_role_names:
56+
continue
57+
58+
message_counts[f"@{author_role_name}"] += 1
59+
60+
return message_counts
61+
62+
63+
async def get_server_message_counts(
64+
guild: discord.Guild, *, guest_role: discord.Role
65+
) -> dict[str, dict[str, int]]:
66+
"""
67+
Get the message counts for each channel in the given server.
68+
69+
The message counts are stored in a mapping. It contains a key "roles" which is
70+
a mapping of role names (prefixed by `@`) to the message counts
71+
for each role across the entire server.
72+
The mapping also contains a key "channels" which is a mapping with the channel
73+
name as a key and the number of messages sent in that channel as the value.
74+
The "roles" sub-mapping also includes a "Total" key for the total number of messages.
75+
"""
76+
message_counts: dict[str, dict[str, int]] = {
77+
"roles": {"Total": 0},
78+
"channels": {},
79+
}
80+
81+
role_name: str
82+
for role_name in settings["STATISTICS_ROLES"]:
83+
if discord.utils.get(guild.roles, name=role_name):
84+
message_counts["roles"][f"@{role_name}"] = 0
85+
86+
channel: discord.TextChannel
87+
for channel in guild.text_channels:
88+
member_has_access_to_channel: bool = channel.permissions_for(
89+
guest_role,
90+
).is_superset(
91+
discord.Permissions(send_messages=True),
92+
)
93+
if not member_has_access_to_channel:
94+
continue
95+
96+
message_counts["channels"][f"#{channel.name}"] = 0
97+
98+
message_history_period: AsyncIterable[discord.Message] = channel.history(
99+
after=discord.utils.utcnow() - settings["STATISTICS_DAYS"],
100+
)
101+
message: discord.Message
102+
async for message in message_history_period:
103+
if message.author.bot:
104+
continue
105+
106+
message_counts["channels"][f"#{channel.name}"] += 1
107+
message_counts["roles"]["Total"] += 1
108+
109+
if isinstance(message.author, discord.User):
110+
continue
111+
112+
author_role_names: set[str] = {
113+
author_role.name for author_role in message.author.roles
114+
}
115+
116+
author_role_name: str
117+
for author_role_name in author_role_names:
118+
if f"@{author_role_name}" in message_counts["roles"]:
119+
is_author_role_committee: bool = author_role_name == "Committee"
120+
if is_author_role_committee and "Committee-Elect" in author_role_names:
121+
continue
122+
123+
if author_role_name == "Guest" and "Member" in author_role_names:
124+
continue
125+
126+
message_counts["roles"][f"@{author_role_name}"] += 1
127+
128+
return message_counts

0 commit comments

Comments
 (0)