@@ -172,6 +172,11 @@ class LoopState:
172172 nodes_needing_checkpoints : list [cst .Return | cst .Yield | ArtificialStatement ] = (
173173 field (default_factory = list [cst .Return | cst .Yield | ArtificialStatement ])
174174 )
175+ # If a missing checkpoint was detected for a return/yield inside an except
176+ # clause, inserting the checkpoint there would trigger ASYNC120. Instead,
177+ # mark the innermost loop so a checkpoint is inserted at the top of the
178+ # loop body.
179+ needs_checkpoint_at_loop_start : bool = False
175180
176181 def copy (self ):
177182 return LoopState (
@@ -182,6 +187,7 @@ def copy(self):
182187 uncheckpointed_before_break = self .uncheckpointed_before_break .copy (),
183188 artificial_errors = self .artificial_errors .copy (),
184189 nodes_needing_checkpoints = self .nodes_needing_checkpoints .copy (),
190+ needs_checkpoint_at_loop_start = self .needs_checkpoint_at_loop_start ,
185191 )
186192
187193
@@ -341,6 +347,12 @@ def __init__(
341347 self .explicitly_imported_library = explicitly_imported
342348 self .nodes_needing_checkpoint = nodes_needing_checkpoint
343349 self .__library = library
350+ # Depth of except handlers we're currently inside, and a flag set if
351+ # we detected a node that would have been fixed inside an except
352+ # clause — in that case the caller should insert a checkpoint at the
353+ # top of the loop body instead.
354+ self .except_depth = 0
355+ self .needs_checkpoint_at_loop_start = False
344356
345357 @property
346358 def library (self ) -> tuple [str , ...]:
@@ -349,6 +361,22 @@ def library(self) -> tuple[str, ...]:
349361 def should_autofix (self , node : cst .CSTNode , code : str | None = None ) -> bool :
350362 return not self .noautofix
351363
364+ def visit_ExceptHandler (
365+ self , node : cst .ExceptHandler | cst .ExceptStarHandler
366+ ) -> None :
367+ self .except_depth += 1
368+
369+ def leave_ExceptHandler (
370+ self ,
371+ original_node : cst .ExceptHandler | cst .ExceptStarHandler ,
372+ updated_node : cst .ExceptHandler | cst .ExceptStarHandler ,
373+ ) -> Any :
374+ self .except_depth -= 1
375+ return updated_node
376+
377+ visit_ExceptStarHandler = visit_ExceptHandler
378+ leave_ExceptStarHandler = leave_ExceptHandler
379+
352380 def leave_Yield (
353381 self ,
354382 original_node : cst .Yield ,
@@ -359,7 +387,15 @@ def leave_Yield(
359387 if original_node in self .nodes_needing_checkpoint and self .should_autofix (
360388 original_node
361389 ):
362- self .add_statement = checkpoint_statement (self .library [0 ])
390+ if self .except_depth > 0 :
391+ # Inserting inside an except clause would trigger ASYNC120.
392+ # Signal to the caller to insert at the top of the loop body.
393+ # ensure_imported_library is called by the caller at the
394+ # insertion site so we don't add an import for a checkpoint
395+ # that ends up suppressed.
396+ self .needs_checkpoint_at_loop_start = True
397+ else :
398+ self .add_statement = checkpoint_statement (self .library [0 ])
363399 return updated_node
364400
365401 # returns handled same as yield, but ofc needs to ignore types
@@ -410,6 +446,15 @@ def __init__(self, *args: Any, **kwargs: Any):
410446 self .try_state = TryState ()
411447 self .match_state = MatchState ()
412448
449+ # Depth of except handlers we're currently inside. Used to avoid
450+ # inserting checkpoints inside an except clause (which would trigger
451+ # ASYNC120).
452+ self .except_depth = 0
453+ # Set when a missing checkpoint is detected inside an except clause
454+ # for a return/yield that is not inside a loop. We then insert a
455+ # single checkpoint at the top of the function body instead.
456+ self .add_checkpoint_at_function_start = False
457+
413458 # ASYNC100
414459 self .has_checkpoint_stack : list [ContextManager ] = []
415460 self .taskgroup_has_start_soon : dict [str , bool ] = {}
@@ -510,6 +555,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
510555 # node_dict is cleaned up and don't need to be saved
511556 "taskgroup_has_start_soon" ,
512557 "suppress_imported_as" , # a copy is saved, but state is not reset
558+ "except_depth" ,
559+ "add_checkpoint_at_function_start" ,
513560 copy = True ,
514561 )
515562 self .uncheckpointed_statements = set ()
@@ -518,6 +565,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
518565 self .loop_state = LoopState ()
519566 # try_state is reset upon entering try
520567 self .taskgroup_has_start_soon = {}
568+ self .except_depth = 0
569+ self .add_checkpoint_at_function_start = False
521570
522571 self .async_function = (
523572 node .asynchronous is not None
@@ -563,6 +612,19 @@ def leave_FunctionDef(
563612
564613 self .ensure_imported_library ()
565614
615+ if (
616+ self .add_checkpoint_at_function_start
617+ and self .new_body is not None
618+ and isinstance (self .new_body , cst .IndentedBlock )
619+ ):
620+ # insert checkpoint at the top of body (for missing checkpoints
621+ # detected on return/yield inside an except clause)
622+ new_body_block = list (self .new_body .body )
623+ new_body_block .insert (0 , self .checkpoint_statement ())
624+ self .new_body = self .new_body .with_changes (body = new_body_block )
625+
626+ self .ensure_imported_library ()
627+
566628 if self .new_body is not None :
567629 updated_node = updated_node .with_changes (body = self .new_body )
568630 self .restore_state (original_node )
@@ -598,7 +660,15 @@ def check_function_exit(
598660 if len (self .uncheckpointed_statements ) == 1 and self .should_autofix (
599661 original_node
600662 ):
601- self .loop_state .nodes_needing_checkpoints .append (original_node )
663+ if self .except_depth > 0 :
664+ # Inserting a checkpoint inside the except clause would
665+ # trigger ASYNC120. Instead mark the innermost loop so a
666+ # checkpoint is inserted at the top of the loop body.
667+ # ensure_imported_library is called at the actual
668+ # insertion site in leave_While.
669+ self .loop_state .needs_checkpoint_at_loop_start = True
670+ else :
671+ self .loop_state .nodes_needing_checkpoints .append (original_node )
602672 return False
603673
604674 any_errors = False
@@ -618,7 +688,7 @@ def leave_Return(
618688 if self .check_function_exit (original_node ) and self .should_autofix (
619689 original_node
620690 ):
621- self .add_statement = self . checkpoint_statement ()
691+ self ._set_missing_checkpoint_fix ()
622692 # avoid duplicate error messages
623693 # but don't see it as a cancel point for ASYNC100
624694 self .checkpoint_schedule_point ()
@@ -627,6 +697,49 @@ def leave_Return(
627697 assert original_node .deep_equals (updated_node )
628698 return original_node
629699
700+ def _set_missing_checkpoint_fix (self ) -> None :
701+ """Record where to insert a fix for a detected missing checkpoint.
702+
703+ Normally we insert the checkpoint right before the offending
704+ return/yield, but if the statement is inside an `except` clause that
705+ would just trigger ASYNC120 (checkpoint inside except). Instead we
706+ insert a checkpoint at a safe location: the top of the innermost
707+ enclosing loop (so yields between iterations are also covered), or
708+ the top of the function body otherwise.
709+
710+ If the uncheckpointed statements include a prior yield (so the error
711+ would be "yield since prior yield"), neither alternative actually
712+ fixes the path, so we fall back to the old behavior of inserting
713+ before the statement. The resulting checkpoint may trigger ASYNC120
714+ as a secondary diagnostic, but that is a pre-existing limitation of
715+ the autofix in such contorted cases.
716+ """
717+ if self .except_depth > 0 and self ._can_redirect_except_fix ():
718+ if ARTIFICIAL_STATEMENT in self .uncheckpointed_statements :
719+ # we're inside a loop
720+ self .loop_state .needs_checkpoint_at_loop_start = True
721+ else :
722+ self .add_checkpoint_at_function_start = True
723+ # ensure_imported_library is called at the actual insertion site
724+ # (leave_FunctionDef / leave_While) so we don't add an import for
725+ # a checkpoint that ends up suppressed by noqa.
726+ else :
727+ self .add_statement = self .checkpoint_statement ()
728+
729+ def _can_redirect_except_fix (self ) -> bool :
730+ """Check if the missing checkpoint can be fixed at a safe top location.
731+
732+ Redirecting to the top of the function or of the enclosing loop only
733+ actually fixes the uncheckpointed path if every uncheckpointed
734+ statement is "function definition" or the artificial loop-start
735+ marker. A previous yield indicates a path that the redirected
736+ checkpoint would not cover.
737+ """
738+ return all (
739+ isinstance (stmt , ArtificialStatement ) or stmt .name == "function definition"
740+ for stmt in self .uncheckpointed_statements
741+ )
742+
630743 def error_91x (
631744 self ,
632745 node : cst .Return | cst .FunctionDef | cst .Yield ,
@@ -838,7 +951,7 @@ def leave_Yield(
838951 if self .check_function_exit (original_node ) and self .should_autofix (
839952 original_node
840953 ):
841- self .add_statement = self . checkpoint_statement ()
954+ self ._set_missing_checkpoint_fix ()
842955
843956 # mark as requiring checkpoint after
844957 pos = self .get_metadata (PositionProvider , original_node ).start # type: ignore
@@ -883,6 +996,7 @@ def visit_ExceptHandler(self, node: cst.ExceptHandler | cst.ExceptStarHandler):
883996 self .uncheckpointed_statements = (
884997 self .try_state .body_uncheckpointed_statements .copy ()
885998 )
999+ self .except_depth += 1
8861000
8871001 def leave_ExceptHandler (
8881002 self ,
@@ -892,6 +1006,7 @@ def leave_ExceptHandler(
8921006 self .try_state .except_uncheckpointed_statements .update (
8931007 self .uncheckpointed_statements
8941008 )
1009+ self .except_depth -= 1
8951010 return updated_node
8961011
8971012 def visit_Try_orelse (self , node : cst .Try | cst .TryStar ):
@@ -1103,6 +1218,10 @@ def leave_While_body(self, node: cst.For | cst.While):
11031218 # the potential checkpoints
11041219 if not any_error :
11051220 self .loop_state .nodes_needing_checkpoints = []
1221+ # Also clear the loop-top insertion flag; it may have been set by
1222+ # return/yield nodes inside an except clause whose errors were
1223+ # suppressed via noqa.
1224+ self .loop_state .needs_checkpoint_at_loop_start = False
11061225
11071226 if (
11081227 self .loop_state .infinite_loop
@@ -1181,16 +1300,12 @@ def leave_While(
11811300 self .restore_state (original_node )
11821301 return updated_node
11831302
1303+ insert_at_loop_start = self .loop_state .needs_checkpoint_at_loop_start
1304+
11841305 # ASYNC913, indefinite loop with no guaranteed checkpoint
11851306 if self .loop_state .nodes_needing_checkpoints == [ARTIFICIAL_STATEMENT ]:
11861307 if self .should_autofix (original_node , code = "ASYNC913" ):
1187- # insert checkpoint at start of body
1188- new_body = list (updated_node .body .body )
1189- new_body .insert (0 , self .checkpoint_statement ())
1190- indentedblock = updated_node .body .with_changes (body = new_body )
1191- updated_node = updated_node .with_changes (body = indentedblock )
1192-
1193- self .ensure_imported_library ()
1308+ insert_at_loop_start = True
11941309 elif self .loop_state .nodes_needing_checkpoints :
11951310 assert ARTIFICIAL_STATEMENT not in self .loop_state .nodes_needing_checkpoints
11961311 transformer = InsertCheckpointsInLoopBody (
@@ -1207,6 +1322,18 @@ def leave_While(
12071322 # include any necessary import added
12081323 self .add_import .update (transformer .add_import )
12091324
1325+ if transformer .needs_checkpoint_at_loop_start :
1326+ insert_at_loop_start = True
1327+
1328+ if insert_at_loop_start :
1329+ # insert checkpoint at start of body
1330+ new_body = list (updated_node .body .body )
1331+ new_body .insert (0 , self .checkpoint_statement ())
1332+ indentedblock = updated_node .body .with_changes (body = new_body )
1333+ updated_node = updated_node .with_changes (body = indentedblock )
1334+
1335+ self .ensure_imported_library ()
1336+
12101337 self .restore_state (original_node )
12111338 return updated_node
12121339
0 commit comments