Skip to content

Commit 143795f

Browse files
authored
Support structural pattern matching in ASYNC103,104 and 91X (#363)
* add match support to 103/104 * add match-case support to visitor91x
1 parent b6bfbb6 commit 143795f

File tree

11 files changed

+357
-4
lines changed

11 files changed

+357
-4
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ repos:
4242
hooks:
4343
- id: mypy
4444
# uses py311 syntax, mypy configured for py39
45-
exclude: tests/(eval|autofix)_files/.*_py311.py
45+
exclude: tests/(eval|autofix)_files/.*_py(310|311).py
4646

4747
- repo: https://github.com/RobertCraigie/pyright-python
4848
rev: v1.1.397

docs/changelog.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Changelog
44

55
`CalVer, YY.month.patch <https://calver.org/>`_
66

7+
25.4.1
8+
======
9+
- Add match-case (structural pattern matching) support to ASYNC103, 104, 910, 911 & 912.
10+
711
25.3.1
812
======
913
- Add except* support to ASYNC102, 103, 104, 120, 910, 911, 912.

docs/usage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ adding the following to your ``.pre-commit-config.yaml``:
3333
minimum_pre_commit_version: '2.9.0'
3434
repos:
3535
- repo: https://github.com/python-trio/flake8-async
36-
rev: 25.3.1
36+
rev: 25.4.1
3737
hooks:
3838
- id: flake8-async
3939
# args: ["--enable=ASYNC100,ASYNC112", "--disable=", "--autofix=ASYNC"]

flake8_async/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939

4040
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
41-
__version__ = "25.3.1"
41+
__version__ = "25.4.1"
4242

4343

4444
# taken from https://github.com/Zac-HD/shed

flake8_async/visitors/visitor103_104.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,24 @@ def visit_If(self, node: ast.If):
199199
# if body didn't raise, or it's unraised after else, set unraise
200200
self.unraised = not body_raised or self.unraised
201201

202+
def visit_Match(self, node: ast.Match): # type: ignore[name-defined]
203+
if not self.unraised:
204+
return
205+
all_cases_raise = True
206+
has_fallback = False
207+
for case in node.cases:
208+
# check for "bare pattern", i.e `case varname:`
209+
has_fallback |= (
210+
case.guard is None
211+
and isinstance(case.pattern, ast.MatchAs) # type: ignore[attr-defined]
212+
and case.pattern.pattern is None
213+
)
214+
self.visit_nodes(case.body)
215+
all_cases_raise &= not self.unraised
216+
self.unraised = True
217+
218+
self.unraised = not (all_cases_raise and has_fallback)
219+
202220
# A loop is guaranteed to raise if:
203221
# condition always raises, or
204222
# else always raises, and

flake8_async/visitors/visitor91x.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ def copy(self):
198198
)
199199

200200

201+
@dataclass
202+
class MatchState:
203+
# TryState, LoopState, and MatchState all do fairly similar things. It would be nice
204+
# to harmonize them and share logic.
205+
base_uncheckpointed_statements: set[Statement] = field(default_factory=set)
206+
case_uncheckpointed_statements: set[Statement] = field(default_factory=set)
207+
has_fallback: bool = False
208+
209+
def copy(self):
210+
return MatchState(
211+
base_uncheckpointed_statements=self.base_uncheckpointed_statements.copy(),
212+
case_uncheckpointed_statements=self.case_uncheckpointed_statements.copy(),
213+
has_fallback=self.has_fallback,
214+
)
215+
216+
201217
def checkpoint_statement(library: str) -> cst.SimpleStatementLine:
202218
# logic before this should stop code from wanting to insert the non-existing
203219
# asyncio.lowlevel.checkpoint
@@ -373,6 +389,7 @@ def __init__(self, *args: Any, **kwargs: Any):
373389

374390
self.loop_state = LoopState()
375391
self.try_state = TryState()
392+
self.match_state = MatchState()
376393

377394
# ASYNC100
378395
self.has_checkpoint_stack: list[bool] = []
@@ -894,6 +911,55 @@ def visit_IfExp(self, node: cst.IfExp) -> bool:
894911
self.leave_If(node, node) # type: ignore
895912
return False # libcst shouldn't visit subnodes again
896913

