@@ -198,6 +198,22 @@ def copy(self):
198198 )
199199
200200
201+ @dataclass
202+ class MatchState :
203+ # TryState, LoopState, and MatchState all do fairly similar things. It would be nice
204+ # to harmonize them and share logic.
205+ base_uncheckpointed_statements : set [Statement ] = field (default_factory = set )
206+ case_uncheckpointed_statements : set [Statement ] = field (default_factory = set )
207+ has_fallback : bool = False
208+
209+ def copy (self ):
210+ return MatchState (
211+ base_uncheckpointed_statements = self .base_uncheckpointed_statements .copy (),
212+ case_uncheckpointed_statements = self .case_uncheckpointed_statements .copy (),
213+ has_fallback = self .has_fallback ,
214+ )
215+
216+
201217def checkpoint_statement (library : str ) -> cst .SimpleStatementLine :
202218 # logic before this should stop code from wanting to insert the non-existing
203219 # asyncio.lowlevel.checkpoint
@@ -373,6 +389,7 @@ def __init__(self, *args: Any, **kwargs: Any):
373389
374390 self .loop_state = LoopState ()
375391 self .try_state = TryState ()
392+ self .match_state = MatchState ()
376393
377394 # ASYNC100
378395 self .has_checkpoint_stack : list [bool ] = []
@@ -894,6 +911,55 @@ def visit_IfExp(self, node: cst.IfExp) -> bool:
894911 self .leave_If (node , node ) # type: ignore
895912 return False # libcst shouldn't visit subnodes again
896913
914+ def leave_Match_subject (self , node : cst .Match ) -> None :
915+ # We start the match logic after parsing the subject, instead of visit_Match,
916+ # since the subject is always executed and might checkpoint.
917+ if not self .async_function :
918+ return
919+ self .save_state (node , "match_state" , copy = True )
920+ self .match_state = MatchState (
921+ base_uncheckpointed_statements = self .uncheckpointed_statements .copy ()
922+ )
923+
924+ def visit_MatchCase (self , node : cst .MatchCase ) -> None :
925+ # enter each case from the state after parsing the subject
926+ self .uncheckpointed_statements = self .match_state .base_uncheckpointed_statements
927+
928+ def leave_MatchCase_guard (self , node : cst .MatchCase ) -> None :
929+ # `case _:` is no pattern and no guard, which means we know body is executed.
930+ # But we also know that `case _ if <guard>:` is guaranteed to execute the guard,
931+ # so for later logic we can treat them the same *if* there's no pattern and that
932+ # guard checkpoints.
933+ if (
934+ isinstance (node .pattern , cst .MatchAs )
935+ and node .pattern .pattern is None
936+ and (node .guard is None or not self .uncheckpointed_statements )
937+ ):
938+ self .match_state .has_fallback = True
939+
940+ def leave_MatchCase (
941+ self , original_node : cst .MatchCase , updated_node : cst .MatchCase
942+ ) -> cst .MatchCase :
943+ # collect the state at the end of each case
944+ self .match_state .case_uncheckpointed_statements .update (
945+ self .uncheckpointed_statements
946+ )
947+ return updated_node
948+
949+ def leave_Match (
950+ self , original_node : cst .Match , updated_node : cst .Match
951+ ) -> cst .Match :
952+ # leave the Match with the worst-case of all branches
953+ self .uncheckpointed_statements = self .match_state .case_uncheckpointed_statements
954+ # if no fallback, also add the state at entering the match (after parsing subject)
955+ if not self .match_state .has_fallback :
956+ self .uncheckpointed_statements .update (
957+ self .match_state .base_uncheckpointed_statements
958+ )
959+
960+ self .restore_state (original_node )
961+ return updated_node
962+
897963 def visit_While (self , node : cst .While | cst .For ):
898964 self .save_state (
899965 node ,
0 commit comments