forked from awslabs/python-deequ
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgithub_client.py
More file actions
279 lines (248 loc) · 10.8 KB
/
Copy pathgithub_client.py
File metadata and controls
279 lines (248 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import logging
import os
import requests
logger = logging.getLogger("issue_bot")
class GitHubClient:
def __init__(self, cfg):
self._token = cfg.github_token
self._repo = cfg.repo
self._timeout = cfg.github_api_timeout
self._dry_run = cfg.dry_run
self._cfg = cfg
self._repo_root = os.getenv("GITHUB_WORKSPACE", os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
self._headers = {
"Authorization": f"token {self._token}",
"Accept": "application/vnd.github.v3+json",
}
def get_issue(self, number):
return self._get(f"/repos/{self._repo}/issues/{number}")
def get_comments(self, number, max_pages=10):
comments = []
page = 1
while page <= max_pages:
batch = self._get(f"/repos/{self._repo}/issues/{number}/comments?per_page=100&page={page}")
if not batch:
break
comments.extend(batch)
if len(batch) < 100:
break
page += 1
return comments
def get_pr(self, number):
return self._get(f"/repos/{self._repo}/pulls/{number}")
def get_pr_diff(self, number):
headers = {**self._headers, "Accept": "application/vnd.github.v3.diff"}
try:
resp = requests.get(
f"https://api.github.com/repos/{self._repo}/pulls/{number}",
headers=headers, timeout=self._timeout,
)
return resp.text if resp.status_code == 200 else ""
except Exception as e:
logger.error(f"PR diff fetch failed: {e}")
return ""
def get_compare_diff(self, base_sha, head_sha):
"""Fetch the diff between two commits using the Compare API.
Returns the diff text, or empty string on failure (e.g. force-push
where base_sha no longer exists)."""
headers = {**self._headers, "Accept": "application/vnd.github.v3.diff"}
try:
resp = requests.get(
f"https://api.github.com/repos/{self._repo}/compare/{base_sha}...{head_sha}",
headers=headers, timeout=self._timeout,
)
if resp.status_code == 200:
return resp.text
logger.warning(f"Compare API {base_sha[:7]}...{head_sha[:7]}: {resp.status_code}")
return ""
except Exception as e:
logger.error(f"Compare diff failed: {e}")
return ""
def get_ci_status(self, sha):
"""Check commit statuses and check runs. Returns (passed, summary).
passed: True (all green), False (something failed), None (pending/unknown)."""
status = self._get(f"/repos/{self._repo}/commits/{sha}/status")
if status is None:
return None, "CI status unavailable"
combined_state = status.get("state", "pending")
check_data = self._get(f"/repos/{self._repo}/commits/{sha}/check-runs")
runs = check_data.get("check_runs", []) if check_data else []
def _is_own_check(name):
lower = name.lower()
return "bot" in lower and ("analyze" in lower or "/ act" in lower)
external_runs = [r for r in runs if not _is_own_check(r.get("name", ""))]
failed = []
pending = []
for r in external_runs:
if r.get("status") != "completed":
pending.append(r["name"])
elif r.get("conclusion") not in ("success", "neutral", "skipped"):
failed.append(r["name"])
if failed:
return False, f"CI failing: {', '.join(failed)}"
if pending:
return None, f"CI pending: {', '.join(pending)}"
if combined_state == "failure":
return False, "CI failing (status checks)"
if combined_state == "pending":
return None, "CI pending (status checks)"
return True, "CI passed"
def get_pr_files(self, number):
return self._get(f"/repos/{self._repo}/pulls/{number}/files") or []
def get_pr_review_comments(self, number, max_pages=10):
comments = []
page = 1
while page <= max_pages:
batch = self._get(f"/repos/{self._repo}/pulls/{number}/comments?per_page=100&page={page}")
if not batch:
break
comments.extend(batch)
if len(batch) < 100:
break
page += 1
return comments
def get_codebase_map(self):
"""List source files (excluding tests) as relative paths."""
src_dir = self._cfg.codebase_src_dir
file_ext = self._cfg.codebase_file_ext
full_dir = os.path.join(self._repo_root, src_dir)
prefix = self._repo_root.rstrip("/") + "/"
try:
paths = []
for root, dirs, files in os.walk(full_dir):
dirs[:] = [d for d in dirs if d not in ("examples", "__pycache__", ".git", "tests", "test")]
for f in files:
if f.endswith(file_ext):
full = os.path.join(root, f)
rel = full[len(prefix):] if full.startswith(prefix) else full
paths.append(rel)
return "\n".join(sorted(paths))
except Exception as e:
logger.error(f"Codebase map failed: {e}")
return ""
def read_local_file(self, path):
repo_root = os.path.realpath(self._repo_root)
if repo_root == "/":
logger.error("Blocked: repo root is /")
return ""
full_path = os.path.realpath(os.path.join(self._repo_root, path))
if not (full_path.startswith(repo_root + os.sep) or full_path == repo_root):
logger.error(f"Blocked path traversal: {path}")
return ""
try:
with open(full_path, "r", errors="replace") as f:
return f.read()
except Exception:
return ""
def get_file_content(self, path, repo=None, ref=None):
target = repo or self._repo
url = f"https://api.github.com/repos/{target}/contents/{path}"
if ref:
url += f"?ref={ref}"
headers = {**self._headers, "Accept": "application/vnd.github.v3.raw"}
try:
resp = requests.get(url, headers=headers, timeout=self._timeout)
return resp.text if resp.status_code == 200 else ""
except Exception as e:
logger.error(f"File fetch failed ({path}): {e}")
return ""
def post_comment(self, number, body):
if self._dry_run:
logger.info(f"[DRY RUN] Comment on #{number}: {body[:80]}...")
return True
return self._post(f"/repos/{self._repo}/issues/{number}/comments", {"body": body})
def post_pr_review(self, number, summary, inline_comments, event="COMMENT"):
if self._dry_run:
logger.info(f"[DRY RUN] PR review on #{number}: {len(inline_comments)} inline comments, event={event}")
return True
# Get valid diff lines per file from the PR
valid_lines = self._get_valid_diff_lines(number)
valid_comments = []
invalid_comments = []
for ic in inline_comments:
line = ic.get("line")
path = ic.get("file", "")
if line and path in valid_lines and line in valid_lines[path]:
valid_comments.append({"path": path, "body": ic["comment"], "line": line, "side": "RIGHT"})
else:
invalid_comments.append(ic)
body = summary
if invalid_comments:
body += "\n\n**Additional feedback:**\n"
for ic in invalid_comments:
line_ref = f":{ic['line']}" if ic.get('line') else ""
body += f"\n`{ic['file']}{line_ref}` — {ic['comment']}\n"
payload = {"body": body, "event": event}
if valid_comments:
payload["comments"] = valid_comments
try:
resp = requests.post(
f"https://api.github.com/repos/{self._repo}/pulls/{number}/reviews",
headers=self._headers, json=payload, timeout=self._timeout,
)
if resp.status_code in (200, 201):
return True
logger.error(f"PR review API failed: {resp.status_code}, falling back to comment")
logger.error(f"Response: {resp.text[:500]}")
except Exception as e:
logger.error(f"PR review API failed: {e}, falling back to comment")
# Fallback: post as regular comment if review API fails
body = summary
if inline_comments:
body += "\n\n**Inline feedback:**\n"
for ic in inline_comments:
line_ref = f":{ic['line']}" if ic.get('line') else ""
body += f"\n`{ic['file']}{line_ref}` — {ic['comment']}\n"
return self._post(f"/repos/{self._repo}/issues/{number}/comments", {"body": body})
def _get_valid_diff_lines(self, number):
"""Extract valid right-side line numbers from each file's diff hunks."""
import re
valid = {}
files = self.get_pr_files(number)
for f in files:
path = f.get("filename", "")
patch = f.get("patch", "")
if not patch:
continue
lines = set()
current_line = None
for line in patch.split("\n"):
hunk = re.match(r'^@@ -\d+(?:,\d+)? \+(\d+)(?:,\d+)? @@', line)
if hunk:
current_line = int(hunk.group(1))
continue
if current_line is None:
continue
if line.startswith("-"):
continue
if line.startswith("\\"):
continue
lines.add(current_line)
current_line += 1
valid[path] = lines
return valid
def add_labels(self, number, labels):
if not labels:
return True
if self._dry_run:
logger.info(f"[DRY RUN] Labels on #{number}: {labels}")
return True
return self._post(f"/repos/{self._repo}/issues/{number}/labels", {"labels": labels})
def _get(self, path):
try:
resp = requests.get(f"https://api.github.com{path}", headers=self._headers, timeout=self._timeout)
if resp.status_code == 200:
return resp.json()
logger.error(f"GET {path}: {resp.status_code}")
except Exception as e:
logger.error(f"GET {path}: {e}")
return None
def _post(self, path, payload):
try:
resp = requests.post(f"https://api.github.com{path}", headers=self._headers, json=payload, timeout=self._timeout)
if resp.status_code in (200, 201):
return True
logger.error(f"POST {path}: {resp.status_code}")
except Exception as e:
logger.error(f"POST {path}: {e}")
return False