@@ -198,6 +198,22 @@ def copy(self):
198198 )
199199
200200
201+ @dataclass
202+ class MatchState :
203+ # TryState, LoopState, and MatchState all do fairly similar things. It would be nice
204+ # to harmonize them and share logic.
205+ base_uncheckpointed_statements : set [Statement ] = field (default_factory = set )
206+ case_uncheckpointed_statements : set [Statement ] = field (default_factory = set )
207+ has_fallback : bool = False
208+
209+ def copy (self ):
210+ return MatchState (
211+ base_uncheckpointed_statements = self .base_uncheckpointed_statements .copy (),
212+ case_uncheckpointed_statements = self .case_uncheckpointed_statements .copy (),
213+ has_fallback = self .has_fallback ,
214+ )
215+
216+
201217def checkpoint_statement (library : str ) -> cst .SimpleStatementLine :
202218 # logic before this should stop code from wanting to insert the non-existing
203219 # asyncio.lowlevel.checkpoint
@@ -373,6 +389,7 @@ def __init__(self, *args: Any, **kwargs: Any):
373389
374390 self .loop_state = LoopState ()
375391 self .try_state = TryState ()
392+ self .match_state = MatchState ()
376393
377394 # ASYNC100
378395 self .has_checkpoint_stack : list [bool ] = []
@@ -894,6 +911,53 @@ def visit_IfExp(self, node: cst.IfExp) -> bool:
894911 self .leave_If (node , node ) # type: ignore
895912 return False # libcst shouldn't visit subnodes again
896913
914+ def leave_Match_subject (self , node : cst .Match ) -> None :
915+ # We start the match logic after parsing the subject, instead of visit_Match,
916+ # since the subject is always executed and might checkpoint.
917+ if not self .async_function :
918+ return
919+ self .save_state (node , "match_state" , copy = True )
920+ self .match_state = MatchState (self .uncheckpointed_statements .copy ())
921+
922+ def visit_MatchCase (self , node : cst .MatchCase ) -> None :
923+ # enter each case from the state after parsing the subject
924+ self .uncheckpointed_statements = self .match_state .base_uncheckpointed_statements
925+
926+ def leave_MatchCase_guard (self , node : cst .MatchCase ) -> None :
927+ # `case _:` is no pattern and no guard, which means we know body is executed.
928+ # But we also know that `case _ if <guard>:` is guaranteed to execute the guard,
929+ # so for later logic we can treat them the same *if* there's no pattern and that
930+ # guard checkpoints.
931+ if (
932+ isinstance (node .pattern , cst .MatchAs )
933+ and node .pattern .pattern is None
934+ and (node .guard is None or not self .uncheckpointed_statements )
935+ ):
936+ self .match_state .has_fallback = True
937+
938+ def leave_MatchCase (
939+ self , original_node : cst .MatchCase , updated_node : cst .MatchCase
940+ ) -> cst .MatchCase :
941+ # collect the state at the end of each case
942+ self .match_state .case_uncheckpointed_statements .update (
943+ self .uncheckpointed_statements
944+ )
945+ return updated_node
946+
947+ def leave_Match (
948+ self , original_node : cst .Match , updated_node : cst .Match
949+ ) -> cst .Match :
950+ # leave the Match with the worst-case of all branches
951+ self .uncheckpointed_statements = self .match_state .case_uncheckpointed_statements
952+ # if no fallback, also add the state at entering the match (after parsing subject)
953+ if not self .match_state .has_fallback :
954+ self .uncheckpointed_statements .update (
955+ self .match_state .base_uncheckpointed_statements
956+ )
957+
958+ self .restore_state (original_node )
959+ return updated_node
960+
897961 def visit_While (self , node : cst .While | cst .For ):
898962 self .save_state (
899963 node ,
0 commit comments