Skip to content

Commit 13a1da6

Browse files
romanlutzCopilot
andauthored
[BREAKING] MAINT: enforce _async suffix on async functions across pyrit/ (#1889)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5e8bf1d commit 13a1da6

202 files changed

Lines changed: 2080 additions & 1028 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/instructions/datasets.instructions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Each `SeedPrompt` / `SeedObjective` must carry:
5757

5858
## Set class-level dataset metadata when known
5959

60-
`_parse_metadata` on `_RemoteDatasetLoader` reads class attributes matching `SeedDatasetMetadata` fields. Declare what you can know statically as class-level constants so dataset discovery/filtering works:
60+
`_parse_metadata_async` on `_RemoteDatasetLoader` reads class attributes matching `SeedDatasetMetadata` fields. Declare what you can know statically as class-level constants so dataset discovery/filtering works:
6161

6262
```python
6363
class _MyDataset(_RemoteDatasetLoader):

.github/instructions/style-guide.instructions.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ async def _read_audio_async(self, path):
4343
### Async Functions
4444
- **MANDATORY**: All async functions and methods MUST end with `_async` suffix
4545
- This applies to ALL async functions without exception
46+
- Enforced by the `check-async-suffix` pre-commit hook (`build_scripts/check_async_suffix.py`)
4647

4748
```python
4849
# CORRECT
@@ -54,6 +55,13 @@ async def send_prompt(self, prompt: str) -> Message: # Missing _async suffix
5455
...
5556
```
5657

58+
**Exemptions** are limited and explicit:
59+
- Async dunders (`__aenter__`, `__aexit__`, `__aiter__`, `__anext__`) are exempt automatically.
60+
- A small set of framework-mandated names (`lifespan`, `dispatch`, `__call__`) is exempt
61+
automatically; see `_FRAMEWORK_EXEMPT_NAMES` in `build_scripts/check_async_suffix.py`.
62+
- For one-off exemptions (e.g. an external SDK protocol method) add a
63+
`# pyrit-async-suffix-exempt` trailing comment on the `async def` line.
64+
5765
### Private Methods
5866
- Private methods MUST start with underscore
5967
- This clearly indicates internal implementation details

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ repos:
4141
language: python
4242
files: ^pyrit/memory/alembic/versions/.*\.py$
4343
pass_filenames: false
44+
- id: check-async-suffix
45+
name: Enforce _async Suffix on async def
46+
entry: python ./build_scripts/check_async_suffix.py
47+
language: python
48+
files: ^pyrit/.*\.py$
49+
pass_filenames: false
4450
- id: memory-migrations-check
4551
name: Check Memory Migrations
4652
entry: python ./build_scripts/memory_migrations.py check
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
Enforce ``.github/instructions/style-guide.instructions.md`` §1: every ``async def`` in
6+
``pyrit/`` must end with the ``_async`` suffix.
7+
8+
Mechanism: walk every ``pyrit/**/*.py`` file with ``ast`` and flag every ``AsyncFunctionDef``
9+
whose name does not end in ``_async`` and is not exempted via either:
10+
11+
1. **Hard-coded framework exemptions** (``_FRAMEWORK_EXEMPT_NAMES``) — names whose meaning
12+
is dictated by an external framework or by the Python data model
13+
(e.g. ``lifespan`` for FastAPI, ``dispatch`` for Starlette middleware, ``__call__``
14+
on Protocol classes). The set is intentionally small; one-off exemptions
15+
should use the per-line ``# pyrit-async-suffix-exempt`` marker instead.
16+
17+
2. **Per-line ``# pyrit-async-suffix-exempt`` marker** on any line of the ``async def``
18+
header (the marker is scanned across the full signature, which the formatter may
19+
split across multiple lines). Common reasons: a deprecation shim that intentionally
20+
keeps the old non-``_async`` name for one release cycle; a one-off external-SDK or
21+
protocol method name.
22+
"""
23+
24+
from __future__ import annotations
25+
26+
import ast
27+
import sys
28+
from pathlib import Path
29+
30+
# Project layout — anchor everything off the repo root (directory containing pyrit/).
31+
_REPO_ROOT = Path(__file__).resolve().parent.parent
32+
_SCAN_ROOTS = ("pyrit",)
33+
34+
# Framework-mandated names: do NOT add to this set for one-off exemptions.
35+
# Use a per-line ``# pyrit-async-suffix-exempt`` marker instead so each exemption is
36+
# visible at the violation site.
37+
_FRAMEWORK_EXEMPT_NAMES: frozenset[str] = frozenset(
38+
{
39+
"lifespan", # FastAPI app lifespan context manager
40+
"dispatch", # Starlette BaseHTTPMiddleware.dispatch override
41+
"__call__", # Python dunder; Protocol classes commonly define async __call__
42+
}
43+
)
44+
45+
_NOQA_MARKER = "# pyrit-async-suffix-exempt"
46+
47+
48+
def _is_violation_name(name: str) -> bool:
49+
"""Return True if ``name`` violates the async-suffix rule."""
50+
if name.endswith("_async"):
51+
return False
52+
if name.startswith("__a"):
53+
# Async dunders: __aenter__, __aexit__, __aiter__, __anext__.
54+
return False
55+
return name not in _FRAMEWORK_EXEMPT_NAMES
56+
57+
58+
def _line_has_noqa(source_lines: list[str], lineno: int) -> bool:
59+
"""Return True if ``source_lines[lineno - 1]`` carries the exempt marker."""
60+
if lineno < 1 or lineno > len(source_lines):
61+
return False
62+
return _NOQA_MARKER in source_lines[lineno - 1]
63+
64+
65+
def _header_has_noqa(source_lines: list[str], node: ast.AsyncFunctionDef) -> bool:
66+
"""Return True if any line of the def header carries the exempt marker.
67+
68+
The header spans ``node.lineno`` through the line just before the function body
69+
starts (which is where the formatter may place the marker after splitting a
70+
long signature across multiple lines).
71+
"""
72+
start = node.lineno
73+
end = node.body[0].lineno - 1 if node.body else start
74+
return any(_line_has_noqa(source_lines, lineno) for lineno in range(start, max(start, end) + 1))
75+
76+
77+
def _scan_file(path: Path) -> list[tuple[str, int, str]]:
78+
"""Return ``(relative_path, line, name)`` violations in ``path``.
79+
80+
``relative_path`` is forward-slash normalized relative to the repo root so that
81+
violations are reported portably between Windows and Linux checkouts.
82+
"""
83+
source = path.read_text(encoding="utf-8")
84+
try:
85+
tree = ast.parse(source, filename=str(path))
86+
except SyntaxError as exc:
87+
rel = path.relative_to(_REPO_ROOT).as_posix()
88+
# Surface the parse failure as a violation so an unparseable file can't
89+
# silently slip past the check. Other hooks (e.g. ruff) should flag the
90+
# syntax error too, but we don't rely on their ordering.
91+
message = f"{exc.msg} (line {exc.lineno})" if exc.lineno is not None else exc.msg
92+
return [(rel, exc.lineno or 0, f"<SyntaxError: {message}>")]
93+
source_lines = source.splitlines()
94+
rel = path.relative_to(_REPO_ROOT).as_posix()
95+
violations: list[tuple[str, int, str]] = []
96+
for node in ast.walk(tree):
97+
if not isinstance(node, ast.AsyncFunctionDef):
98+
continue
99+
if not _is_violation_name(node.name):
100+
continue
101+
if _header_has_noqa(source_lines, node):
102+
continue
103+
violations.append((rel, node.lineno, node.name))
104+
return violations
105+
106+
107+
def _scan_repo() -> list[tuple[str, int, str]]:
108+
"""Return all violations across the scanned roots, sorted for determinism."""
109+
violations: list[tuple[str, int, str]] = []
110+
for root in _SCAN_ROOTS:
111+
for path in sorted((_REPO_ROOT / root).rglob("*.py")):
112+
violations.extend(_scan_file(path))
113+
return violations
114+
115+
116+
def main() -> int:
117+
violations = _scan_repo()
118+
if not violations:
119+
return 0
120+
121+
print(
122+
"[ERROR] Async functions are missing the `_async` suffix "
123+
"(see .github/instructions/style-guide.instructions.md §1):"
124+
)
125+
for path, line, name in violations:
126+
if name.startswith("<SyntaxError"):
127+
print(f" {path}:{line}: could not parse file: {name[1:-1]}")
128+
else:
129+
print(f" {path}:{line}: async def {name}(...)")
130+
print("")
131+
print("Rename each function to end in `_async`, or — if the name is dictated")
132+
print("by a framework — add `# pyrit-async-suffix-exempt` at the end of the `async def` line.")
133+
return 1
134+
135+
136+
if __name__ == "__main__":
137+
sys.exit(main())

doc/code/datasets/4_dataset_coding.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
"\n",
7070
" async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:\n",
7171
" # Fetch from HuggingFace\n",
72-
" data = await self._fetch_from_huggingface(\n",
72+
" data = await self._fetch_from_huggingface_async(\n",
7373
" dataset_name=\"apart/darkbench\",\n",
7474
" config=\"default\",\n",
7575
" split=\"train\",\n",

doc/code/datasets/4_dataset_coding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def dataset_name(self) -> str:
6666

6767
async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
6868
# Fetch from HuggingFace
69-
data = await self._fetch_from_huggingface(
69+
data = await self._fetch_from_huggingface_async(
7070
dataset_name="apart/darkbench",
7171
config="default",
7272
split="train",

doc/code/targets/realtime_target.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@
197197
"attack = PromptSendingAttack(objective_target=target)\n",
198198
"result = await attack.execute_with_context_async(context=context) # type: ignore\n",
199199
"await output_attack_async(result)\n",
200-
"await target.cleanup_target() # type: ignore"
200+
"await target.cleanup_target_async() # type: ignore"
201201
]
202202
},
203203
{

doc/code/targets/realtime_target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
attack = PromptSendingAttack(objective_target=target)
7575
result = await attack.execute_with_context_async(context=context) # type: ignore
7676
await output_attack_async(result)
77-
await target.cleanup_target() # type: ignore
77+
await target.cleanup_target_async() # type: ignore
7878

7979
# %% [markdown]
8080
# ## Text Conversation

pyrit/auth/azure_auth.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) ->
8585
"""
8686
self._token_provider = token_provider
8787

88-
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
88+
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pyrit-async-suffix-exempt
8989
"""
9090
Get an access token asynchronously.
9191
@@ -104,7 +104,7 @@ async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
104104
expires_on = int(time.time()) + 3600
105105
return AccessToken(str(token), expires_on)
106106

107-
async def close(self) -> None:
107+
async def close(self) -> None: # pyrit-async-suffix-exempt
108108
"""No-op close for protocol compliance. The callable provider does not hold resources."""
109109

110110
async def __aenter__(self) -> AsyncTokenProviderCredential:
@@ -149,7 +149,7 @@ def ensure_async_token_provider(
149149
" Automatically wrapping in async function for compatibility with async client."
150150
)
151151

152-
async def async_token_provider() -> str:
152+
async def async_token_provider() -> str: # pyrit-async-suffix-exempt
153153
"""
154154
Async wrapper for synchronous token provider.
155155

pyrit/auth/azure_storage_auth.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
)
1313
from azure.storage.blob.aio import BlobServiceClient
1414