914+
def leave_Match_subject(self, node: cst.Match) -> None:
915+
# We start the match logic after parsing the subject, instead of visit_Match,
916+
# since the subject is always executed and might checkpoint.
917+
if not self.async_function:
918+
return
919+
self.save_state(node, "match_state", copy=True)
920+
self.match_state = MatchState(
921+
base_uncheckpointed_statements=self.uncheckpointed_statements.copy()
922+
)
923+
924+
def visit_MatchCase(self, node: cst.MatchCase) -> None:
925+
# enter each case from the state after parsing the subject
926+
self.uncheckpointed_statements = self.match_state.base_uncheckpointed_statements
927+
928+
def leave_MatchCase_guard(self, node: cst.MatchCase) -> None:
929+
# `case _:` is no pattern and no guard, which means we know body is executed.
930+
# But we also know that `case _ if <guard>:` is guaranteed to execute the guard,
931+
# so for later logic we can treat them the same *if* there's no pattern and that
932+
# guard checkpoints.
933+
if (
934+
isinstance(node.pattern, cst.MatchAs)
935+
and node.pattern.pattern is None
936+
and (node.guard is None or not self.uncheckpointed_statements)
937+
):
938+
self.match_state.has_fallback = True
939+
940+
def leave_MatchCase(
941+
self, original_node: cst.MatchCase, updated_node: cst.MatchCase
942+
) -> cst.MatchCase:
943+
# collect the state at the end of each case
944+
self.match_state.case_uncheckpointed_statements.update(
945+
self.uncheckpointed_statements
946+
)
947+
return updated_node
948+
949+
def leave_Match(
950+
self, original_node: cst.Match, updated_node: cst.Match
951+
) -> cst.Match:
952+
# leave the Match with the worst-case of all branches
953+
self.uncheckpointed_statements = self.match_state.case_uncheckpointed_statements
954+
# if no fallback, also add the state at entering the match (after parsing subject)
955+
if not self.match_state.has_fallback:
956+
self.uncheckpointed_statements.update(
957+
self.match_state.base_uncheckpointed_statements
958+
)
959+
960+
self.restore_state(original_node)
961+
return updated_node
962+
897963
def visit_While(self, node: cst.While | cst.For):
898964
self.save_state(
899965
node,

flake8_async/visitors/visitor_utility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from re import Match
1818

1919
import libcst as cst
20-
from libcst._position import CodeRange
20+
from libcst.metadata import CodeRange
2121

2222

2323
@utility_visitor
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# ARG --enable=ASYNC910,ASYNC911,ASYNC913
2+
# AUTOFIX
3+
# ASYNCIO_NO_AUTOFIX
4+
import trio
5+
6+
7+
async def foo(): ...
8+
9+
10+
async def match_subject() -> None:
11+
match await foo():
12+
case False:
13+
pass
14+
15+
16+
async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
17+
None
18+
):
19+
match foo():
20+
case 1:
21+
...
22+
case _:
23+
await foo()
24+
await trio.lowlevel.checkpoint()
25+
26+
27+
async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
28+
None
29+
):
30+
match foo():
31+
case 1:
32+
await foo()
33+
case 2:
34+
await foo()
35+
case _ if True:
36+
await foo()
37+
await trio.lowlevel.checkpoint()
38+
39+
40+
async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
41+
None
42+
):
43+
match foo():
44+
case 1:
45+
await foo()
46+
case 2:
47+
await foo()
48+
case _ if foo():
49+
await foo()
50+
await trio.lowlevel.checkpoint()
51+
52+
53+
async def match_all_cases() -> None:
54+
match foo():
55+
case 1:
56+
await foo()
57+
case 2:
58+
await foo()
59+
case _:
60+
await foo()
61+
62+
63+
async def match_fallback_await_in_guard() -> None:
64+
# The case guard is only executed if the pattern matches, so we can mostly treat
65+
# it as part of the body, except for a special case for fallback+checkpointing guard.
66+
match foo():
67+
case 1 if await foo():
68+
...
69+
case _ if await foo():
70+
...
71+
72+
73+
async def match_checkpoint_guard() -> None:
74+
# The above pattern is quite cursed, but this seems fairly reasonable to do.
75+
match foo():
76+
case 1 if await foo():
77+
...
78+
case _:
79+
await foo()
80+
81+
82+
async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
83+
None
84+
):
85+
match foo():
86+
case 1:
87+
...
88+
case _ if await foo():
89+
...
90+
await trio.lowlevel.checkpoint()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
---
2+
+++
3+
@@ x,6 x,7 @@
4+
...
5+
case _:
6+
await foo()
7+
+ await trio.lowlevel.checkpoint()
8+
9+
10+
async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
11+
@@ x,6 x,7 @@
12+
await foo()
13+
case _ if True:
14+
await foo()
15+
+ await trio.lowlevel.checkpoint()
16+
17+
18+
async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
19+
@@ x,6 x,7 @@
20+
await foo()
21+
case _ if foo():
22+
await foo()
23+
+ await trio.lowlevel.checkpoint()
24+
25+
26+
async def match_all_cases() -> None:
27+
@@ x,3 x,4 @@
28+
...
29+
case _ if await foo():
30+
...
31+
+ await trio.lowlevel.checkpoint()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Test for ASYNC103/ASYNC104 with structural pattern matching
2+
3+
ASYNC103: no-reraise-cancelled
4+
ASYNC104: cancelled-not-raised
5+
"""
6+
7+
# ARG --enable=ASYNC103,ASYNC104
8+
9+
10+
def foo() -> Any: ...
11+
12+
13+
try:
14+
...
15+
except BaseException as e: # ASYNC103_trio: 7, "BaseException"
16+
match foo():
17+
case True:
18+
raise e
19+
case False:
20+
...
21+
case _:
22+
raise e
23+
24+
try:
25+
...
26+
except BaseException: # ASYNC103_trio: 7, "BaseException"
27+
match foo():
28+
case True:
29+
raise
30+
31+
try:
32+
...
33+
except BaseException: # safe
34+
match foo():
35+
case True:
36+
raise
37+
case False:
38+
raise
39+
case _:
40+
raise
41+
try:
42+
...
43+
except BaseException: # ASYNC103_trio: 7, "BaseException"
44+
match foo():
45+
case _ if foo():
46+
raise
47+
try:
48+
...
49+
except BaseException: # ASYNC103_trio: 7, "BaseException"
50+
match foo():
51+
case 1:
52+
return # ASYNC104: 12
53+
case 2:
54+
raise
55+
case 3:
56+
return # ASYNC104: 12
57+
case blah:
58+
raise

0 commit comments

Comments
 (0)