diff --git a/docs/changelog.rst b/docs/changelog.rst index 1064da0..5ac72f4 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,10 @@ Changelog `CalVer, YY.month.patch `_ +Unreleased +========== +- Autofix for :ref:`ASYNC910 ` / :ref:`ASYNC911 ` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 `); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) `_ + 25.7.1 ====== - :ref:`ASYNC102 ` no longer triggered for asyncio due to different cancellation semantics it uses. diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 1ce0dbd..67f9d51 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -172,6 +172,11 @@ class LoopState: nodes_needing_checkpoints: list[cst.Return | cst.Yield | ArtificialStatement] = ( field(default_factory=list[cst.Return | cst.Yield | ArtificialStatement]) ) + # If a missing checkpoint was detected for a return/yield inside an except + # clause, inserting the checkpoint there would trigger ASYNC120. Instead, + # mark the innermost loop so a checkpoint is inserted at the top of the + # loop body. + needs_checkpoint_at_loop_start: bool = False def copy(self): return LoopState( @@ -182,6 +187,7 @@ def copy(self): uncheckpointed_before_break=self.uncheckpointed_before_break.copy(), artificial_errors=self.artificial_errors.copy(), nodes_needing_checkpoints=self.nodes_needing_checkpoints.copy(), + needs_checkpoint_at_loop_start=self.needs_checkpoint_at_loop_start, ) @@ -341,6 +347,12 @@ def __init__( self.explicitly_imported_library = explicitly_imported self.nodes_needing_checkpoint = nodes_needing_checkpoint self.__library = library + # Depth of except handlers we're currently inside, and a flag set if + # we detected a node that would have been fixed inside an except + # clause — in that case the caller should insert a checkpoint at the + # top of the loop body instead. + self.except_depth = 0 + self.needs_checkpoint_at_loop_start = False @property def library(self) -> tuple[str, ...]: @@ -349,6 +361,22 @@ def library(self) -> tuple[str, ...]: def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: return not self.noautofix + def visit_ExceptHandler( + self, node: cst.ExceptHandler | cst.ExceptStarHandler + ) -> None: + self.except_depth += 1 + + def leave_ExceptHandler( + self, + original_node: cst.ExceptHandler | cst.ExceptStarHandler, + updated_node: cst.ExceptHandler | cst.ExceptStarHandler, + ) -> Any: + self.except_depth -= 1 + return updated_node + + visit_ExceptStarHandler = visit_ExceptHandler + leave_ExceptStarHandler = leave_ExceptHandler + def leave_Yield( self, original_node: cst.Yield, @@ -359,7 +387,15 @@ def leave_Yield( if original_node in self.nodes_needing_checkpoint and self.should_autofix( original_node ): - self.add_statement = checkpoint_statement(self.library[0]) + if self.except_depth > 0: + # Inserting inside an except clause would trigger ASYNC120. + # Signal to the caller to insert at the top of the loop body. + # ensure_imported_library is called by the caller at the + # insertion site so we don't add an import for a checkpoint + # that ends up suppressed. + self.needs_checkpoint_at_loop_start = True + else: + self.add_statement = checkpoint_statement(self.library[0]) return updated_node # returns handled same as yield, but ofc needs to ignore types @@ -410,6 +446,15 @@ def __init__(self, *args: Any, **kwargs: Any): self.try_state = TryState() self.match_state = MatchState() + # Depth of except handlers we're currently inside. Used to avoid + # inserting checkpoints inside an except clause (which would trigger + # ASYNC120). + self.except_depth = 0 + # Set when a missing checkpoint is detected inside an except clause + # for a return/yield that is not inside a loop. We then insert a + # single checkpoint at the top of the function body instead. + self.add_checkpoint_at_function_start = False + # ASYNC100 self.has_checkpoint_stack: list[ContextManager] = [] self.taskgroup_has_start_soon: dict[str, bool] = {} @@ -510,6 +555,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # node_dict is cleaned up and don't need to be saved "taskgroup_has_start_soon", "suppress_imported_as", # a copy is saved, but state is not reset + "except_depth", + "add_checkpoint_at_function_start", copy=True, ) self.uncheckpointed_statements = set() @@ -518,6 +565,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: self.loop_state = LoopState() # try_state is reset upon entering try self.taskgroup_has_start_soon = {} + self.except_depth = 0 + self.add_checkpoint_at_function_start = False self.async_function = ( node.asynchronous is not None @@ -563,6 +612,19 @@ def leave_FunctionDef( self.ensure_imported_library() + if ( + self.add_checkpoint_at_function_start + and self.new_body is not None + and isinstance(self.new_body, cst.IndentedBlock) + ): + # insert checkpoint at the top of body (for missing checkpoints + # detected on return/yield inside an except clause) + new_body_block = list(self.new_body.body) + new_body_block.insert(0, self.checkpoint_statement()) + self.new_body = self.new_body.with_changes(body=new_body_block) + + self.ensure_imported_library() + if self.new_body is not None: updated_node = updated_node.with_changes(body=self.new_body) self.restore_state(original_node) @@ -598,7 +660,15 @@ def check_function_exit( if len(self.uncheckpointed_statements) == 1 and self.should_autofix( original_node ): - self.loop_state.nodes_needing_checkpoints.append(original_node) + if self.except_depth > 0: + # Inserting a checkpoint inside the except clause would + # trigger ASYNC120. Instead mark the innermost loop so a + # checkpoint is inserted at the top of the loop body. + # ensure_imported_library is called at the actual + # insertion site in leave_While. + self.loop_state.needs_checkpoint_at_loop_start = True + else: + self.loop_state.nodes_needing_checkpoints.append(original_node) return False any_errors = False @@ -618,7 +688,7 @@ def leave_Return( if self.check_function_exit(original_node) and self.should_autofix( original_node ): - self.add_statement = self.checkpoint_statement() + self._set_missing_checkpoint_fix() # avoid duplicate error messages # but don't see it as a cancel point for ASYNC100 self.checkpoint_schedule_point() @@ -627,6 +697,49 @@ def leave_Return( assert original_node.deep_equals(updated_node) return original_node + def _set_missing_checkpoint_fix(self) -> None: + """Record where to insert a fix for a detected missing checkpoint. + + Normally we insert the checkpoint right before the offending + return/yield, but if the statement is inside an `except` clause that + would just trigger ASYNC120 (checkpoint inside except). Instead we + insert a checkpoint at a safe location: the top of the innermost + enclosing loop (so yields between iterations are also covered), or + the top of the function body otherwise. + + If the uncheckpointed statements include a prior yield (so the error + would be "yield since prior yield"), neither alternative actually + fixes the path, so we fall back to the old behavior of inserting + before the statement. The resulting checkpoint may trigger ASYNC120 + as a secondary diagnostic, but that is a pre-existing limitation of + the autofix in such contorted cases. + """ + if self.except_depth > 0 and self._can_redirect_except_fix(): + if ARTIFICIAL_STATEMENT in self.uncheckpointed_statements: + # we're inside a loop + self.loop_state.needs_checkpoint_at_loop_start = True + else: + self.add_checkpoint_at_function_start = True + # ensure_imported_library is called at the actual insertion site + # (leave_FunctionDef / leave_While) so we don't add an import for + # a checkpoint that ends up suppressed by noqa. + else: + self.add_statement = self.checkpoint_statement() + + def _can_redirect_except_fix(self) -> bool: + """Check if the missing checkpoint can be fixed at a safe top location. + + Redirecting to the top of the function or of the enclosing loop only + actually fixes the uncheckpointed path if every uncheckpointed + statement is "function definition" or the artificial loop-start + marker. A previous yield indicates a path that the redirected + checkpoint would not cover. + """ + return all( + isinstance(stmt, ArtificialStatement) or stmt.name == "function definition" + for stmt in self.uncheckpointed_statements + ) + def error_91x( self, node: cst.Return | cst.FunctionDef | cst.Yield, @@ -838,7 +951,7 @@ def leave_Yield( if self.check_function_exit(original_node) and self.should_autofix( original_node ): - self.add_statement = self.checkpoint_statement() + self._set_missing_checkpoint_fix() # mark as requiring checkpoint after pos = self.get_metadata(PositionProvider, original_node).start # type: ignore @@ -883,6 +996,7 @@ def visit_ExceptHandler(self, node: cst.ExceptHandler | cst.ExceptStarHandler): self.uncheckpointed_statements = ( self.try_state.body_uncheckpointed_statements.copy() ) + self.except_depth += 1 def leave_ExceptHandler( self, @@ -892,6 +1006,7 @@ def leave_ExceptHandler( self.try_state.except_uncheckpointed_statements.update( self.uncheckpointed_statements ) + self.except_depth -= 1 return updated_node def visit_Try_orelse(self, node: cst.Try | cst.TryStar): @@ -1103,6 +1218,10 @@ def leave_While_body(self, node: cst.For | cst.While): # the potential checkpoints if not any_error: self.loop_state.nodes_needing_checkpoints = [] + # Also clear the loop-top insertion flag; it may have been set by + # return/yield nodes inside an except clause whose errors were + # suppressed via noqa. + self.loop_state.needs_checkpoint_at_loop_start = False if ( self.loop_state.infinite_loop @@ -1181,16 +1300,12 @@ def leave_While( self.restore_state(original_node) return updated_node + insert_at_loop_start = self.loop_state.needs_checkpoint_at_loop_start + # ASYNC913, indefinite loop with no guaranteed checkpoint if self.loop_state.nodes_needing_checkpoints == [ARTIFICIAL_STATEMENT]: if self.should_autofix(original_node, code="ASYNC913"): - # insert checkpoint at start of body - new_body = list(updated_node.body.body) - new_body.insert(0, self.checkpoint_statement()) - indentedblock = updated_node.body.with_changes(body=new_body) - updated_node = updated_node.with_changes(body=indentedblock) - - self.ensure_imported_library() + insert_at_loop_start = True elif self.loop_state.nodes_needing_checkpoints: assert ARTIFICIAL_STATEMENT not in self.loop_state.nodes_needing_checkpoints transformer = InsertCheckpointsInLoopBody( @@ -1207,6 +1322,18 @@ def leave_While( # include any necessary import added self.add_import.update(transformer.add_import) + if transformer.needs_checkpoint_at_loop_start: + insert_at_loop_start = True + + if insert_at_loop_start: + # insert checkpoint at start of body + new_body = list(updated_node.body.body) + new_body.insert(0, self.checkpoint_statement()) + indentedblock = updated_node.body.with_changes(body=new_body) + updated_node = updated_node.with_changes(body=indentedblock) + + self.ensure_imported_library() + self.restore_state(original_node) return updated_node diff --git a/tests/autofix_files/async911.py b/tests/autofix_files/async911.py index a91a322..d4edd07 100644 --- a/tests/autofix_files/async911.py +++ b/tests/autofix_files/async911.py @@ -373,10 +373,10 @@ async def foo_try_1(): # error: 0, "exit", Statement("function definition", lin # no checkpoint after yield in ValueError async def foo_try_2(): # error: 0, "exit", Statement("yield", lineno+5) + await trio.lowlevel.checkpoint() try: await foo() except ValueError: - await trio.lowlevel.checkpoint() # try might not have checkpointed yield # error: 8, "yield", Statement("function definition", lineno-5) except: @@ -398,10 +398,10 @@ async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6) async def foo_try_4(): # safe + await trio.lowlevel.checkpoint() try: ... except: - await trio.lowlevel.checkpoint() yield # error: 8, "yield", Statement("function definition", lineno-4) finally: await foo() diff --git a/tests/autofix_files/async911.py.diff b/tests/autofix_files/async911.py.diff index bf853b9..dee5b7d 100644 --- a/tests/autofix_files/async911.py.diff +++ b/tests/autofix_files/async911.py.diff @@ -251,7 +251,7 @@ async def foo_while_endless_3(): -@@ x,9 x,11 @@ +@@ x,13 x,16 @@ # try async def foo_try_1(): # error: 0, "exit", Statement("function definition", lineno) # error: 0, "exit", Statement("yield", lineno+2) try: @@ -263,14 +263,12 @@ # no checkpoint after yield in ValueError -@@ x,12 x,14 @@ + async def foo_try_2(): # error: 0, "exit", Statement("yield", lineno+5) ++ await trio.lowlevel.checkpoint() try: await foo() except ValueError: -+ await trio.lowlevel.checkpoint() - # try might not have checkpointed - yield # error: 8, "yield", Statement("function definition", lineno-5) - except: +@@ x,6 x,7 @@ await foo() else: pass @@ -278,7 +276,7 @@ async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6) -@@ x,13 x,16 @@ +@@ x,10 x,13 @@ except: await foo() else: @@ -288,13 +286,10 @@ async def foo_try_4(): # safe ++ await trio.lowlevel.checkpoint() try: ... except: -+ await trio.lowlevel.checkpoint() - yield # error: 8, "yield", Statement("function definition", lineno-4) - finally: - await foo() @@ x,6 x,7 @@ try: await foo() diff --git a/tests/autofix_files/async91x_autofix.py b/tests/autofix_files/async91x_autofix.py index 50f1e99..33e36e1 100644 --- a/tests/autofix_files/async91x_autofix.py +++ b/tests/autofix_files/async91x_autofix.py @@ -228,3 +228,52 @@ async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Stat case _ if await foo(): ... await trio.lowlevel.checkpoint() + + +# Issue #403: autofix should not insert checkpoints inside an except clause, +# as that would trigger ASYNC120. Instead insert at top of function / loop. +async def except_return(): + await trio.lowlevel.checkpoint() + try: + await foo() + except ValueError: + return # ASYNC910: 8, "return", Statement("function definition", lineno-4) + + +async def except_yield(): # ASYNC911: 0, "exit", Statement("yield", lineno+4) + await trio.lowlevel.checkpoint() + try: + await foo() + except ValueError: + yield # ASYNC911: 8, "yield", Statement("function definition", lineno-4) + await trio.lowlevel.checkpoint() + + +async def except_return_in_for_loop(): # ASYNC910: 0, "exit", Statement("function definition", lineno) + for _ in range(10): + await trio.lowlevel.checkpoint() + try: + await foo() + except ValueError: + return # ASYNC910: 12, "return", Statement("function definition", lineno-5) + await trio.lowlevel.checkpoint() + + +async def except_yield_in_for_loop(): # ASYNC911: 0, "exit", Statement("yield", lineno+5) + for _ in range(10): + await trio.lowlevel.checkpoint() + try: + await foo() + except ValueError: + yield # ASYNC911: 12, "yield", Statement("function definition", lineno-5) # ASYNC911: 12, "yield", Statement("yield", lineno) + await trio.lowlevel.checkpoint() + + +async def except_nested_in_if(): + await trio.lowlevel.checkpoint() + if bar(): + try: + await foo() + except ValueError: + return # ASYNC910: 12, "return", Statement("function definition", lineno-5) + await foo() diff --git a/tests/autofix_files/async91x_autofix.py.diff b/tests/autofix_files/async91x_autofix.py.diff index 30825a7..30744be 100644 --- a/tests/autofix_files/async91x_autofix.py.diff +++ b/tests/autofix_files/async91x_autofix.py.diff @@ -110,8 +110,54 @@ async def match_all_cases() -> None: -@@ x,3 x,4 @@ +@@ x,11 x,13 @@ ... case _ if await foo(): ... + await trio.lowlevel.checkpoint() + + + # Issue #403: autofix should not insert checkpoints inside an except clause, + # as that would trigger ASYNC120. Instead insert at top of function / loop. + async def except_return(): ++ await trio.lowlevel.checkpoint() + try: + await foo() + except ValueError: +@@ x,29 x,36 @@ + + + async def except_yield(): # ASYNC911: 0, "exit", Statement("yield", lineno+4) ++ await trio.lowlevel.checkpoint() + try: + await foo() + except ValueError: + yield # ASYNC911: 8, "yield", Statement("function definition", lineno-4) ++ await trio.lowlevel.checkpoint() + + + async def except_return_in_for_loop(): # ASYNC910: 0, "exit", Statement("function definition", lineno) + for _ in range(10): ++ await trio.lowlevel.checkpoint() + try: + await foo() + except ValueError: + return # ASYNC910: 12, "return", Statement("function definition", lineno-5) ++ await trio.lowlevel.checkpoint() + + + async def except_yield_in_for_loop(): # ASYNC911: 0, "exit", Statement("yield", lineno+5) + for _ in range(10): ++ await trio.lowlevel.checkpoint() + try: + await foo() + except ValueError: + yield # ASYNC911: 12, "yield", Statement("function definition", lineno-5) # ASYNC911: 12, "yield", Statement("yield", lineno) ++ await trio.lowlevel.checkpoint() + + + async def except_nested_in_if(): ++ await trio.lowlevel.checkpoint() + if bar(): + try: + await foo() diff --git a/tests/eval_files/async91x_autofix.py b/tests/eval_files/async91x_autofix.py index aecf45a..34e70b3 100644 --- a/tests/eval_files/async91x_autofix.py +++ b/tests/eval_files/async91x_autofix.py @@ -208,3 +208,44 @@ async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Stat ... case _ if await foo(): ... + + +# Issue #403: autofix should not insert checkpoints inside an except clause, +# as that would trigger ASYNC120. Instead insert at top of function / loop. +async def except_return(): + try: + await foo() + except ValueError: + return # ASYNC910: 8, "return", Statement("function definition", lineno-4) + + +async def except_yield(): # ASYNC911: 0, "exit", Statement("yield", lineno+4) + try: + await foo() + except ValueError: + yield # ASYNC911: 8, "yield", Statement("function definition", lineno-4) + + +async def except_return_in_for_loop(): # ASYNC910: 0, "exit", Statement("function definition", lineno) + for _ in range(10): + try: + await foo() + except ValueError: + return # ASYNC910: 12, "return", Statement("function definition", lineno-5) + + +async def except_yield_in_for_loop(): # ASYNC911: 0, "exit", Statement("yield", lineno+5) + for _ in range(10): + try: + await foo() + except ValueError: + yield # ASYNC911: 12, "yield", Statement("function definition", lineno-5) # ASYNC911: 12, "yield", Statement("yield", lineno) + + +async def except_nested_in_if(): + if bar(): + try: + await foo() + except ValueError: + return # ASYNC910: 12, "return", Statement("function definition", lineno-5) + await foo()