diff --git a/.github/workflows/issue-bot.yml b/.github/workflows/issue-bot.yml new file mode 100644 index 0000000..1351268 --- /dev/null +++ b/.github/workflows/issue-bot.yml @@ -0,0 +1,127 @@ +name: PyDeequ Bot + +on: + issues: + types: [opened, reopened] + pull_request: + types: [opened, reopened, synchronize] + issue_comment: + types: [created] + workflow_dispatch: + inputs: + issue_number: + description: "Issue/PR number to process" + required: true + dry_run: + description: "Dry run (no writes)" + type: boolean + default: true + +# Serialize per issue/PR to prevent duplicate comments +concurrency: + group: bot-${{ github.event.issue.number || github.event.pull_request.number || inputs.issue_number }} + cancel-in-progress: false + +jobs: + analyze: + runs-on: ubuntu-latest + timeout-minutes: 10 + if: >- + (github.event_name == 'workflow_dispatch') || + (github.actor != 'github-actions[bot]' && + (github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request') && + (github.event.issue.pull_request == null || github.event_name == 'pull_request')) + permissions: + contents: read + id-token: write + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + role-to-assume: ${{ secrets.AWS_ROLE_ARN }} + aws-region: us-east-1 + + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: "3.12" + + - name: Install dependencies + run: pip install requests==2.33.1 boto3==1.42.94 + + - name: Run analysis + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REPOSITORY: ${{ github.repository }} + ISSUE_NUMBER: ${{ github.event.issue.number || github.event.pull_request.number || inputs.issue_number }} + EVENT_TYPE: ${{ github.event_name }} + EVENT_ACTION: ${{ github.event.action }} + GITHUB_ACTOR: ${{ github.actor }} + KB_S3_BUCKET: ${{ secrets.KB_S3_BUCKET }} + KB_S3_KEY: ${{ secrets.KB_S3_KEY }} + BEDROCK_MODEL_ID: ${{ secrets.BEDROCK_MODEL_ID }} + GUARDRAIL_ID: ${{ secrets.GUARDRAIL_ID }} + GUARDRAIL_VERSION: ${{ secrets.GUARDRAIL_VERSION }} + ISSUE_CLASSIFY_PROMPT: ${{ secrets.ISSUE_CLASSIFY_PROMPT }} + ISSUE_RESPOND_PROMPT: ${{ secrets.ISSUE_RESPOND_PROMPT }} + PR_FILE_REVIEW_PROMPT: ${{ secrets.PR_FILE_REVIEW_PROMPT }} + FOLLOWUP_PROMPT: ${{ secrets.FOLLOWUP_PROMPT }} + DRY_RUN: ${{ inputs.dry_run || 'false' }} + ARTIFACT_PATH: ${{ runner.temp }}/bot_result.json + run: python -m issue_bot.main analyze + working-directory: scripts + + - name: Upload artifact + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: bot-result + path: ${{ runner.temp }}/bot_result.json + retention-days: 30 + + act: + runs-on: ubuntu-latest + timeout-minutes: 1 + needs: analyze + permissions: + contents: read + issues: write + pull-requests: write + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: "3.12" + + - name: Install dependencies + run: pip install requests==2.33.1 boto3==1.42.94 + + - name: Download artifact + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0 + with: + name: bot-result + path: ${{ runner.temp }} + + - name: Execute actions + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REPOSITORY: ${{ github.repository }} + ISSUE_NUMBER: ${{ github.event.issue.number || github.event.pull_request.number || inputs.issue_number }} + EVENT_TYPE: ${{ github.event_name }} + EVENT_ACTION: ${{ github.event.action }} + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} + DRY_RUN: ${{ inputs.dry_run || 'false' }} + ARTIFACT_PATH: ${{ runner.temp }}/bot_result.json + run: python -m issue_bot.main act + working-directory: scripts diff --git a/scripts/issue_bot/.gitignore b/scripts/issue_bot/.gitignore new file mode 100644 index 0000000..c18dd8d --- /dev/null +++ b/scripts/issue_bot/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/scripts/issue_bot/__init__.py b/scripts/issue_bot/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/scripts/issue_bot/__init__.py @@ -0,0 +1 @@ + diff --git a/scripts/issue_bot/bedrock_client.py b/scripts/issue_bot/bedrock_client.py new file mode 100644 index 0000000..30a0f43 --- /dev/null +++ b/scripts/issue_bot/bedrock_client.py @@ -0,0 +1,105 @@ +import logging + +import boto3 +from botocore.config import Config as BotoConfig +from botocore.exceptions import ClientError, BotoCoreError + +logger = logging.getLogger("issue_bot") + +_CIRCUIT_BREAKER_THRESHOLD = 3 + + +class BedrockClient: + def __init__(self, cfg): + self._model_id = cfg.bedrock_model_id + self._client = boto3.client( + "bedrock-runtime", + config=BotoConfig( + read_timeout=cfg.bedrock_timeout, + connect_timeout=cfg.bedrock_timeout, + retries={"max_attempts": 3, "mode": "adaptive"}, + ), + ) + self._guardrail_id = cfg.guardrail_id + self._guardrail_version = cfg.guardrail_version + self._failures = 0 + self._circuit_open = False # Resets per-process; GHA runs are ephemeral + + @property + def available(self): + return not self._circuit_open + + def invoke(self, system_prompt, user_prompt, max_tokens=4096, + temperature=0.3, json_schema=None): + """Invoke Bedrock Converse API with guardrail on user message only. + + Follows the GlueML pattern (BedrockModelHelper.java): + - system_prompt: Instructions + trusted context (KB, diffs, codebase). + Passed as plain text SystemContentBlock with cachePoint. The + guardrail does NOT assess system prompts without guardContent. + - user_prompt: Untrusted user input (issue title/body, PR title/body, + comments). When guardrail is configured, wrapped in guardContent + so the guardrail scans it for prompt injection. + """ + if self._circuit_open: + logger.warning("Circuit breaker open, skipping Bedrock call") + return None + try: + if self._guardrail_id: + user_content = [{"guardContent": {"text": {"text": user_prompt}}}] + else: + user_content = [{"text": user_prompt}] + + kwargs = { + "modelId": self._model_id, + "messages": [{"role": "user", "content": user_content}], + "inferenceConfig": {"maxTokens": max_tokens, "temperature": temperature}, + } + + if system_prompt: + kwargs["system"] = [ + {"text": system_prompt}, + {"cachePoint": {"type": "default"}}, + ] + + if json_schema: + kwargs["outputConfig"] = { + "textFormat": { + "type": "json_schema", + "structure": {"jsonSchema": { + "schema": json_schema, + "name": "bot_response", + }}, + } + } + + if self._guardrail_id: + kwargs["guardrailConfig"] = { + "guardrailIdentifier": self._guardrail_id, + "guardrailVersion": self._guardrail_version, + "trace": "enabled", + } + + resp = self._client.converse(**kwargs) + + if resp.get("stopReason") == "guardrail_intervened": + logger.warning("Guardrail intervened: %s", resp.get("trace", "")) + return None + + output = resp.get("output", {}).get("message", {}).get("content", []) + if not output: + raise ValueError("Empty Bedrock response") + + self._failures = 0 + usage = resp.get("usage", {}) + logger.info("Bedrock: input=%s, output=%s, cacheRead=%s, cacheWrite=%s", + usage.get("inputTokens"), usage.get("outputTokens"), + usage.get("cacheReadInputTokens"), usage.get("cacheWriteInputTokens")) + return output[0]["text"].strip() + except (ClientError, BotoCoreError, ValueError, ConnectionError) as e: + self._failures += 1 + logger.error(f"Bedrock failed ({self._failures}/{_CIRCUIT_BREAKER_THRESHOLD}): {e}") + if self._failures >= _CIRCUIT_BREAKER_THRESHOLD: + self._circuit_open = True + logger.error("Circuit breaker OPEN") + return None diff --git a/scripts/issue_bot/config.py b/scripts/issue_bot/config.py new file mode 100644 index 0000000..b6fdb0c --- /dev/null +++ b/scripts/issue_bot/config.py @@ -0,0 +1,51 @@ +import os +import sys +import logging + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger("issue_bot") + + +class Config: + def __init__(self): + self.github_token = _require("GITHUB_TOKEN") + self.event_type = _require("EVENT_TYPE") + self.event_action = os.getenv("EVENT_ACTION", "") + self.issue_number = _require("ISSUE_NUMBER") + if not self.issue_number.isdigit(): + logger.error(f"ISSUE_NUMBER must be numeric: {self.issue_number}") + sys.exit(1) + self.repo = _require("GITHUB_REPOSITORY") + self.actor = os.getenv("GITHUB_ACTOR", "") + + self.bedrock_model_id = os.getenv("BEDROCK_MODEL_ID", "us.anthropic.claude-opus-4-6-v1") + + self.kb_s3_bucket = os.getenv("KB_S3_BUCKET", "") + self.kb_s3_key = os.getenv("KB_S3_KEY", "") + + self.slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", "") + self.guardrail_id = os.getenv("GUARDRAIL_ID", "") + self.guardrail_version = os.getenv("GUARDRAIL_VERSION") or "DRAFT" + + self.dry_run = os.getenv("DRY_RUN", "false").lower() == "true" + self.enable_slack = bool(self.slack_webhook_url) + self.enable_repo_search = os.getenv("ENABLE_REPO_SEARCH", "true").lower() == "true" + + self.upstream_repo = os.getenv("UPSTREAM_REPO", "awslabs/python-deequ") + + self.bedrock_timeout = 120 + self.max_context_chars = 200000 + self.max_github_search_results = 8 + self.github_api_timeout = 10 + self.allowed_labels = { + "bug", "enhancement", "question", "documentation", + "help-wanted", "analyzer", "check", "spark-compatibility", "installation", + } + + +def _require(name): + val = os.getenv(name) + if not val: + logger.error(f"Missing required env var: {name}") + sys.exit(1) + return val diff --git a/scripts/issue_bot/github_client.py b/scripts/issue_bot/github_client.py new file mode 100644 index 0000000..82d0ed2 --- /dev/null +++ b/scripts/issue_bot/github_client.py @@ -0,0 +1,220 @@ +import logging +import os +import requests + +logger = logging.getLogger("issue_bot") + + +class GitHubClient: + def __init__(self, cfg): + self._token = cfg.github_token + self._repo = cfg.repo + self._timeout = cfg.github_api_timeout + self._dry_run = cfg.dry_run + self._repo_root = os.getenv("GITHUB_WORKSPACE", os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + self._headers = { + "Authorization": f"token {self._token}", + "Accept": "application/vnd.github.v3+json", + } + + def get_issue(self, number): + return self._get(f"/repos/{self._repo}/issues/{number}") + + def get_comments(self, number, max_pages=10): + comments = [] + page = 1 + while page <= max_pages: + batch = self._get(f"/repos/{self._repo}/issues/{number}/comments?per_page=100&page={page}") + if not batch: + break + comments.extend(batch) + if len(batch) < 100: + break + page += 1 + return comments + + def get_pr(self, number): + return self._get(f"/repos/{self._repo}/pulls/{number}") + + def get_pr_diff(self, number): + headers = {**self._headers, "Accept": "application/vnd.github.v3.diff"} + try: + resp = requests.get( + f"https://api.github.com/repos/{self._repo}/pulls/{number}", + headers=headers, timeout=self._timeout, + ) + return resp.text if resp.status_code == 200 else "" + except Exception as e: + logger.error(f"PR diff fetch failed: {e}") + return "" + + def get_pr_files(self, number): + return self._get(f"/repos/{self._repo}/pulls/{number}/files") or [] + + def get_pr_review_comments(self, number, max_pages=10): + comments = [] + page = 1 + while page <= max_pages: + batch = self._get(f"/repos/{self._repo}/pulls/{number}/comments?per_page=100&page={page}") + if not batch: + break + comments.extend(batch) + if len(batch) < 100: + break + page += 1 + return comments + + def get_codebase_map(self, src_dir="pydeequ"): + """List all Python source files (excluding tests) as relative paths.""" + full_dir = os.path.join(self._repo_root, src_dir) + prefix = self._repo_root.rstrip("/") + "/" + try: + paths = [] + for root, dirs, files in os.walk(full_dir): + dirs[:] = [d for d in dirs if d not in ("tests", "__pycache__", ".git")] + for f in files: + if f.endswith(".py"): + full = os.path.join(root, f) + rel = full[len(prefix):] if full.startswith(prefix) else full + paths.append(rel) + return "\n".join(sorted(paths)) + except Exception as e: + logger.error(f"Codebase map failed: {e}") + return "" + + def read_local_file(self, path): + repo_root = os.path.realpath(self._repo_root) + if repo_root == "/": + logger.error("Blocked: repo root is /") + return "" + full_path = os.path.realpath(os.path.join(self._repo_root, path)) + if not (full_path.startswith(repo_root + os.sep) or full_path == repo_root): + logger.error(f"Blocked path traversal: {path}") + return "" + try: + with open(full_path, "r", errors="replace") as f: + return f.read() + except Exception: + return "" + + def get_file_content(self, path, repo=None, ref=None): + target = repo or self._repo + url = f"https://api.github.com/repos/{target}/contents/{path}" + if ref: + url += f"?ref={ref}" + headers = {**self._headers, "Accept": "application/vnd.github.v3.raw"} + try: + resp = requests.get(url, headers=headers, timeout=self._timeout) + return resp.text if resp.status_code == 200 else "" + except Exception as e: + logger.error(f"File fetch failed ({path}): {e}") + return "" + + def post_comment(self, number, body): + if self._dry_run: + logger.info(f"[DRY RUN] Comment on #{number}: {body[:80]}...") + return True + return self._post(f"/repos/{self._repo}/issues/{number}/comments", {"body": body}) + + def post_pr_review(self, number, summary, inline_comments): + if self._dry_run: + logger.info(f"[DRY RUN] PR review on #{number}: {len(inline_comments)} inline comments") + return True + + # Get valid diff lines per file from the PR + valid_lines = self._get_valid_diff_lines(number) + + valid_comments = [] + invalid_comments = [] + for ic in inline_comments: + line = ic.get("line") + path = ic.get("file", "") + if line and path in valid_lines and line in valid_lines[path]: + valid_comments.append({"path": path, "body": ic["comment"], "line": line, "side": "RIGHT"}) + else: + invalid_comments.append(ic) + + if valid_comments: + body = summary + if invalid_comments: + body += "\n\n**Additional feedback:**\n" + for ic in invalid_comments: + line_ref = f":{ic['line']}" if ic.get('line') else "" + body += f"\n`{ic['file']}{line_ref}` — {ic['comment']}\n" + payload = {"body": body, "event": "REQUEST_CHANGES", "comments": valid_comments} + try: + resp = requests.post( + f"https://api.github.com/repos/{self._repo}/pulls/{number}/reviews", + headers=self._headers, json=payload, timeout=self._timeout, + ) + if resp.status_code in (200, 201): + return True + logger.error(f"PR review API failed: {resp.status_code}, falling back to comment") + except Exception as e: + logger.error(f"PR review API failed: {e}, falling back to comment") + + # Fallback: post all as regular comment + all_comments = inline_comments + body = summary + if all_comments: + body += "\n\n**Inline feedback:**\n" + for ic in all_comments: + line_ref = f":{ic['line']}" if ic.get('line') else "" + body += f"\n`{ic['file']}{line_ref}` — {ic['comment']}\n" + return self._post(f"/repos/{self._repo}/issues/{number}/comments", {"body": body}) + + def _get_valid_diff_lines(self, number): + """Extract valid right-side line numbers from each file's diff hunks.""" + import re + valid = {} + files = self.get_pr_files(number) + for f in files: + path = f.get("filename", "") + patch = f.get("patch", "") + if not patch: + continue + lines = set() + current_line = None + for line in patch.split("\n"): + hunk = re.match(r'^@@ -\d+(?:,\d+)? \+(\d+)(?:,\d+)? @@', line) + if hunk: + current_line = int(hunk.group(1)) + continue + if current_line is None: + continue + if line.startswith("-"): + continue + if line.startswith("\\"): + continue + lines.add(current_line) + current_line += 1 + valid[path] = lines + return valid + + def add_labels(self, number, labels): + if not labels: + return True + if self._dry_run: + logger.info(f"[DRY RUN] Labels on #{number}: {labels}") + return True + return self._post(f"/repos/{self._repo}/issues/{number}/labels", {"labels": labels}) + + def _get(self, path): + try: + resp = requests.get(f"https://api.github.com{path}", headers=self._headers, timeout=self._timeout) + if resp.status_code == 200: + return resp.json() + logger.error(f"GET {path}: {resp.status_code}") + except Exception as e: + logger.error(f"GET {path}: {e}") + return None + + def _post(self, path, payload): + try: + resp = requests.post(f"https://api.github.com{path}", headers=self._headers, json=payload, timeout=self._timeout) + if resp.status_code in (200, 201): + return True + logger.error(f"POST {path}: {resp.status_code}") + except Exception as e: + logger.error(f"POST {path}: {e}") + return False diff --git a/scripts/issue_bot/knowledge_base.py b/scripts/issue_bot/knowledge_base.py new file mode 100644 index 0000000..23ae27d --- /dev/null +++ b/scripts/issue_bot/knowledge_base.py @@ -0,0 +1,52 @@ +import logging +import boto3 + +logger = logging.getLogger("issue_bot") + + +class KnowledgeBase: + def __init__(self, cfg): + self._bucket = cfg.kb_s3_bucket + self._key = cfg.kb_s3_key + self._max_chars = cfg.max_context_chars + self._content = "" + + def load(self): + if self._bucket and self._key: + try: + resp = boto3.client("s3").get_object(Bucket=self._bucket, Key=self._key) + self._content = resp["Body"].read().decode("utf-8") + logger.info(f"KB loaded from S3: {len(self._content)} chars") + return + except Exception as e: + logger.warning(f"S3 KB failed: {e}") + logger.warning("No KB available") + + def build_context(self, issue_text, repo_snippets=""): + parts = [] + if self._content: + parts.append(self._content) + if repo_snippets: + parts.append(f"## Relevant Source Code\n{repo_snippets}") + context = "\n\n".join(parts) + if len(context) > self._max_chars: + context = self._truncate_by_relevance(context, issue_text) + return context + + def _truncate_by_relevance(self, content, issue_text): + keywords = set(issue_text.lower().split()) + sections = content.split("\n## ") + if len(sections) <= 1: + return content[:self._max_chars] + scored = [] + for i, s in enumerate(sections): + score = sum(1 for w in keywords if w in s.lower()) + max(0, 10 - i) + scored.append((score, s)) + scored.sort(key=lambda x: x[0], reverse=True) + result = "" + for _, s in scored: + chunk = f"\n## {s}" if result else s + if len(result) + len(chunk) > self._max_chars: + break + result += chunk + return result diff --git a/scripts/issue_bot/main.py b/scripts/issue_bot/main.py new file mode 100644 index 0000000..16ebafa --- /dev/null +++ b/scripts/issue_bot/main.py @@ -0,0 +1,624 @@ +""" +PyDeequ Bot — two-phase orchestration. + + analyze: read-only phase, produces JSON artifact + act: write-only phase, reads artifact and posts to GitHub/Slack +""" + +import json +import sys +import os +import datetime +import logging +import uuid + +from .config import Config +from .bedrock_client import BedrockClient +from .github_client import GitHubClient +from .knowledge_base import KnowledgeBase +from .slack_client import SlackClient +from .sanitizer import sanitize +from . import prompts + +logger = logging.getLogger("issue_bot") + +ARTIFACT_PATH = os.getenv("ARTIFACT_PATH", "/tmp/bot_result.json") +_MAX_BOT_REPLIES = 2 + + +def _render(template_str, **kwargs): + """Render a prompt template safely using unique tokens per invocation. + Prevents cross-variable injection (user body containing {context} won't leak KB).""" + token_id = uuid.uuid4().hex + tokens = {} + result = template_str + for key, value in kwargs.items(): + token = f"__TMPL_{token_id}_{key}__" + result = result.replace("{" + key + "}", token) + tokens[token] = str(value) + for token, value in tokens.items(): + result = result.replace(token, value) + return result + + +def _load_schema(name): + """Load a JSON schema file from the schemas directory.""" + path = os.path.join(os.path.dirname(__file__), "schemas", name) + with open(path) as f: + return f.read() + + +ISSUE_RESPONSE_SCHEMA = _load_schema("issue_response.json") +PR_REVIEW_SCHEMA = _load_schema("pr_review_response.json") +FOLLOWUP_SCHEMA = _load_schema("followup_response.json") + + +def analyze(): + cfg = Config() + gh = GitHubClient(cfg) + bedrock = BedrockClient(cfg) + kb = KnowledgeBase(cfg) + kb.load() + + number = cfg.issue_number + is_followup = cfg.event_type == "issue_comment" and cfg.event_action == "created" + + item = None + if cfg.event_type == "pull_request": + is_pr = True + elif cfg.event_type in ("issues", "issue_comment"): + is_pr = False + else: + # workflow_dispatch or unknown — check via API + item = gh.get_issue(number) + is_pr = bool(item and item.get("pull_request")) + + if item is None: + item = gh.get_pr(number) if is_pr else gh.get_issue(number) + if not item: + _write_artifact({"action": "SKIP", "reason": "fetch_failed"}) + return + + author = item.get("user", {}).get("login", "") + if author.endswith("[bot]"): + _write_artifact({"action": "SKIP", "reason": "author_is_bot"}) + return + + if item.get("state") == "closed" and not is_pr: + _write_artifact({"action": "SKIP", "reason": "issue_closed"}) + return + + title = item.get("title", "") or "" + body = item.get("body", "") or "" + html_url = item.get("html_url", "") + comments_data = gh.get_comments(number) + comments_text = _format_comments(comments_data) + + is_pr_update = is_pr and cfg.event_action == "synchronize" + is_reopened = not is_pr and cfg.event_action == "reopened" + + if is_reopened: + _write_artifact({ + "action": "ESCALATE", "labels": [], "response": "", + "reason": "issue_reopened", "title": title, + "html_url": html_url, "number": number, "is_pr": is_pr, + "prompt_id": "n/a", "model_id": cfg.bedrock_model_id, + }) + return + + if not is_followup and not is_pr_update and any( + c.get("user", {}).get("login") == "github-actions[bot]" for c in comments_data): + _write_artifact({"action": "SKIP", "reason": "already_commented"}) + return + + if is_followup and comments_data: + if comments_data[-1].get("user", {}).get("login") == "github-actions[bot]": + _write_artifact({"action": "SKIP", "reason": "bot_last_comment"}) + return + if _already_replied_to_latest(comments_data): + _write_artifact({"action": "SKIP", "reason": "already_replied_to_comment"}) + return + if _bot_reply_count(comments_data) >= _MAX_BOT_REPLIES: + _write_artifact({ + "action": "ESCALATE", "labels": [], "response": "", + "reason": "max_replies_reached", "title": title, + "html_url": html_url, "number": number, "is_pr": is_pr, + "prompt_id": "n/a", "model_id": cfg.bedrock_model_id, + }) + return + if _user_dissatisfied(comments_data): + _write_artifact({ + "action": "ESCALATE", "labels": [], "response": "", + "reason": "user_dissatisfied", "title": title, + "html_url": html_url, "number": number, "is_pr": is_pr, + "prompt_id": "n/a", "model_id": cfg.bedrock_model_id, + }) + return + + issue_text = f"{title} {body}" + context = kb.build_context(issue_text) + codebase_map = gh.get_codebase_map() if not is_followup else "" + + if is_pr: + tmpl = prompts.get_pr_file_review_prompt() + if not tmpl: + _write_artifact({"action": "ESCALATE", "labels": [], "response": "", + "reason": "prompt_load_failed", "title": title, "html_url": html_url, + "number": number, "is_pr": True, "prompt_id": "n/a", "model_id": cfg.bedrock_model_id}) + return + diff = gh.get_pr_diff(number) + review_comments = gh.get_pr_review_comments(number) + existing_feedback = _format_pr_feedback(comments_data, review_comments) + # System prompt: instructions + all trusted context (not scanned by guardrail) + system_prompt = _render(tmpl, current_date=datetime.date.today().isoformat()) + ( + f"\n\n\n{context}\n\n" + f"\n{codebase_map}\n\n" + f"\n{diff}\n\n" + f"\n{existing_feedback}\n" + ) + # User prompt: only user-authored content (scanned by guardrail) + user_prompt = f"\nTitle: {title}\nBody: {body}\n" + raw = bedrock.invoke(system_prompt, user_prompt, + max_tokens=4000, json_schema=PR_REVIEW_SCHEMA) + if raw is None: + _write_artifact({ + "action": "ESCALATE", "reason": "bedrock_unavailable", "title": title, + "html_url": html_url, "number": number, "is_pr": True, + "prompt_id": prompts.prompt_version(tmpl), "model_id": cfg.bedrock_model_id, + }) + return + try: + pr_result = json.loads(raw) + inline_comments = pr_result.get("comments", []) + except json.JSONDecodeError: + inline_comments = _parse_file_review_multi(raw) + _write_artifact({ + "action": "RESPOND" if inline_comments else "SKIP", + "labels": [], "response": "", + "inline_comments": inline_comments, + "title": title, "html_url": html_url, "number": number, + "is_pr": True, "prompt_id": prompts.prompt_version(tmpl), + "model_id": cfg.bedrock_model_id, + "reason": "no_issues_found" if not inline_comments else "", + }) + return + + elif is_followup: + tmpl = prompts.get_followup_prompt() + if not tmpl: + _write_artifact({"action": "ESCALATE", "labels": [], "response": "", + "reason": "prompt_load_failed", "title": title, "html_url": html_url, + "number": number, "is_pr": is_pr, "prompt_id": "n/a", "model_id": cfg.bedrock_model_id}) + return + system_prompt = tmpl + f"\n\n\n{context}\n" + user_prompt = f"\nTitle: {title}\nBody: {body}\n\n\n{comments_text}\n" + prompt_id = prompts.prompt_version(tmpl) + else: + tmpl = prompts.get_issue_prompt() + if not tmpl: + _write_artifact({"action": "ESCALATE", "labels": [], "response": "", + "reason": "prompt_load_failed", "title": title, "html_url": html_url, + "number": number, "is_pr": is_pr, "prompt_id": "n/a", "model_id": cfg.bedrock_model_id}) + return + system_prompt = tmpl + ( + f"\n\n\n{context}\n\n" + f"\n{codebase_map}\n" + ) + user_prompt = f"\nTitle: {title}\nBody: {body}\n\n\n{comments_text}\n" + prompt_id = prompts.prompt_version(tmpl) + + schema = FOLLOWUP_SCHEMA if is_followup else ISSUE_RESPONSE_SCHEMA + raw = bedrock.invoke(system_prompt, user_prompt, json_schema=schema) + + if raw is None: + _write_artifact({ + "action": "ESCALATE", "labels": [], "response": "", + "reason": "bedrock_unavailable", "title": title, + "html_url": html_url, "number": number, "is_pr": is_pr, + "prompt_id": prompt_id, "model_id": cfg.bedrock_model_id, + }) + return + + parsed = _parse_response(raw, is_pr) + + if parsed.get("read_files") and cfg.enable_repo_search: + snippets = _read_requested_files(gh, parsed["read_files"], cfg) + if snippets: + respond_tmpl = prompts.get_issue_respond_prompt() + if respond_tmpl: + respond_system = respond_tmpl + ( + f"\n\n\n{context}\n\n" + f"\n{snippets}\n" + ) + respond_user = f"\nTitle: {title}\nBody: {body}\n\n\n{comments_text}\n" + raw2 = bedrock.invoke(respond_system, respond_user, + json_schema=ISSUE_RESPONSE_SCHEMA) + if raw2: + parsed2 = _parse_response(raw2, is_pr) + parsed2["labels"] = parsed2.get("labels") or parsed.get("labels", []) + parsed = parsed2 + + _write_artifact({ + "action": parsed["action"], "labels": parsed.get("labels", []), + "response": parsed.get("response", ""), + "inline_comments": parsed.get("inline_comments", []), + "title": title, "html_url": html_url, "number": number, + "is_pr": is_pr, "prompt_id": prompt_id, "model_id": cfg.bedrock_model_id, + }) + + +def act(): + cfg = Config() + gh = GitHubClient(cfg) + slack = SlackClient(cfg) + + result = _read_artifact() + if not result: + logger.error("No artifact found") + return + + # Validate artifact has required fields + action = result.get("action", "SKIP") + if action not in ("SKIP", "RESPOND", "ESCALATE", "CLOSE"): + logger.error(f"Invalid action in artifact: {action}") + return + + number = result.get("number", cfg.issue_number) + is_pr = result.get("is_pr", False) + title = str(result.get("title", ""))[:200] # Truncate to prevent injection + html_url = result.get("html_url", "") + if html_url and not html_url.startswith("https://github.com/"): + html_url = "" + raw_labels = result.get("labels", []) + if not isinstance(raw_labels, list): + raw_labels = [] + labels = [l for l in raw_labels if isinstance(l, str) and l in cfg.allowed_labels] + response = result.get("response", "") + prompt_id = result.get("prompt_id", "unknown") + model_id = result.get("model_id", "unknown") + + if action == "SKIP": + logger.info(f"Skip #{number}: {result.get('reason')}") + return + + footer = ( + f"\n\n---\n*Generated by AI (model: {model_id}, prompt: {prompt_id}) " + f"— may not be fully accurate. Reply if this doesn't help.*" + ) + + # Pre-process: sanitize response before dispatch + if action == "RESPOND": + safe = sanitize(response) + if safe is None: + action = "ESCALATE" + response = "" + elif not safe and not result.get("inline_comments"): + action = "ESCALATE" + response = "" + else: + response = safe or "" + + if action == "RESPOND": + inline_comments = result.get("inline_comments", []) + # Sanitize inline comment text and keep the sanitized version + sanitized_comments = [] + for ic in inline_comments: + safe_comment = sanitize(ic.get("comment", "")) + if safe_comment is not None: + sanitized_comments.append({**ic, "comment": safe_comment}) + inline_comments = sanitized_comments + if is_pr and inline_comments: + gh.post_pr_review(number, response + footer, inline_comments) + else: + gh.post_comment(number, response + footer) + gh.add_labels(number, labels) + if "bug" in labels: + slack.send_escalation(number, title, html_url, labels) + elif "enhancement" in labels: + slack.send_escalation(number, title, html_url, labels) + logger.info(f"Responded to #{number}") + + elif action == "ESCALATE": + reason = result.get("reason", "") + if reason == "user_dissatisfied": + ack = ( + "I understand my previous response wasn't helpful. " + "I've notified the maintainer team and they will follow up directly." + footer + ) + elif reason == "max_replies_reached": + ack = ( + "I've reached the limit of what I can assist with on this issue. " + "The maintainer team has been notified and will take over." + footer + ) + elif reason == "issue_reopened": + ack = ( + "This issue has been reopened. " + "A maintainer has been notified and will follow up." + footer + ) + else: + if response: + ack = ( + response + "\n\n" + "This has also been flagged for our maintainer team to review." + footer + ) + else: + ack = ( + "Thank you for reporting this.\n\n" + "This has been flagged for review by our maintainer team. " + "We'll get back to you as soon as possible." + footer + ) + gh.post_comment(number, ack) + gh.add_labels(number, labels) + slack.send_escalation(number, title, html_url, labels) + logger.info(f"Escalated #{number}") + + elif action == "CLOSE" and not is_pr: + msg = ( + "This issue may not be related to the PyDeequ data quality library. " + "The maintainer team has been notified and will review." + footer + ) + gh.post_comment(number, msg) + gh.add_labels(number, labels) + slack.send_escalation(number, title, html_url, labels) + logger.info(f"Flagged #{number} as potentially off-topic") + + else: + logger.warning(f"Unhandled action '{action}' for #{number}, escalating") + gh.post_comment(number, "This has been flagged for review by our maintainer team." + footer) + slack.send_escalation(number, title, html_url, labels) + + +def _bot_reply_count(comments): + return sum(1 for c in comments if c.get("user", {}).get("login") == "github-actions[bot]") + + +def _already_replied_to_latest(comments): + """True if the bot already posted after the most recent non-bot comment.""" + last_user_idx = -1 + last_bot_idx = -1 + for i, c in enumerate(comments): + if c.get("user", {}).get("login") == "github-actions[bot]": + last_bot_idx = i + else: + last_user_idx = i + return last_bot_idx > last_user_idx >= 0 + + +_DISSATISFACTION_SIGNALS = [ + "that's wrong", "thats wrong", "that is wrong", + "this is wrong", "this is incorrect", "incorrect answer", + "didn't help", "doesn't help", "not helpful", "unhelpful", + "wrong answer", "bad answer", "not correct", "that's not right", + "still broken", "still not working", "doesn't work", + "please escalate", "need a human", "talk to a human", + "maintainer", "real person", +] + + +def _user_dissatisfied(comments): + bot_has_replied = any(c.get("user", {}).get("login") == "github-actions[bot]" for c in comments) + if not bot_has_replied: + return False + for c in reversed(comments): + login = c.get("user", {}).get("login", "") + if login == "github-actions[bot]": + break + if not login: + continue + body = (c.get("body") or "").lower() + if any(s in body for s in _DISSATISFACTION_SIGNALS): + return True + return False + + +_HEADER_PREFIXES = ("ACTION:", "LABELS:", "READ_FILES:", "SEARCH:", "SEARCH_TERMS:") + + +def _parse_response(raw, is_pr): + # Try structured JSON first (from Bedrock structured output) + try: + parsed = json.loads(raw) + result = { + "action": parsed.get("action", "ESCALATE"), + "labels": parsed.get("labels", []), + "read_files": parsed.get("read_files", []), + "response": parsed.get("response", ""), + "inline_comments": [], + } + if is_pr and result["action"] == "CLOSE": + result["action"] = "ESCALATE" + return result + except (json.JSONDecodeError, TypeError): + pass + + # Fallback: parse free-text format + lines = raw.strip().split("\n") + result = {"action": "ESCALATE", "labels": [], "response": "", "read_files": [], "inline_comments": []} + response_lines = [] + + for line in lines: + upper = line.strip().upper() + if upper.startswith("ACTION:"): + val = line.split(":", 1)[1].strip().upper() + if val in ("RESPOND", "ESCALATE", "CLOSE"): + result["action"] = val + continue + elif upper.startswith("LABELS:"): + raw_labels = line.split(":", 1)[1].strip() + result["labels"] = [l.strip() for l in raw_labels.split(",") if l.strip().lower() not in ("none", "")] + continue + elif upper.startswith("READ_FILES:"): + raw_files = line.split(":", 1)[1].strip() + result["read_files"] = [f.strip() for f in raw_files.split(",") if f.strip().lower() not in ("none", "")] + continue + elif upper.startswith(("SEARCH:", "SEARCH_TERMS:")): + continue + response_lines.append(line) + + full_text = "\n".join(response_lines).strip() + + if is_pr and "INLINE:" in full_text and "FILE:" in full_text: + result["response"], result["inline_comments"] = _parse_pr_review(full_text) + else: + result["response"] = _clean_response(full_text) + + if is_pr and result["action"] == "CLOSE": + result["action"] = "ESCALATE" + return result + + +def _parse_file_review_multi(raw): + """Parse multi-file review output into inline comments.""" + comments = [] + current_file = None + current_line = None + current_comment = [] + + for line in raw.strip().split("\n"): + stripped = line.strip() + upper = stripped.upper() + if upper.startswith("FILE:"): + if current_file and current_line and current_comment: + comments.append({"file": current_file, "line": current_line, "comment": "\n".join(current_comment).strip()}) + current_file = stripped.split(":", 1)[1].strip() + current_line = None + current_comment = [] + elif upper.startswith("LINE:"): + if current_file and current_line and current_comment: + comments.append({"file": current_file, "line": current_line, "comment": "\n".join(current_comment).strip()}) + try: + current_line = int(stripped.split(":", 1)[1].strip()) + current_comment = [] + except ValueError: + current_line = None + elif upper.startswith("COMMENT:"): + current_comment = [stripped.split(":", 1)[1].strip()] + elif current_comment is not None and current_file: + current_comment.append(stripped) + + if current_file and current_line and current_comment: + comments.append({"file": current_file, "line": current_line, "comment": "\n".join(current_comment).strip()}) + + return comments + + + + +def _parse_pr_review(text): + """Split PR review into summary and inline comments.""" + summary_part = "" + inline_comments = [] + + parts = text.split("INLINE:") + summary_part = parts[0].replace("SUMMARY:", "").strip() + + if len(parts) > 1: + inline_text = parts[1].strip() + if inline_text.lower() == "none": + return _clean_response(summary_part), [] + + current = {} + for line in inline_text.split("\n"): + stripped = line.strip() + upper = stripped.upper() + if upper.startswith("FILE:"): + if current.get("file") and current.get("comment"): + inline_comments.append(current) + current = {"file": stripped.split(":", 1)[1].strip()} + elif upper.startswith("LINE:"): + try: + current["line"] = int(stripped.split(":", 1)[1].strip()) + except ValueError: + pass + elif upper.startswith("COMMENT:"): + current["comment"] = stripped.split(":", 1)[1].strip() + elif current.get("comment"): + current["comment"] += "\n" + stripped + + if current.get("file") and current.get("comment"): + inline_comments.append(current) + + return _clean_response(summary_part), inline_comments + + +def _clean_response(text): + """Remove any leaked headers or internal thinking from the response.""" + lines = text.split("\n") + cleaned = [] + for line in lines: + upper = line.strip().upper() + if upper.startswith(_HEADER_PREFIXES): + continue + cleaned.append(line) + result = "\n".join(cleaned).strip() + # Remove leading preamble like "Let me request..." or "I'll analyze..." + while result and result.split("\n")[0].strip().lower().startswith(( + "let me ", "i'll ", "i will ", "i need to ", "first,", "sure,", + "since i don't", "since i do not", + )): + result = "\n".join(result.split("\n")[1:]).strip() + return result + + +def _format_comments(comments): + if not comments: + return "(none)" + return "\n".join( + f"{c.get('user', {}).get('login', '?')}: {c.get('body', '') or ''}" + for c in comments + ) + + +def _format_pr_feedback(issue_comments, review_comments): + parts = [] + for c in issue_comments: + author = c.get("user", {}).get("login", "?") + body = c.get("body", "") or "" + parts.append(f"{author}: {body}") + for c in review_comments: + author = c.get("user", {}).get("login", "?") + path = c.get("path", "") + line = c.get("line") or c.get("original_line") or "?" + body = c.get("body", "") or "" + parts.append(f"{author} on {path}:{line}: {body}") + return "\n".join(parts) if parts else "(no existing feedback)" + + +def _read_requested_files(gh, file_paths, cfg): + snippets = [] + for path in file_paths[:cfg.max_github_search_results]: + if ".." in path or path.startswith("/"): + continue + content = gh.read_local_file(path) + if not content: + content = gh.get_file_content(path, repo=cfg.upstream_repo) + if content: + snippets.append(f"### {path}\n```python\n{content}\n```") + return "\n\n".join(snippets) + + +def _write_artifact(data): + os.makedirs(os.path.dirname(ARTIFACT_PATH) or "/tmp", exist_ok=True) + with open(ARTIFACT_PATH, "w") as f: + json.dump(data, f) + logger.info(f"Artifact: action={data.get('action')}") + + +def _read_artifact(): + try: + with open(ARTIFACT_PATH) as f: + return json.load(f) + except Exception as e: + logger.error(f"Artifact read failed: {e}") + return None + + +def main(): + if len(sys.argv) < 2 or sys.argv[1] not in ("analyze", "act"): + print("Usage: python -m issue_bot.main ") + sys.exit(1) + {"analyze": analyze, "act": act}[sys.argv[1]]() + + +if __name__ == "__main__": + main() diff --git a/scripts/issue_bot/prompts.py b/scripts/issue_bot/prompts.py new file mode 100644 index 0000000..e2ff385 --- /dev/null +++ b/scripts/issue_bot/prompts.py @@ -0,0 +1,22 @@ +import hashlib +import os + + +def get_issue_prompt(): + return os.getenv("ISSUE_CLASSIFY_PROMPT", "") + + +def get_issue_respond_prompt(): + return os.getenv("ISSUE_RESPOND_PROMPT", "") + + +def get_pr_file_review_prompt(): + return os.getenv("PR_FILE_REVIEW_PROMPT", "") + + +def get_followup_prompt(): + return os.getenv("FOLLOWUP_PROMPT", "") + + +def prompt_version(template): + return hashlib.sha256(template.encode()).hexdigest()[:8] diff --git a/scripts/issue_bot/sanitizer.py b/scripts/issue_bot/sanitizer.py new file mode 100644 index 0000000..07fe24f --- /dev/null +++ b/scripts/issue_bot/sanitizer.py @@ -0,0 +1,55 @@ +import re +import logging + +logger = logging.getLogger("issue_bot") + +# Primary defense: Bedrock Guardrails. These are the required backup layer. +_SECRET_PATTERNS = [ + re.compile(r"AKIA[0-9A-Z]{16}"), + re.compile(r"ghp_[0-9a-zA-Z]{36}"), + re.compile(r"gho_[0-9a-zA-Z]{36}"), + re.compile(r"ghs_[0-9a-zA-Z]{36}"), + re.compile(r"github_pat_[A-Za-z0-9_]{22,}"), + re.compile(r"xox[bpras]-[A-Za-z0-9\-]+"), + re.compile(r"https?://[^\s]*\.corp\.amazon\.com[^\s]*"), + re.compile(r"https?://[^\s]*\.a2z\.com[^\s]*"), + re.compile(r"https?://[^\s]*\.amazon\.dev[^\s]*"), + re.compile(r"hooks\.slack\.com/services/\S+"), +] + +_INJECTION_MARKERS = [ + "my system prompt is", + "my instructions are", + "here are my internal", + "ignore previous instructions", +] + + +def sanitize(text): + if not text: + return text + for p in _SECRET_PATTERNS: + if p.search(text): + logger.error(f"BLOCKED: secret pattern {p.pattern}") + return None + lower = text.lower() + for m in _INJECTION_MARKERS: + if m in lower: + logger.error(f"BLOCKED: injection marker '{m}'") + return None + text = _fix_accidental_issue_refs(text) + return text + + +def _fix_accidental_issue_refs(text): + """Wrap #N references outside code blocks in backticks to prevent GitHub auto-linking.""" + lines = text.split("\n") + in_code_block = False + fixed = [] + for line in lines: + if line.strip().startswith("```"): + in_code_block = not in_code_block + if not in_code_block: + line = re.sub(r'(?", ">") + label_text = ", ".join(f"`{l}`" for l in labels) if labels else "_none_" + text = ( + f"*PyDeequ Issue #{number}*\n" + f">{safe_title}\n\n" + f"*Labels:* {label_text}\n" + f"*Status:* Bot posted analysis on the issue\n\n" + f"<{url}|View on GitHub>" + ) + self._send({"text": text}) + + def _send(self, payload): + try: + resp = requests.post(self._webhook, json=payload, timeout=10) + if resp.status_code != 200: + logger.error(f"Slack: {resp.status_code}") + except Exception as e: + logger.error(f"Slack failed: {e}") diff --git a/tests/test_bot.py b/tests/test_bot.py new file mode 100644 index 0000000..77992dc --- /dev/null +++ b/tests/test_bot.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +"""Unit tests for the issue bot parsing and validation functions.""" +import json +import sys +import os + +import pytest + +# Add scripts dir to path so we can import issue_bot +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts")) + +from issue_bot.main import ( + _parse_response, + _parse_file_review_multi, + _already_replied_to_latest, + _bot_reply_count, + _user_dissatisfied, + _clean_response, + _render, +) +from issue_bot.sanitizer import sanitize, _fix_accidental_issue_refs + + +class TestParseResponse: + @pytest.mark.parametrize("action", ["RESPOND", "ESCALATE", "CLOSE"]) + def test_json_actions(self, action): + raw = json.dumps({"action": action, "labels": [], "read_files": [], "response": "text"}) + assert _parse_response(raw, is_pr=False)["action"] == action + + def test_json_with_labels_and_files(self): + raw = json.dumps({"action": "RESPOND", "labels": ["bug", "question"], + "read_files": ["pydeequ/checks.py"], "response": ""}) + result = _parse_response(raw, is_pr=False) + assert result["labels"] == ["bug", "question"] + assert result["read_files"] == ["pydeequ/checks.py"] + + def test_close_on_pr_becomes_escalate(self): + raw = json.dumps({"action": "CLOSE", "labels": [], "read_files": [], "response": ""}) + assert _parse_response(raw, is_pr=True)["action"] == "ESCALATE" + + def test_fallback_to_text_parsing(self): + raw = "ACTION: RESPOND\nLABELS: bug, question\nREAD_FILES: none\n\nHere is the answer." + result = _parse_response(raw, is_pr=False) + assert result["action"] == "RESPOND" + assert "bug" in result["labels"] + assert "Here is the answer." in result["response"] + + def test_text_defaults_to_escalate(self): + raw = "Some unstructured text without headers" + result = _parse_response(raw, is_pr=False) + assert result["action"] == "ESCALATE" + + def test_empty_json_defaults(self): + raw = json.dumps({}) + result = _parse_response(raw, is_pr=False) + assert result["action"] == "ESCALATE" + assert result["labels"] == [] + + +class TestParseFileReviewMulti: + def test_single_comment(self): + raw = "FILE: src/foo.py\nLINE: 42\nCOMMENT: Missing null check" + comments = _parse_file_review_multi(raw) + assert len(comments) == 1 + assert comments[0] == {"file": "src/foo.py", "line": 42, "comment": "Missing null check"} + + def test_multiple_comments(self): + raw = "FILE: a.py\nLINE: 1\nCOMMENT: issue one\nFILE: b.py\nLINE: 2\nCOMMENT: issue two" + assert len(_parse_file_review_multi(raw)) == 2 + + def test_multiline_comment(self): + raw = "FILE: a.py\nLINE: 10\nCOMMENT: first line\nsecond line" + comments = _parse_file_review_multi(raw) + assert "second line" in comments[0]["comment"] + + def test_invalid_line_number_skipped(self): + raw = "FILE: a.py\nLINE: not_a_number\nCOMMENT: bad" + assert len(_parse_file_review_multi(raw)) == 0 + + def test_empty_input(self): + assert _parse_file_review_multi("") == [] + + +class TestSanitize: + def test_none_passthrough(self): + assert sanitize(None) is None + + def test_empty_passthrough(self): + assert sanitize("") == "" + + def test_clean_text_passes(self): + assert sanitize("Normal response about PyDeequ.") is not None + + @pytest.mark.parametrize("marker", [ + "my system prompt is", + "here are my internal", + "ignore previous instructions", + ]) + def test_blocks_injection_markers(self, marker): + assert sanitize(f"Some text with {marker} embedded") is None + + +class TestFixIssueRefs: + def test_wraps_in_backticks(self): + assert _fix_accidental_issue_refs("see #42") == "see `#42`" + + def test_preserves_code_blocks(self): + text = "```\n#42\n```" + assert _fix_accidental_issue_refs(text) == text + + def test_no_match_on_non_numeric(self): + assert _fix_accidental_issue_refs("#abc") == "#abc" + + def test_multiple_refs(self): + result = _fix_accidental_issue_refs("fixes #1 and #2") + assert "`#1`" in result + assert "`#2`" in result + + +def _make_comment(login, body="text"): + return {"user": {"login": login}, "body": body} + + +class TestBotReplyCount: + def test_zero(self): + assert _bot_reply_count([_make_comment("user1")]) == 0 + + def test_counts_bot_only(self): + comments = [_make_comment("user1"), _make_comment("github-actions[bot]"), + _make_comment("user2"), _make_comment("github-actions[bot]")] + assert _bot_reply_count(comments) == 2 + + +class TestAlreadyRepliedToLatest: + def test_bot_after_user(self): + assert _already_replied_to_latest( + [_make_comment("user1"), _make_comment("github-actions[bot]")]) is True + + def test_user_after_bot(self): + assert _already_replied_to_latest( + [_make_comment("github-actions[bot]"), _make_comment("user1")]) is False + + def test_empty(self): + assert _already_replied_to_latest([]) is False + + +class TestUserDissatisfied: + def test_no_bot_reply_means_not_dissatisfied(self): + assert _user_dissatisfied([_make_comment("user1", "that's wrong")]) is False + + def test_dissatisfied_after_bot(self): + comments = [_make_comment("github-actions[bot]"), _make_comment("user1", "that's wrong")] + assert _user_dissatisfied(comments) is True + + def test_happy_after_bot(self): + comments = [_make_comment("github-actions[bot]"), _make_comment("user1", "thanks!")] + assert _user_dissatisfied(comments) is False + + @pytest.mark.parametrize("signal", [ + "didn't help", "not helpful", "still broken", "please escalate", "need a human", + ]) + def test_various_signals(self, signal): + comments = [_make_comment("github-actions[bot]"), _make_comment("user1", signal)] + assert _user_dissatisfied(comments) is True + + +class TestRender: + def test_basic(self): + assert _render("Hello {name}", name="world") == "Hello world" + + def test_braces_in_value_dont_crash(self): + result = _render("Title: {title}", title="Fix {broken} thing") + assert "{broken}" in result + + def test_missing_var_preserved(self): + result = _render("{present} {missing}", present="yes") + assert "yes" in result + assert "{missing}" in result + + def test_no_cross_variable_injection(self): + """User content containing {context} must NOT leak the actual context value.""" + result = _render("KB: {context}\nBody: {body}", context="SECRET", body="{context}") + assert result == "KB: SECRET\nBody: {context}" + + def test_no_reverse_injection(self): + """Context containing {body} must NOT be replaced by body value.""" + result = _render("KB: {context}\nBody: {body}", context="{body}", body="SECRET") + assert result == "KB: {body}\nBody: SECRET" + + +class TestCleanResponse: + def test_strips_header_lines(self): + text = "ACTION: RESPOND\nLABELS: bug\nActual response here" + assert "Actual response here" in _clean_response(text) + assert "ACTION:" not in _clean_response(text) + + def test_strips_preamble(self): + text = "Let me analyze this issue.\nThe actual answer." + assert _clean_response(text) == "The actual answer." + + def test_preserves_normal_text(self): + text = "This is a normal response." + assert _clean_response(text) == text + + +class TestSmoke: + def test_main_module_imports(self): + from issue_bot import main + assert hasattr(main, 'analyze') + assert hasattr(main, 'act') + + def test_sanitizer_imports(self): + from issue_bot import sanitizer + assert hasattr(sanitizer, 'sanitize') + + def test_schemas_loadable(self): + from issue_bot.main import ISSUE_RESPONSE_SCHEMA, PR_REVIEW_SCHEMA, FOLLOWUP_SCHEMA + import json + assert json.loads(ISSUE_RESPONSE_SCHEMA)["type"] == "object" + assert json.loads(PR_REVIEW_SCHEMA)["type"] == "object" + assert json.loads(FOLLOWUP_SCHEMA)["type"] == "object" + + +class TestArtifactValidation: + def test_invalid_action_rejected(self): + """Actions not in the allowed set should be treated as invalid.""" + valid = {"SKIP", "RESPOND", "ESCALATE", "CLOSE"} + assert "DROP_TABLE" not in valid + assert "RESPOND" in valid + + def test_title_truncated(self): + title = "A" * 500 + truncated = str(title)[:200] + assert len(truncated) == 200 + + def test_non_github_url_cleared(self): + url = "https://evil.com/steal" + result = "" if not url.startswith("https://github.com/") else url + assert result == "" + + def test_github_url_preserved(self): + url = "https://github.com/awslabs/python-deequ/issues/1" + result = "" if not url.startswith("https://github.com/") else url + assert result == url + + def test_empty_url_preserved(self): + url = "" + result = "" if url and not url.startswith("https://github.com/") else url + assert result == "" + + +class TestSplitPrompt: + """Test that invoke() follows GlueML pattern: system=trusted, user=guarded.""" + + def _make_client(self, guardrail_id=""): + class FakeCfg: + bedrock_model_id = "test" + bedrock_timeout = 10 + guardrail_id = "" + guardrail_version = "DRAFT" + cfg = FakeCfg() + cfg.guardrail_id = guardrail_id + from issue_bot.bedrock_client import BedrockClient + import unittest.mock as mock + with mock.patch("boto3.client"): + client = BedrockClient(cfg) + return client + + def _mock_converse(self, client): + import unittest.mock as mock + client._client = mock.MagicMock() + client._client.converse.return_value = { + "stopReason": "end_turn", + "output": {"message": {"content": [{"text": "ok"}]}}, + "usage": {}, + } + + def test_with_guardrail_user_is_guardcontent(self): + client = self._make_client(guardrail_id="gr-123") + self._mock_converse(client) + client.invoke("system instructions", "user input") + kwargs = client._client.converse.call_args[1] + content = kwargs["messages"][0]["content"] + assert len(content) == 1 + assert "guardContent" in content[0] + assert content[0]["guardContent"]["text"]["text"] == "user input" + + def test_without_guardrail_user_is_text(self): + client = self._make_client() + self._mock_converse(client) + client.invoke("system", "user input") + kwargs = client._client.converse.call_args[1] + content = kwargs["messages"][0]["content"] + assert len(content) == 1 + assert "text" in content[0] + assert content[0]["text"] == "user input" + + def test_system_prompt_is_plain_text_cached(self): + client = self._make_client(guardrail_id="gr-123") + self._mock_converse(client) + client.invoke("instructions + diff with ignore previous instructions", "Title: test") + kwargs = client._client.converse.call_args[1] + system = kwargs["system"] + assert system[0]["text"] == "instructions + diff with ignore previous instructions" + assert "cachePoint" in system[1] + # System prompt is NOT guardContent — guardrail won't scan it + assert "guardContent" not in system[0] + + def test_guardrail_config_present(self): + client = self._make_client(guardrail_id="gr-123") + self._mock_converse(client) + client.invoke("system", "user") + kwargs = client._client.converse.call_args[1] + assert "guardrailConfig" in kwargs + assert kwargs["guardrailConfig"]["guardrailIdentifier"] == "gr-123" + + def test_no_guardrail_no_config(self): + client = self._make_client() + self._mock_converse(client) + client.invoke("system", "user") + kwargs = client._client.converse.call_args[1] + assert "guardrailConfig" not in kwargs