|
1 | 1 | """Wrapper for Github CLI commands.""" |
2 | 2 | # TODO: The following should be built using the builder pattern |
3 | 3 |
|
4 | | -from typing import Optional |
| 4 | +import json |
| 5 | +import re |
| 6 | +from typing import Any, Optional |
5 | 7 |
|
6 | 8 | from exercise_utils.cli import run |
7 | 9 |
|
8 | 10 |
|
| 11 | +_PR_STATES = {"open", "closed", "merged", "all"} |
| 12 | +_PR_MERGE_METHODS = {"merge", "squash", "rebase"} |
| 13 | +_PR_REVIEW_ACTIONS = {"request-changes", "comment"} |
| 14 | + |
| 15 | + |
9 | 16 | def fork_repo( |
10 | 17 | repository_name: str, |
11 | 18 | fork_name: str, |
@@ -127,3 +134,244 @@ def get_remote_url(repository_name: str, verbose: bool) -> str: |
127 | 134 | remote_url = f"git@github.com:{repository_name}.git" |
128 | 135 |
|
129 | 136 | return remote_url |
| 137 | + |
| 138 | + |
| 139 | +def create_pr( |
| 140 | + title: str, |
| 141 | + body: str, |
| 142 | + base: str, |
| 143 | + head: str, |
| 144 | + repo_name: str, |
| 145 | + verbose: bool, |
| 146 | + draft: bool = False, |
| 147 | +) -> Optional[int]: |
| 148 | + """Create a pull request.""" |
| 149 | + command = _build_pr_command("create", repo_name=repo_name) |
| 150 | + command = _append_value_flag(command, "--title", title) |
| 151 | + command = _append_value_flag(command, "--body", body) |
| 152 | + command = _append_value_flag(command, "--base", base) |
| 153 | + command = _append_value_flag(command, "--head", head) |
| 154 | + command = _append_bool_flag(command, draft, "--draft") |
| 155 | + |
| 156 | + result = run(command, verbose) |
| 157 | + if not result.is_success(): |
| 158 | + return None |
| 159 | + |
| 160 | + match = re.search(r"/pull/(\d+)", result.stdout) |
| 161 | + if match is None: |
| 162 | + return None |
| 163 | + |
| 164 | + return int(match.group(1)) |
| 165 | + |
| 166 | + |
| 167 | +def _append_repo_flag(command: list[str], repo_name: str) -> list[str]: |
| 168 | + """Append --repo flag. PR commands require explicit repository context.""" |
| 169 | + if repo_name.strip() == "": |
| 170 | + raise ValueError("repo_name must be provided for deterministic PR commands") |
| 171 | + |
| 172 | + return [*command, "--repo", repo_name] |
| 173 | + |
| 174 | + |
| 175 | +def _validate_choice(value: str, allowed: set[str], field_name: str) -> str: |
| 176 | + """Validate a string argument against a known set of values.""" |
| 177 | + if value not in allowed: |
| 178 | + allowed_values = ", ".join(sorted(allowed)) |
| 179 | + raise ValueError( |
| 180 | + f"Invalid {field_name}: {value}. Allowed values: {allowed_values}" |
| 181 | + ) |
| 182 | + return value |
| 183 | + |
| 184 | + |
| 185 | +def _build_pr_command(subcommand: str, *args: str, repo_name: str) -> list[str]: |
| 186 | + """Build a gh pr command and append deterministic repository context.""" |
| 187 | + return _append_repo_flag(["gh", "pr", subcommand, *args], repo_name) |
| 188 | + |
| 189 | + |
| 190 | +def _append_bool_flag(command: list[str], enabled: bool, flag: str) -> list[str]: |
| 191 | + """Append a CLI flag when the related boolean option is enabled.""" |
| 192 | + return [*command, flag] if enabled else command |
| 193 | + |
| 194 | + |
| 195 | +def _append_value_flag(command: list[str], flag: str, value: str) -> list[str]: |
| 196 | + """Append a value-taking CLI option in --flag=value form.""" |
| 197 | + return [*command, f"{flag}={value}"] |
| 198 | + |
| 199 | + |
| 200 | +def _parse_json_or_default(raw_output: str, default: Any) -> Any: |
| 201 | + """Parse JSON output and return a default value on decode failure.""" |
| 202 | + try: |
| 203 | + return json.loads(raw_output) |
| 204 | + except json.JSONDecodeError: |
| 205 | + return default |
| 206 | + |
| 207 | + |
| 208 | +def view_pr(pr_number: int, repo_name: str, verbose: bool) -> dict[str, Any]: |
| 209 | + """View pull request details.""" |
| 210 | + fields = "title,body,state,author,headRefName,baseRefName,comments,reviews" |
| 211 | + |
| 212 | + command = _build_pr_command( |
| 213 | + "view", |
| 214 | + str(pr_number), |
| 215 | + repo_name=repo_name, |
| 216 | + ) |
| 217 | + command = _append_value_flag(command, "--json", fields) |
| 218 | + |
| 219 | + result = run( |
| 220 | + command, |
| 221 | + verbose, |
| 222 | + ) |
| 223 | + |
| 224 | + if result.is_success(): |
| 225 | + parsed = _parse_json_or_default(result.stdout, {}) |
| 226 | + return parsed if isinstance(parsed, dict) else {} |
| 227 | + return {} |
| 228 | + |
| 229 | + |
| 230 | +def comment_on_pr( |
| 231 | + pr_number: int, |
| 232 | + comment: str, |
| 233 | + repo_name: str, |
| 234 | + verbose: bool, |
| 235 | +) -> bool: |
| 236 | + """Add a comment to a pull request.""" |
| 237 | + command = _build_pr_command("comment", str(pr_number), repo_name=repo_name) |
| 238 | + command = _append_value_flag(command, "--body", comment) |
| 239 | + |
| 240 | + result = run( |
| 241 | + command, |
| 242 | + verbose, |
| 243 | + ) |
| 244 | + return result.is_success() |
| 245 | + |
| 246 | + |
| 247 | +def list_prs( |
| 248 | + state: str, |
| 249 | + repo_name: str, |
| 250 | + verbose: bool, |
| 251 | + limit: int = 30, |
| 252 | + search: Optional[str] = None, |
| 253 | +) -> list[dict[str, Any]]: |
| 254 | + """ |
| 255 | + List pull requests. |
| 256 | + PR state filter ('open', 'closed', 'merged', 'all') |
| 257 | + Optional search query using GitHub search syntax. |
| 258 | + """ |
| 259 | + validated_state = _validate_choice(state, _PR_STATES, "state") |
| 260 | + fields = "number,title,state,author,headRefName,baseRefName" |
| 261 | + command = _build_pr_command("list", repo_name=repo_name) |
| 262 | + command = _append_value_flag(command, "--state", validated_state) |
| 263 | + command = _append_value_flag(command, "--json", fields) |
| 264 | + command = _append_value_flag(command, "--limit", str(limit)) |
| 265 | + |
| 266 | + if search is not None and search.strip() != "": |
| 267 | + command = _append_value_flag(command, "--search", search) |
| 268 | + |
| 269 | + result = run(command, verbose) |
| 270 | + |
| 271 | + if result.is_success(): |
| 272 | + parsed = _parse_json_or_default(result.stdout, []) |
| 273 | + return parsed if isinstance(parsed, list) else [] |
| 274 | + return [] |
| 275 | + |
| 276 | + |
| 277 | +def merge_pr( |
| 278 | + pr_number: int, |
| 279 | + merge_method: str, |
| 280 | + repo_name: str, |
| 281 | + delete_branch: bool = True, |
| 282 | + verbose: bool = False, |
| 283 | +) -> bool: |
| 284 | + """ |
| 285 | + Merge a pull request. |
| 286 | + Merge method ('merge', 'squash', 'rebase') |
| 287 | + """ |
| 288 | + validated_merge_method = _validate_choice( |
| 289 | + merge_method, |
| 290 | + _PR_MERGE_METHODS, |
| 291 | + "merge_method", |
| 292 | + ) |
| 293 | + command = _build_pr_command( |
| 294 | + "merge", |
| 295 | + str(pr_number), |
| 296 | + f"--{validated_merge_method}", |
| 297 | + repo_name=repo_name, |
| 298 | + ) |
| 299 | + |
| 300 | + command = _append_bool_flag(command, delete_branch, "--delete-branch") |
| 301 | + |
| 302 | + result = run(command, verbose) |
| 303 | + return result.is_success() |
| 304 | + |
| 305 | + |
| 306 | +def close_pr( |
| 307 | + pr_number: int, |
| 308 | + repo_name: str, |
| 309 | + comment: Optional[str] = None, |
| 310 | + delete_branch: bool = False, |
| 311 | + verbose: bool = False, |
| 312 | +) -> bool: |
| 313 | + """Close a pull request without merging.""" |
| 314 | + command = _build_pr_command( |
| 315 | + "close", |
| 316 | + str(pr_number), |
| 317 | + repo_name=repo_name, |
| 318 | + ) |
| 319 | + command = _append_bool_flag(command, delete_branch, "--delete-branch") |
| 320 | + |
| 321 | + if comment: |
| 322 | + command = _append_value_flag(command, "--comment", comment) |
| 323 | + |
| 324 | + result = run(command, verbose) |
| 325 | + return result.is_success() |
| 326 | + |
| 327 | + |
| 328 | +def review_pr( |
| 329 | + pr_number: int, |
| 330 | + comment: str, |
| 331 | + action: str, |
| 332 | + repo_name: str, |
| 333 | + verbose: bool, |
| 334 | +) -> bool: |
| 335 | + """ |
| 336 | + Submit a review on a pull request. |
| 337 | + Review action ('request-changes', 'comment') |
| 338 | + """ |
| 339 | + validated_action = _validate_choice(action, _PR_REVIEW_ACTIONS, "action") |
| 340 | + command = _build_pr_command("review", str(pr_number), repo_name=repo_name) |
| 341 | + command = _append_value_flag(command, "--body", comment) |
| 342 | + command.append(f"--{validated_action}") |
| 343 | + |
| 344 | + result = run(command, verbose) |
| 345 | + return result.is_success() |
| 346 | + |
| 347 | + |
| 348 | +def get_pr_numbers_by_author(username: str, repo_name: str, verbose: bool) -> list[int]: |
| 349 | + """Return the latest opened pull request numbers created by username in the repo.""" |
| 350 | + command = _build_pr_command("list", repo_name=repo_name) |
| 351 | + command = _append_value_flag(command, "--author", username) |
| 352 | + command = _append_value_flag(command, "--state", "open") |
| 353 | + command = _append_value_flag(command, "--json", "number") |
| 354 | + |
| 355 | + result = run(command, verbose) |
| 356 | + if not result.is_success(): |
| 357 | + return [] |
| 358 | + |
| 359 | + import json |
| 360 | + |
| 361 | + try: |
| 362 | + prs = json.loads(result.stdout) |
| 363 | + except json.JSONDecodeError: |
| 364 | + return [] |
| 365 | + |
| 366 | + pr_numbers = [pr.get("number") for pr in prs if isinstance(pr.get("number"), int)] |
| 367 | + pr_numbers.sort() |
| 368 | + return pr_numbers |
| 369 | + |
| 370 | + |
| 371 | +def get_latest_pr_number_by_author( |
| 372 | + username: str, repo_full_name: str, verbose: bool |
| 373 | +) -> Optional[int]: |
| 374 | + """Return the latest open pull request number created by username in the repo.""" |
| 375 | + if pr_numbers := get_pr_numbers_by_author(username, repo_full_name, verbose): |
| 376 | + return pr_numbers[-1] |
| 377 | + raise ValueError(f"No open PRs found for user {username} in repo {repo_full_name}.") |
0 commit comments