15+
from pyrit.common.deprecation import print_deprecation_message
16+
1517

1618
class AzureStorageAuth:
1719
"""
@@ -20,7 +22,7 @@ class AzureStorageAuth:
2022
"""
2123

2224
@staticmethod
23-
async def get_user_delegation_key(blob_service_client: BlobServiceClient) -> UserDelegationKey:
25+
async def get_user_delegation_key_async(blob_service_client: BlobServiceClient) -> UserDelegationKey:
2426
"""
2527
Retrieve a user delegation key valid for one day.
2628
@@ -39,7 +41,28 @@ async def get_user_delegation_key(blob_service_client: BlobServiceClient) -> Use
3941
)
4042

4143
@staticmethod
42-
async def get_sas_token(container_url: str) -> str:
44+
async def get_user_delegation_key(
45+
blob_service_client: BlobServiceClient,
46+
) -> UserDelegationKey: # pyrit-async-suffix-exempt
47+
"""
48+
Retrieve a user delegation key (deprecated alias of ``get_user_delegation_key_async``).
49+
50+
Args:
51+
blob_service_client (BlobServiceClient): An instance of BlobServiceClient to interact
52+
with Azure Blob Storage.
53+
54+
Returns:
55+
UserDelegationKey: A user delegation key valid for one day.
56+
"""
57+
print_deprecation_message(
58+
old_item="AzureStorageAuth.get_user_delegation_key",
59+
new_item="AzureStorageAuth.get_user_delegation_key_async",
60+
removed_in="0.16.0",
61+
)
62+
return await AzureStorageAuth.get_user_delegation_key_async(blob_service_client)
63+
64+
@staticmethod
65+
async def get_sas_token_async(container_url: str) -> str:
4366
"""
4467
Generate a SAS token for the specified blob using a user delegation key.
4568
@@ -72,7 +95,7 @@ async def get_sas_token(container_url: str) -> str:
7295

7396
try:
7497
async with BlobServiceClient(account_url=account_url, credential=credential) as blob_service_client:
75-
user_delegation_key = await AzureStorageAuth.get_user_delegation_key(
98+
user_delegation_key = await AzureStorageAuth.get_user_delegation_key_async(
7699
blob_service_client=blob_service_client
77100
)
78101
container_name = parsed_url.path.lstrip("/")
@@ -94,3 +117,21 @@ async def get_sas_token(container_url: str) -> str:
94117
await credential.close()
95118

96119
return sas_token
120+
121+
@staticmethod
122+
async def get_sas_token(container_url: str) -> str: # pyrit-async-suffix-exempt
123+
"""
124+
Generate a SAS token (deprecated alias of ``get_sas_token_async``).
125+
126+
Args:
127+
container_url (str): The URL of the Azure Blob Storage container.
128+
129+
Returns:
130+
str: The generated SAS token.
131+
"""
132+
print_deprecation_message(
133+
old_item="AzureStorageAuth.get_sas_token",
134+
new_item="AzureStorageAuth.get_sas_token_async",
135+
removed_in="0.16.0",
136+
)
137+
return await AzureStorageAuth.get_sas_token_async(container_url)

0 commit comments

Comments
 (0)