diff --git a/bot/exts/utilities/githubinfo.py b/bot/exts/utilities/githubinfo.py index 219d48a2b..dc80afb0d 100644 --- a/bot/exts/utilities/githubinfo.py +++ b/bot/exts/utilities/githubinfo.py @@ -34,7 +34,7 @@ REQUEST_HEADERS["Authorization"] = f"token {Tokens.github.get_secret_value()}" CODE_BLOCK_RE = re.compile( - r"^`([^`\n]+)`" # Inline codeblock + r"`([^`\n]+)`" # Inline codeblock r"|```(.+?)```", # Multiline codeblock re.DOTALL | re.MULTILINE ) @@ -48,6 +48,17 @@ r"((?P[a-zA-Z0-9][a-zA-Z0-9\-]{1,39})\/)?(?P[\w\-\.]{1,100})#(?P[0-9]+)" ) +class GithubAPIError(Exception): + """Raised when GitHub API returns a non 200 status code.""" + + def __init__(self, status: int, message: str = "GitHub API error"): + self.status = status + self.message = message + super().__init__(f"{message} (Status: {status})") + +class StargazersLimitError(Exception): + """Raised when a repository exceeds the searchable stargazer limit.""" + @dataclass(eq=True, frozen=True) class FoundIssue: @@ -366,6 +377,254 @@ def build_embed(self, repo_data: dict) -> discord.Embed: ) return embed + async def get_issue_count(self, repo: str, start: str, end: str, state: str) -> int: + """Gets the number of issues opened or closed (based on state) in a given timeframe.""" + url = f"{GITHUB_API_URL}/search/issues" + query = f"repo:{repo} is:issue {state}:{start}..{end}" + params = {"q": query} + + async with self.bot.http_session.get(url, headers=REQUEST_HEADERS, params=params) as response: + if response.status != 200: + raise GithubAPIError(response.status) + data = await response.json() + return data.get("total_count", 0) + + async def get_pr_count(self, repo: str, start: str, end: str, action: str) -> int: + """Gets the number of PRs opened, closed, or merged in a given timeframe.""" + url = f"{GITHUB_API_URL}/search/issues" + + if action == "opened": + state_query = f"created:{start}..{end}" + elif action == "merged": + state_query = f"is:merged merged:{start}..{end}" + elif action == "closed": + state_query = f"is:unmerged closed:{start}..{end}" + else: + return 0 + + query = f"repo:{repo} is:pr {state_query}" + params = {"q": query} + + async with self.bot.http_session.get(url, headers=REQUEST_HEADERS, params=params) as response: + if response.status != 200: + raise GithubAPIError(response.status) + + data = await response.json() + return data.get("total_count", 0) + + async def get_commit_count(self, repo_str: str, start_str: str, end_str: str) -> int: + """Returns the number of commits done to the given repo between the start and end date.""" + start_iso = f"{start_str}T00:00:00Z" + end_iso = f"{end_str}T23:59:59Z" + + url = f"https://api.github.com/repos/{repo_str}/commits" + params = {"since": start_iso, "until": end_iso, "per_page": 1, "page": 1} + + async with self.bot.http_session.get(url, headers=REQUEST_HEADERS, params=params) as response: + if response.status != 200: + raise GithubAPIError(response.status) + + commits_json = await response.json() + # No commits + if not commits_json: + return 0 + + link_header = response.headers.get("Link") + # No link header means only one page + if not link_header: + return 1 + + # Grabbing the number of pages from the Link header + match = re.search(r'page=(\d+)>; rel="last"', link_header) + + if match: + return int(match.group(1)) + + return 1 + + async def _fetch_page(self, url: str, headers: dict, page: int, cache: dict) -> list: + """Fetch a page of stargazers, using cache to avoid duplicate requests.""" + if page not in cache: + params = {"per_page": 100, "page": page} + async with self.bot.http_session.get(url, headers=headers, params=params) as response: + if response.status != 200: + raise GithubAPIError(response.status) + cache[page] = await response.json() + return cache[page] + + async def _get_date_at(self, url: str, headers: dict, i: int, cache: dict) -> str: + """Get the starred_at date (YYYY-MM-DD) of the star at global index i (0-based).""" + page = (i // 100) + 1 + pos = i % 100 + page_data = await self._fetch_page(url, headers, page, cache) + + # FIX: Prevent IndexError if GitHub's cached count is higher than the actual list + if page_data and pos < len(page_data): + return page_data[pos].get("starred_at", "")[:10] + return "" + + async def get_stars_gained(self, repo: str, start: str, end: str) -> int: + """Gets the number of stars gained for a given repository in a timeframe.""" + url = f"{GITHUB_API_URL}/repos/{repo}/stargazers" + + # Copy the global headers but update the Accept header specifically for Stargazers + star_headers = REQUEST_HEADERS.copy() + star_headers["Accept"] = "application/vnd.github.star+json" + + repo_data, response = await self.fetch_data(f"{GITHUB_API_URL}/repos/{repo}") + if response.status != 200: + raise GithubAPIError(response.status) + + max_stars = repo_data.get("stargazers_count", 0) + + if max_stars == 0: + return 0 + + # GitHub API limits stargazers pagination to 40 000 entries (page 400 max) + # Because of this the output is not consistent for projects with more than 40 000 stars so we default to -2 + github_stargazer_limit = 40000 + if max_stars > github_stargazer_limit: + raise StargazersLimitError("Repository exceeds the 40,000 star limit.") + searchable_stars = max_stars + + # We use a cache and binary search to limit the number of requests to the GitHub API + cache = {} + low, high = 0, searchable_stars - 1 + while low < high: + mid = (low + high) // 2 + lowdate = await self._get_date_at(url, star_headers, mid, cache) + if lowdate == "": + return -1 + if lowdate < start: + low = mid + 1 + else: + high = mid + left = low + + date_left = await self._get_date_at(url, star_headers, left, cache) + if date_left < start or date_left > end: + return 0 + + low, high = left, searchable_stars - 1 + while low < high: + mid = (low + high + 1) // 2 + highdate = await self._get_date_at(url, star_headers, mid, cache) + if highdate == "": + return -1 + if highdate > end: + high = mid - 1 + else: + low = mid + right = low + + return right - left + 1 + + def parse_date(self, date_str: str) -> datetime | None: + """Parse a YYYY-MM-DD date string into a UTC datetime.""" + try: + return datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=UTC) + except ValueError: + return None + + def validate_date_format(self, date_str: str) -> bool: + """Validates that the date string is formatted correctly.""" + return self.parse_date(date_str) is not None + + def validate_date_range(self, start_date: str, end_date: str) -> bool: + """Validate a date range for correctness and logical ordering.""" + start = self.parse_date(start_date) + end = self.parse_date(end_date) + + if start and end: + return start <= end + return False + + @github_group.command(name="stats") + async def github_stats(self, ctx: commands.Context, start: str, end: str, repo: str) -> None: + """ + Fetches stats for a GitHub repo. + + Usage: !github_stats 2023-01-01 2023-12-31 python-discord/bot. + """ + async with ctx.typing(): + # Validate the date first to spare API calls + if not self.validate_date_format(start): + embed = discord.Embed( + title=random.choice(NEGATIVE_REPLIES), + description="Start date must be in YYYY-MM-DD format.", + colour=Colours.soft_red, + ) + await ctx.send(embed=embed) + + return + + if not self.validate_date_format(end): + embed = discord.Embed( + title=random.choice(NEGATIVE_REPLIES), + description="End date must be in YYYY-MM-DD format.", + colour=Colours.soft_red, + ) + await ctx.send(embed=embed) + return + + if not self.validate_date_range(start, end): + embed = discord.Embed( + title=random.choice(NEGATIVE_REPLIES), + description="Invalid date range.", + colour=Colours.soft_red, + ) + await ctx.send(embed=embed) + return + + url = f"{GITHUB_API_URL}/repos/{repo}" + repo_data, _ = await self.fetch_data(url) + + if "message" in repo_data: + embed = discord.Embed( + title=random.choice(NEGATIVE_REPLIES), + description=f"Could not find repository: `{repo}`", + colour=Colours.soft_red, + ) + await ctx.send(embed=embed) + return + + try: + open_issues = await self.get_issue_count(repo, start, end, state="created") + closed_issues = await self.get_issue_count(repo, start, end, state="closed") + prs_opened = await self.get_pr_count(repo, start, end, "opened") + prs_closed = await self.get_pr_count(repo, start, end, "closed") + prs_merged = await self.get_pr_count(repo, start, end, "merged") + commits = await self.get_commit_count(repo, start, end) + + try: + stars_gained = await self.get_stars_gained(repo, start, end) + stars = f"+{stars_gained}" if stars_gained > 0 else "0" + except StargazersLimitError: + stars = "N/A (repo exceeded API limit)" + + except GithubAPIError as e: + embed = discord.Embed( + title=random.choice(NEGATIVE_REPLIES), + description=f"Failed to fetch data from GitHub API. (Status Code: {e.status})", + colour=Colours.soft_red, + ) + await ctx.send(embed=embed) + return + + stats_text = ( + f"Issues opened: {open_issues}\n" + f"Issues closed: {closed_issues}\n" + f"Pull Requests opened: {prs_opened}\n" + f"Pull Requests closed: {prs_closed}\n" + f"Pull Requests merged: {prs_merged}\n" + f"Stars gained: {stars}\n" + f"Commits: {commits}" + ) + + stats_embed = discord.Embed( + title=f"Stats for {repo}", description=stats_text, colour=discord.Colour.og_blurple() + ) + await ctx.send(embed=stats_embed) @github_group.command(name="repository", aliases=("repo",)) async def github_repo_info(self, ctx: commands.Context, *repo: str) -> None: