Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
hooks:
- id: mypy
# uses py311 syntax, mypy configured for py39
exclude: tests/(eval|autofix)_files/.*_py311.py
exclude: tests/(eval|autofix)_files/.*_py(310|311).py

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.396
Expand Down
18 changes: 18 additions & 0 deletions flake8_async/visitors/visitor103_104.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,24 @@ def visit_If(self, node: ast.If):
# if body didn't raise, or it's unraised after else, set unraise
self.unraised = not body_raised or self.unraised

def visit_Match(self, node: ast.Match): # type: ignore[name-defined]
if not self.unraised:
return
all_cases_raise = True
has_fallback = False
for case in node.cases:
# check for "bare pattern", i.e `case varname:`
has_fallback |= (
case.guard is None
and isinstance(case.pattern, ast.MatchAs) # type: ignore[attr-defined]
and case.pattern.pattern is None
)
self.visit_nodes(case.body)
all_cases_raise &= not self.unraised
self.unraised = True

self.unraised = not (all_cases_raise and has_fallback)

# A loop is guaranteed to raise if:
# condition always raises, or
# else always raises, and
Expand Down
64 changes: 64 additions & 0 deletions flake8_async/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,22 @@ def copy(self):
)


@dataclass
class MatchState:
# TryState, LoopState, and MatchState all do fairly similar things. It would be nice
# to harmonize them and share logic.
base_uncheckpointed_statements: set[Statement] = field(default_factory=set)
case_uncheckpointed_statements: set[Statement] = field(default_factory=set)
has_fallback: bool = False

def copy(self):
return MatchState(
base_uncheckpointed_statements=self.base_uncheckpointed_statements.copy(),
case_uncheckpointed_statements=self.case_uncheckpointed_statements.copy(),
has_fallback=self.has_fallback,
)


def checkpoint_statement(library: str) -> cst.SimpleStatementLine:
# logic before this should stop code from wanting to insert the non-existing
# asyncio.lowlevel.checkpoint
Expand Down Expand Up @@ -373,6 +389,7 @@ def __init__(self, *args: Any, **kwargs: Any):

self.loop_state = LoopState()
self.try_state = TryState()
self.match_state = MatchState()

# ASYNC100
self.has_checkpoint_stack: list[bool] = []
Expand Down Expand Up @@ -894,6 +911,53 @@ def visit_IfExp(self, node: cst.IfExp) -> bool:
self.leave_If(node, node) # type: ignore
return False # libcst shouldn't visit subnodes again

def leave_Match_subject(self, node: cst.Match) -> None:
# We start the match logic after parsing the subject, instead of visit_Match,
# since the subject is always executed and might checkpoint.
if not self.async_function:
return
self.save_state(node, "match_state", copy=True)
self.match_state = MatchState(self.uncheckpointed_statements.copy())

def visit_MatchCase(self, node: cst.MatchCase) -> None:
# enter each case from the state after parsing the subject
self.uncheckpointed_statements = self.match_state.base_uncheckpointed_statements

def leave_MatchCase_guard(self, node: cst.MatchCase) -> None:
# `case _:` is no pattern and no guard, which means we know body is executed.
# But we also know that `case _ if <guard>:` is guaranteed to execute the guard,
# so for later logic we can treat them the same *if* there's no pattern and that
# guard checkpoints.
if (
isinstance(node.pattern, cst.MatchAs)
and node.pattern.pattern is None
and (node.guard is None or not self.uncheckpointed_statements)
):
self.match_state.has_fallback = True

def leave_MatchCase(
self, original_node: cst.MatchCase, updated_node: cst.MatchCase
) -> cst.MatchCase:
# collect the state at the end of each case
self.match_state.case_uncheckpointed_statements.update(
self.uncheckpointed_statements
)
return updated_node

def leave_Match(
self, original_node: cst.Match, updated_node: cst.Match
) -> cst.Match:
# leave the Match with the worst-case of all branches
self.uncheckpointed_statements = self.match_state.case_uncheckpointed_statements
# if no fallback, also add the state at entering the match (after parsing subject)
if not self.match_state.has_fallback:
self.uncheckpointed_statements.update(
self.match_state.base_uncheckpointed_statements
)

self.restore_state(original_node)
return updated_node

def visit_While(self, node: cst.While | cst.For):
self.save_state(
node,
Expand Down
90 changes: 90 additions & 0 deletions tests/autofix_files/async91x_py310.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# ARG --enable=ASYNC910,ASYNC911,ASYNC913
# AUTOFIX
# ASYNCIO_NO_AUTOFIX
import trio


async def foo(): ...


async def match_subject() -> None:
match await foo():
case False:
pass


async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
...
case _:
await foo()
await trio.lowlevel.checkpoint()


async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
await foo()
case 2:
await foo()
case _ if True:
await foo()
await trio.lowlevel.checkpoint()


async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
await foo()
case 2:
await foo()
case _ if foo():
await foo()
await trio.lowlevel.checkpoint()


async def match_all_cases() -> None:
match foo():
case 1:
await foo()
case 2:
await foo()
case _:
await foo()


async def match_fallback_await_in_guard() -> None:
# The case guard is only executed if the pattern matches, so we can mostly treat
# it as part of the body, except for a special case for fallback+checkpointing guard.
match foo():
case 1 if await foo():
...
case _ if await foo():
...


async def match_checkpoint_guard() -> None:
# The above pattern is quite cursed, but this seems fairly reasonable to do.
match foo():
case 1 if await foo():
...
case _:
await foo()


async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
...
case _ if await foo():
...
await trio.lowlevel.checkpoint()
31 changes: 31 additions & 0 deletions tests/autofix_files/async91x_py310.py.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
+++
@@ x,6 x,7 @@
...
case _:
await foo()
+ await trio.lowlevel.checkpoint()


async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
@@ x,6 x,7 @@
await foo()
case _ if True:
await foo()
+ await trio.lowlevel.checkpoint()


async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
@@ x,6 x,7 @@
await foo()
case _ if foo():
await foo()
+ await trio.lowlevel.checkpoint()


async def match_all_cases() -> None:
@@ x,3 x,4 @@
...
case _ if await foo():
...
+ await trio.lowlevel.checkpoint()
58 changes: 58 additions & 0 deletions tests/eval_files/async103_104_py310.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Test for ASYNC103/ASYNC104 with structural pattern matching

ASYNC103: no-reraise-cancelled
ASYNC104: cancelled-not-raised
"""

# ARG --enable=ASYNC103,ASYNC104


def foo() -> Any: ...


try:
...
except BaseException as e: # ASYNC103_trio: 7, "BaseException"
match foo():
case True:
raise e
case False:
...
case _:
raise e

try:
...
except BaseException: # ASYNC103_trio: 7, "BaseException"
match foo():
case True:
raise

try:
...
except BaseException: # safe
match foo():
case True:
raise
case False:
raise
case _:
raise
try:
...
except BaseException: # ASYNC103_trio: 7, "BaseException"
match foo():
case _ if foo():
raise
try:
...
except BaseException: # ASYNC103_trio: 7, "BaseException"
match foo():
case 1:
return # ASYNC104: 12
case 2:
raise
case 3:
return # ASYNC104: 12
case blah:
raise
86 changes: 86 additions & 0 deletions tests/eval_files/async91x_py310.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# ARG --enable=ASYNC910,ASYNC911,ASYNC913
# AUTOFIX
# ASYNCIO_NO_AUTOFIX
import trio


async def foo(): ...


async def match_subject() -> None:
match await foo():
case False:
pass


async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
...
case _:
await foo()


async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
await foo()
case 2:
await foo()
case _ if True:
await foo()


async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
await foo()
case 2:
await foo()
case _ if foo():
await foo()


async def match_all_cases() -> None:
match foo():
case 1:
await foo()
case 2:
await foo()
case _:
await foo()


async def match_fallback_await_in_guard() -> None:
# The case guard is only executed if the pattern matches, so we can mostly treat
# it as part of the body, except for a special case for fallback+checkpointing guard.
match foo():
case 1 if await foo():
...
case _ if await foo():
...


async def match_checkpoint_guard() -> None:
# The above pattern is quite cursed, but this seems fairly reasonable to do.
match foo():
case 1 if await foo():
...
case _:
await foo()


async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
None
):
match foo():
case 1:
...
case _ if await foo():
...