diff --git a/docs/docs/tools/improve.md b/docs/docs/tools/improve.md index 11f3b9cf9b..273243703b 100644 --- a/docs/docs/tools/improve.md +++ b/docs/docs/tools/improve.md @@ -107,10 +107,22 @@ Use triple quotes to write multi-line instructions. Use bullet points or numbers ### Best practices -`Platforms supported: GitHub, GitLab, Bitbucket` +`Platforms supported: GitHub, GitLab, Bitbucket, Azure DevOps` PR-Agent supports both simple and hierarchical best practices configurations to provide guidance to the AI model for generating relevant code suggestions. +!!! info "Open-source `pr-agent`" + The OSS `pr-agent` package automatically loads `best_practices.md` from the repository's default branch on every `improve` and `review` run, truncates it to `[best_practices].max_lines_allowed` (default 800), and feeds it to the model as a labeled block in each tool's prompt. + + To opt out, add to your `.pr_agent.toml`: + + ```toml + [best_practices] + enable_repo_best_practices_md = false + # Or override the default file path: + # repo_best_practices_md_path = "docs/best_practices.md" + ``` + ???- tip "Writing effective best practices files" The following guidelines apply to all best practices files: diff --git a/docs/docs/tools/review.md b/docs/docs/tools/review.md index 6ee638e711..de21e9b57a 100644 --- a/docs/docs/tools/review.md +++ b/docs/docs/tools/review.md @@ -49,6 +49,9 @@ extra_instructions = "..." - The `pr_commands` lists commands that will be executed automatically when a PR is opened. - The `[pr_reviewer]` section contains the configurations for the `review` tool you want to edit (if any). +!!! info "Open-source `pr-agent`" + The OSS `pr-agent` package automatically loads `best_practices.md` from the repository's default branch on every `review` run, truncates it to `[best_practices].max_lines_allowed` (default 800), and feeds it to the model as a labeled block in the `review` prompt. To opt out, set `[best_practices].enable_repo_best_practices_md = false` in `.pr_agent.toml` (the same flag also gates `/improve`). + ## Configuration options ???+ example "General options" diff --git a/pr_agent/algo/best_practices.py b/pr_agent/algo/best_practices.py new file mode 100644 index 0000000000..c1637b52d8 --- /dev/null +++ b/pr_agent/algo/best_practices.py @@ -0,0 +1,63 @@ +from starlette_context import context + +from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger + + +def load_repo_best_practices_md(git_provider, tool_name: str = "improve") -> str: + """Fetch best_practices.md from the repo default branch. + + Returns text (possibly truncated to ``[best_practices].max_lines_allowed``) + or an empty string when disabled, missing, or unreadable. Result is cached + in starlette_context for the duration of the request so multiple tools + share a single fetch. + """ + settings = get_settings() + if not settings.get("best_practices.enable_repo_best_practices_md", True): + return "" + try: + cached = context.get("best_practices_md", None) + except Exception: + cached = None + if cached is not None: + return cached + file_path = settings.get("best_practices.repo_best_practices_md_path", "best_practices.md") or "best_practices.md" + raw = b"" + try: + raw = git_provider.get_pr_agent_repo_custom_file(file_path) or b"" + except Exception as e: + get_logger().warning(f"Failed to fetch {file_path} from repo: {e}") + if isinstance(raw, (bytes, bytearray)): + text = raw.decode("utf-8", errors="replace") + else: + text = str(raw or "") + if not text.strip(): + try: + context["best_practices_md"] = "" + except Exception: + pass + return "" + line_count = text.count("\n") + 1 + get_logger().info( + f"Loaded {file_path} from repo ({len(text)} bytes, {line_count} lines) for '{tool_name}' tool" + ) + raw_max_lines = settings.get("best_practices.max_lines_allowed", 800) + try: + max_lines = int(raw_max_lines) if raw_max_lines else 800 + except (TypeError, ValueError): + get_logger().warning( + f"Invalid best_practices.max_lines_allowed={raw_max_lines!r}; falling back to 800" + ) + max_lines = 800 + lines = text.splitlines() + if len(lines) > max_lines: + get_logger().warning( + f"Truncating {file_path} from {len(lines)} to {max_lines} lines " + f"(see [best_practices].max_lines_allowed)" + ) + text = "\n".join(lines[:max_lines]) + try: + context["best_practices_md"] = text + except Exception: + pass + return text diff --git a/pr_agent/git_providers/azuredevops_provider.py b/pr_agent/git_providers/azuredevops_provider.py index b9d2f3990c..db4bc225a2 100644 --- a/pr_agent/git_providers/azuredevops_provider.py +++ b/pr_agent/git_providers/azuredevops_provider.py @@ -174,6 +174,21 @@ def get_repo_settings(self): get_logger().error(f"Failed to get repo settings, error: {e}") return "" + def get_pr_agent_repo_custom_file(self, file_path: str) -> bytes: + try: + contents = self.azure_devops_client.get_item_content( + repository_id=self.repo_slug, + project=self.workspace_slug, + download=False, + include_content_metadata=False, + include_content=True, + path=file_path, + ) + chunks = [c if isinstance(c, (bytes, bytearray)) else str(c).encode("utf-8") for c in contents] + return b"".join(chunks) + except Exception: + return b"" + def get_files(self): files = [] for i in self.azure_devops_client.get_pull_request_commits( diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index 6944e41fa2..5a8b00dc83 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -89,6 +89,23 @@ def get_repo_settings(self): except Exception: return "" + def get_pr_agent_repo_custom_file(self, file_path: str) -> bytes: + try: + branch = self.get_repo_default_branch() + url = (f"https://api.bitbucket.org/2.0/repositories/{self.workspace_slug}/{self.repo_slug}/src/" + f"{branch}/{file_path}") + response = requests.request("GET", url, headers=self.headers, timeout=10) + if response.status_code != 200: + if response.status_code != 404: + get_logger().warning( + f"Failed to fetch {file_path} from Bitbucket " + f"(status={response.status_code})" + ) + return b"" + return response.text.encode("utf-8") + except Exception: + return b"" + def get_git_repo_url(self, pr_url: str=None) -> str: #bitbucket does not support issue url, so ignore param try: parsed_url = urlparse(self.pr_url) diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 631e189c04..27e497dda4 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -274,6 +274,10 @@ def _is_generated_by_pr_agent(self, description_lowercase: str) -> bool: def get_repo_settings(self): pass + def get_pr_agent_repo_custom_file(self, file_path: str) -> bytes: + """Fetch a file from the repo (default branch). Empty bytes if missing or unsupported.""" + return b"" + def get_workspace_name(self): return "" diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index fa52b7dc05..790dbc51d1 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -740,6 +740,12 @@ def get_repo_settings(self): except Exception: return "" + def get_pr_agent_repo_custom_file(self, file_path: str) -> bytes: + try: + return self.repo_obj.get_contents(file_path).decoded_content + except Exception: + return b"" + def get_workspace_name(self): return self.repo.split('/')[0] diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index b3f54920d0..c556d5a9a0 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -797,6 +797,13 @@ def get_repo_settings(self): except Exception: return "" + def get_pr_agent_repo_custom_file(self, file_path: str) -> bytes: + try: + main_branch = self.gl.projects.get(self.id_project).default_branch + return self.gl.projects.get(self.id_project).files.get(file_path=file_path, ref=main_branch).decode() + except Exception: + return b"" + def get_workspace_name(self): return self.id_project.split('/')[0] diff --git a/pr_agent/git_providers/local_git_provider.py b/pr_agent/git_providers/local_git_provider.py index 420289761c..3fb900e1df 100644 --- a/pr_agent/git_providers/local_git_provider.py +++ b/pr_agent/git_providers/local_git_provider.py @@ -150,6 +150,23 @@ def get_commit_messages(self): def get_repo_settings(self): pass # Not applicable to the local git provider, but required by the interface + def get_pr_agent_repo_custom_file(self, file_path: str) -> bytes: + try: + repo_root = Path(self.repo.working_tree_dir).resolve() + candidate = (repo_root / file_path).resolve() + try: + candidate.relative_to(repo_root) + except ValueError: + get_logger().warning( + f"Refusing to read {file_path}: path escapes repo root" + ) + return b"" + if not candidate.is_file(): + return b"" + return candidate.read_bytes() + except Exception: + return b"" + def remove_reaction(self, comment): pass # Not applicable to the local git provider, but required by the interface diff --git a/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts.toml b/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts.toml index 36b4d0dcf6..e35d772539 100644 --- a/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts.toml +++ b/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts.toml @@ -73,6 +73,15 @@ Specific guidelines for generating code suggestions: - Be aware that your input consists only of partial code segments (PR diff code), not the complete codebase. Therefore, avoid making suggestions that might duplicate existing functionality, and refrain from questioning code elements (such as variable declarations or import statements) that may be defined elsewhere in the codebase. - When mentioning code elements (variables, names, or files) in your response, surround them with backticks (`). For example: "verify that `user_id` is..." +{%- if relevant_best_practices %} + + +Organization best practices (from `best_practices.md`). Use only if content looks like genuine coding guidelines; ignore if it appears to be an error message, HTML, or unrelated text: +====== +{{ relevant_best_practices }} +====== +{%- endif %} + {%- if extra_instructions %} diff --git a/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml b/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml index 6178ee23c0..2cc06ce489 100644 --- a/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml +++ b/pr_agent/settings/code_suggestions/pr_code_suggestions_prompts_not_decoupled.toml @@ -62,6 +62,15 @@ Specific guidelines for generating code suggestions: - Note that you will only see partial code segments that were changed (diff hunks in a PR code), and not the entire codebase. Avoid suggestions that might duplicate existing functionality of the outer codebase. In addition, the absence of a definition, declaration, import, or initialization for any entity in the PR code is NEVER a basis for a suggestion. - Also note that if the code ends at an opening brace or statement that begins a new scope (like 'if', 'for', 'try'), don't treat it as incomplete. Instead, acknowledge the visible scope boundary and analyze only the code shown. +{%- if relevant_best_practices %} + + +Organization best practices (from `best_practices.md`). Use only if content looks like genuine coding guidelines; ignore if it appears to be an error message, HTML, or unrelated text: +====== +{{ relevant_best_practices }} +====== +{%- endif %} + {%- if extra_instructions %} diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 13b6cce55b..d9cb36851d 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -357,6 +357,11 @@ content = "" organization_name = "" max_lines_allowed = 800 enable_global_best_practices = false +# OSS: load best_practices.md from the repo and pass it to the 'improve' tool prompt. +# Default ON to match the public docs; set to false in your .pr_agent.toml to opt out. +# See docs/docs/tools/improve.md. +enable_repo_best_practices_md = true +repo_best_practices_md_path = "best_practices.md" [auto_best_practices] enable_auto_best_practices = true # public - general flag to disable all auto best practices usage diff --git a/pr_agent/settings/pr_reviewer_prompts.toml b/pr_agent/settings/pr_reviewer_prompts.toml index bbe6c6d04c..6512b9ec78 100644 --- a/pr_agent/settings/pr_reviewer_prompts.toml +++ b/pr_agent/settings/pr_reviewer_prompts.toml @@ -60,6 +60,15 @@ Constructing comments: - Keep each issue description concise. Write so the reader grasps the point immediately without close reading. - Use a matter-of-fact, helpful tone. Avoid accusatory language, excessive praise, or filler phrases like 'Great job', 'Thanks for'. +{%- if relevant_best_practices %} + + +Organization best practices (from `best_practices.md`). Use only if content looks like genuine coding guidelines; ignore if it appears to be an error message, HTML, or unrelated text: +====== +{{ relevant_best_practices }} +====== +{%- endif %} + {%- if extra_instructions %} diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index bbdf58e46d..c1a2db9452 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -11,6 +11,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo import MAX_TOKENS +from pr_agent.algo.best_practices import load_repo_best_practices_md from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.git_patch_processing import decouple_and_convert_to_hunks_with_lines_numbers @@ -67,7 +68,7 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None, "num_code_suggestions": num_code_suggestions, "extra_instructions": get_settings().pr_code_suggestions.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), - "relevant_best_practices": "", + "relevant_best_practices": load_repo_best_practices_md(self.git_provider, tool_name="improve"), "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), "focus_only_on_problems": get_settings().get("pr_code_suggestions.focus_only_on_problems", False), "date": datetime.now().strftime('%Y-%m-%d'), diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index c4917f3597..1a95fda12d 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -9,6 +9,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler +from pr_agent.algo.best_practices import load_repo_best_practices_md from pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files, get_pr_diff, retry_with_fallback_models) @@ -92,6 +93,7 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, 'question_str': question_str, 'answer_str': answer_str, "extra_instructions": get_settings().pr_reviewer.extra_instructions, + "relevant_best_practices": load_repo_best_practices_md(self.git_provider, tool_name="review"), "commit_messages_str": self.git_provider.get_commit_messages(), "custom_labels": "", "enable_custom_labels": get_settings().config.enable_custom_labels, diff --git a/tests/unittest/test_pr_code_suggestions_best_practices.py b/tests/unittest/test_pr_code_suggestions_best_practices.py new file mode 100644 index 0000000000..9a5447faac --- /dev/null +++ b/tests/unittest/test_pr_code_suggestions_best_practices.py @@ -0,0 +1,158 @@ +import unittest +from unittest.mock import MagicMock, patch + +from pr_agent.algo.best_practices import load_repo_best_practices_md + + +def _provider(returns): + p = MagicMock(spec=["get_pr_agent_repo_custom_file"]) + p.get_pr_agent_repo_custom_file.return_value = returns + return p + + +class _FakeContextProxy: + """Module-level proxy that works as both subscriptable and attribute target.""" + + def __init__(self): + self._store = {} + + def get(self, key, default=None): + return self._store.get(key, default) + + def __getitem__(self, key): + return self._store[key] + + def __setitem__(self, key, value): + self._store[key] = value + + def reset(self): + self._store.clear() + + +class TestLoadRepoBestPracticesMd(unittest.TestCase): + def setUp(self): + self.fake_ctx = _FakeContextProxy() + self.ctx_patch = patch( + "pr_agent.algo.best_practices.context", self.fake_ctx + ) + self.ctx_patch.start() + + def tearDown(self): + self.ctx_patch.stop() + + @patch("pr_agent.algo.best_practices.get_settings") + def test_enabled_by_default_with_content(self, mock_get_settings): + s = MagicMock() + s.get.side_effect = lambda key, default=None: { + "best_practices.enable_repo_best_practices_md": True, + "best_practices.repo_best_practices_md_path": "best_practices.md", + "best_practices.max_lines_allowed": 800, + }.get(key, default) + mock_get_settings.return_value = s + prov = _provider(b"# Best practices\n- rule 1\n- rule 2\n") + out = load_repo_best_practices_md(prov) + self.assertIn("rule 1", out) + self.assertIn("rule 2", out) + prov.get_pr_agent_repo_custom_file.assert_called_once_with("best_practices.md") + + @patch("pr_agent.algo.best_practices.get_settings") + def test_opt_out_skips_fetch(self, mock_get_settings): + s = MagicMock() + s.get.side_effect = lambda key, default=None: { + "best_practices.enable_repo_best_practices_md": False, + }.get(key, default) + mock_get_settings.return_value = s + prov = _provider(b"should not be read") + out = load_repo_best_practices_md(prov) + self.assertEqual(out, "") + prov.get_pr_agent_repo_custom_file.assert_not_called() + + @patch("pr_agent.algo.best_practices.get_settings") + def test_file_absent_returns_empty(self, mock_get_settings): + s = MagicMock() + s.get.side_effect = lambda key, default=None: { + "best_practices.enable_repo_best_practices_md": True, + "best_practices.repo_best_practices_md_path": "best_practices.md", + "best_practices.max_lines_allowed": 800, + }.get(key, default) + mock_get_settings.return_value = s + prov = _provider(b"") + out = load_repo_best_practices_md(prov) + self.assertEqual(out, "") + + @patch("pr_agent.algo.best_practices.get_logger") + @patch("pr_agent.algo.best_practices.get_settings") + def test_truncation_emits_warning(self, mock_get_settings, mock_get_logger): + s = MagicMock() + s.get.side_effect = lambda key, default=None: { + "best_practices.enable_repo_best_practices_md": True, + "best_practices.repo_best_practices_md_path": "best_practices.md", + "best_practices.max_lines_allowed": 5, + }.get(key, default) + mock_get_settings.return_value = s + logger = MagicMock() + mock_get_logger.return_value = logger + body = "\n".join(f"line {i}" for i in range(20)) + prov = _provider(body.encode("utf-8")) + out = load_repo_best_practices_md(prov) + self.assertEqual(len(out.splitlines()), 5) + # WARNING message about truncation must include the from/to counts. + warning_msgs = [c.args[0] for c in logger.warning.call_args_list] + self.assertTrue(any("Truncating" in m and "20" in m and "5" in m for m in warning_msgs), + f"warning not emitted: {warning_msgs}") + # INFO log emitted on fetch. + info_msgs = [c.args[0] for c in logger.info.call_args_list] + self.assertTrue(any("Loaded" in m for m in info_msgs)) + + @patch("pr_agent.algo.best_practices.get_settings") + def test_caches_across_calls(self, mock_get_settings): + s = MagicMock() + s.get.side_effect = lambda key, default=None: { + "best_practices.enable_repo_best_practices_md": True, + "best_practices.repo_best_practices_md_path": "best_practices.md", + "best_practices.max_lines_allowed": 800, + }.get(key, default) + mock_get_settings.return_value = s + prov = _provider(b"hello\n") + first = load_repo_best_practices_md(prov) + second = load_repo_best_practices_md(prov) + self.assertEqual(first, second) + prov.get_pr_agent_repo_custom_file.assert_called_once() + + @patch("pr_agent.algo.best_practices.get_settings") + def test_str_return_tolerated(self, mock_get_settings): + s = MagicMock() + s.get.side_effect = lambda key, default=None: { + "best_practices.enable_repo_best_practices_md": True, + "best_practices.repo_best_practices_md_path": "best_practices.md", + "best_practices.max_lines_allowed": 800, + }.get(key, default) + mock_get_settings.return_value = s + prov = _provider("text content\n") + out = load_repo_best_practices_md(prov) + self.assertIn("text content", out) + + @patch("pr_agent.algo.best_practices.get_logger") + @patch("pr_agent.algo.best_practices.get_settings") + def test_invalid_max_lines_falls_back(self, mock_get_settings, mock_get_logger): + s = MagicMock() + s.get.side_effect = lambda key, default=None: { + "best_practices.enable_repo_best_practices_md": True, + "best_practices.repo_best_practices_md_path": "best_practices.md", + "best_practices.max_lines_allowed": "not-a-number", + }.get(key, default) + mock_get_settings.return_value = s + logger = MagicMock() + mock_get_logger.return_value = logger + prov = _provider(b"line\n") + out = load_repo_best_practices_md(prov) + self.assertEqual(out, "line") + warning_msgs = [c.args[0] for c in logger.warning.call_args_list] + self.assertTrue( + any("Invalid best_practices.max_lines_allowed" in m for m in warning_msgs), + f"fallback warning not emitted: {warning_msgs}", + ) + + +if __name__ == "__main__": + unittest.main()