diff --git a/docs/changelog.rst b/docs/changelog.rst
index 1064da0..5ac72f4 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -4,6 +4,10 @@ Changelog
`CalVer, YY.month.patch `_
+Unreleased
+==========
+- Autofix for :ref:`ASYNC910 ` / :ref:`ASYNC911 ` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 `); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) `_
+
25.7.1
======
- :ref:`ASYNC102 ` no longer triggered for asyncio due to different cancellation semantics it uses.
diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py
index 1ce0dbd..67f9d51 100644
--- a/flake8_async/visitors/visitor91x.py
+++ b/flake8_async/visitors/visitor91x.py
@@ -172,6 +172,11 @@ class LoopState:
nodes_needing_checkpoints: list[cst.Return | cst.Yield | ArtificialStatement] = (
field(default_factory=list[cst.Return | cst.Yield | ArtificialStatement])
)
+ # If a missing checkpoint was detected for a return/yield inside an except
+ # clause, inserting the checkpoint there would trigger ASYNC120. Instead,
+ # mark the innermost loop so a checkpoint is inserted at the top of the
+ # loop body.
+ needs_checkpoint_at_loop_start: bool = False
def copy(self):
return LoopState(
@@ -182,6 +187,7 @@ def copy(self):
uncheckpointed_before_break=self.uncheckpointed_before_break.copy(),
artificial_errors=self.artificial_errors.copy(),
nodes_needing_checkpoints=self.nodes_needing_checkpoints.copy(),
+ needs_checkpoint_at_loop_start=self.needs_checkpoint_at_loop_start,
)
@@ -341,6 +347,12 @@ def __init__(
self.explicitly_imported_library = explicitly_imported
self.nodes_needing_checkpoint = nodes_needing_checkpoint
self.__library = library
+ # Depth of except handlers we're currently inside, and a flag set if
+ # we detected a node that would have been fixed inside an except
+ # clause — in that case the caller should insert a checkpoint at the
+ # top of the loop body instead.
+ self.except_depth = 0
+ self.needs_checkpoint_at_loop_start = False
@property
def library(self) -> tuple[str, ...]:
@@ -349,6 +361,22 @@ def library(self) -> tuple[str, ...]:
def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
return not self.noautofix
+ def visit_ExceptHandler(
+ self, node: cst.ExceptHandler | cst.ExceptStarHandler
+ ) -> None:
+ self.except_depth += 1
+
+ def leave_ExceptHandler(
+ self,
+ original_node: cst.ExceptHandler | cst.ExceptStarHandler,
+ updated_node: cst.ExceptHandler | cst.ExceptStarHandler,
+ ) -> Any:
+ self.except_depth -= 1
+ return updated_node
+
+ visit_ExceptStarHandler = visit_ExceptHandler
+ leave_ExceptStarHandler = leave_ExceptHandler
+
def leave_Yield(
self,
original_node: cst.Yield,
@@ -359,7 +387,15 @@ def leave_Yield(
if original_node in self.nodes_needing_checkpoint and self.should_autofix(
original_node
):
- self.add_statement = checkpoint_statement(self.library[0])
+ if self.except_depth > 0:
+ # Inserting inside an except clause would trigger ASYNC120.
+ # Signal to the caller to insert at the top of the loop body.
+ # ensure_imported_library is called by the caller at the
+ # insertion site so we don't add an import for a checkpoint
+ # that ends up suppressed.
+ self.needs_checkpoint_at_loop_start = True
+ else:
+ self.add_statement = checkpoint_statement(self.library[0])
return updated_node
# returns handled same as yield, but ofc needs to ignore types
@@ -410,6 +446,15 @@ def __init__(self, *args: Any, **kwargs: Any):
self.try_state = TryState()
self.match_state = MatchState()
+ # Depth of except handlers we're currently inside. Used to avoid
+ # inserting checkpoints inside an except clause (which would trigger
+ # ASYNC120).
+ self.except_depth = 0
+ # Set when a missing checkpoint is detected inside an except clause
+ # for a return/yield that is not inside a loop. We then insert a
+ # single checkpoint at the top of the function body instead.
+ self.add_checkpoint_at_function_start = False
+
# ASYNC100
self.has_checkpoint_stack: list[ContextManager] = []
self.taskgroup_has_start_soon: dict[str, bool] = {}
@@ -510,6 +555,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
# node_dict is cleaned up and don't need to be saved
"taskgroup_has_start_soon",
"suppress_imported_as", # a copy is saved, but state is not reset
+ "except_depth",
+ "add_checkpoint_at_function_start",
copy=True,
)
self.uncheckpointed_statements = set()
@@ -518,6 +565,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
self.loop_state = LoopState()
# try_state is reset upon entering try
self.taskgroup_has_start_soon = {}
+ self.except_depth = 0
+ self.add_checkpoint_at_function_start = False
self.async_function = (
node.asynchronous is not None
@@ -563,6 +612,19 @@ def leave_FunctionDef(
self.ensure_imported_library()
+ if (
+ self.add_checkpoint_at_function_start
+ and self.new_body is not None
+ and isinstance(self.new_body, cst.IndentedBlock)
+ ):
+ # insert checkpoint at the top of body (for missing checkpoints
+ # detected on return/yield inside an except clause)
+ new_body_block = list(self.new_body.body)
+ new_body_block.insert(0, self.checkpoint_statement())
+ self.new_body = self.new_body.with_changes(body=new_body_block)
+
+ self.ensure_imported_library()
+
if self.new_body is not None:
updated_node = updated_node.with_changes(body=self.new_body)
self.restore_state(original_node)
@@ -598,7 +660,15 @@ def check_function_exit(
if len(self.uncheckpointed_statements) == 1 and self.should_autofix(
original_node
):
- self.loop_state.nodes_needing_checkpoints.append(original_node)
+ if self.except_depth > 0:
+ # Inserting a checkpoint inside the except clause would
+ # trigger ASYNC120. Instead mark the innermost loop so a
+ # checkpoint is inserted at the top of the loop body.
+ # ensure_imported_library is called at the actual
+ # insertion site in leave_While.
+ self.loop_state.needs_checkpoint_at_loop_start = True
+ else:
+ self.loop_state.nodes_needing_checkpoints.append(original_node)
return False
any_errors = False
@@ -618,7 +688,7 @@ def leave_Return(
if self.check_function_exit(original_node) and self.should_autofix(
original_node
):
- self.add_statement = self.checkpoint_statement()
+ self._set_missing_checkpoint_fix()
# avoid duplicate error messages
# but don't see it as a cancel point for ASYNC100
self.checkpoint_schedule_point()
@@ -627,6 +697,49 @@ def leave_Return(
assert original_node.deep_equals(updated_node)
return original_node
+ def _set_missing_checkpoint_fix(self) -> None:
+ """Record where to insert a fix for a detected missing checkpoint.
+
+ Normally we insert the checkpoint right before the offending
+ return/yield, but if the statement is inside an `except` clause that
+ would just trigger ASYNC120 (checkpoint inside except). Instead we
+ insert a checkpoint at a safe location: the top of the innermost
+ enclosing loop (so yields between iterations are also covered), or
+ the top of the function body otherwise.
+
+ If the uncheckpointed statements include a prior yield (so the error
+ would be "yield since prior yield"), neither alternative actually
+ fixes the path, so we fall back to the old behavior of inserting
+ before the statement. The resulting checkpoint may trigger ASYNC120
+ as a secondary diagnostic, but that is a pre-existing limitation of
+ the autofix in such contorted cases.
+ """
+ if self.except_depth > 0 and self._can_redirect_except_fix():
+ if ARTIFICIAL_STATEMENT in self.uncheckpointed_statements:
+ # we're inside a loop
+ self.loop_state.needs_checkpoint_at_loop_start = True
+ else:
+ self.add_checkpoint_at_function_start = True
+ # ensure_imported_library is called at the actual insertion site
+ # (leave_FunctionDef / leave_While) so we don't add an import for
+ # a checkpoint that ends up suppressed by noqa.
+ else:
+ self.add_statement = self.checkpoint_statement()
+
+ def _can_redirect_except_fix(self) -> bool:
+ """Check if the missing checkpoint can be fixed at a safe top location.
+
+ Redirecting to the top of the function or of the enclosing loop only
+ actually fixes the uncheckpointed path if every uncheckpointed
+ statement is "function definition" or the artificial loop-start
+ marker. A previous yield indicates a path that the redirected
+ checkpoint would not cover.
+ """
+ return all(
+ isinstance(stmt, ArtificialStatement) or stmt.name == "function definition"
+ for stmt in self.uncheckpointed_statements
+ )
+
def error_91x(
self,
node: cst.Return | cst.FunctionDef | cst.Yield,
@@ -838,7 +951,7 @@ def leave_Yield(
if self.check_function_exit(original_node) and self.should_autofix(
original_node
):
- self.add_statement = self.checkpoint_statement()
+ self._set_missing_checkpoint_fix()
# mark as requiring checkpoint after
pos = self.get_metadata(PositionProvider, original_node).start # type: ignore
@@ -883,6 +996,7 @@ def visit_ExceptHandler(self, node: cst.ExceptHandler | cst.ExceptStarHandler):
self.uncheckpointed_statements = (
self.try_state.body_uncheckpointed_statements.copy()
)
+ self.except_depth += 1
def leave_ExceptHandler(
self,
@@ -892,6 +1006,7 @@ def leave_ExceptHandler(
self.try_state.except_uncheckpointed_statements.update(
self.uncheckpointed_statements
)
+ self.except_depth -= 1
return updated_node
def visit_Try_orelse(self, node: cst.Try | cst.TryStar):
@@ -1103,6 +1218,10 @@ def leave_While_body(self, node: cst.For | cst.While):
# the potential checkpoints
if not any_error:
self.loop_state.nodes_needing_checkpoints = []
+ # Also clear the loop-top insertion flag; it may have been set by
+ # return/yield nodes inside an except clause whose errors were
+ # suppressed via noqa.
+ self.loop_state.needs_checkpoint_at_loop_start = False
if (
self.loop_state.infinite_loop
@@ -1181,16 +1300,12 @@ def leave_While(
self.restore_state(original_node)
return updated_node
+ insert_at_loop_start = self.loop_state.needs_checkpoint_at_loop_start
+
# ASYNC913, indefinite loop with no guaranteed checkpoint
if self.loop_state.nodes_needing_checkpoints == [ARTIFICIAL_STATEMENT]:
if self.should_autofix(original_node, code="ASYNC913"):
- # insert checkpoint at start of body
- new_body = list(updated_node.body.body)
- new_body.insert(0, self.checkpoint_statement())
- indentedblock = updated_node.body.with_changes(body=new_body)
- updated_node = updated_node.with_changes(body=indentedblock)
-
- self.ensure_imported_library()
+ insert_at_loop_start = True
elif self.loop_state.nodes_needing_checkpoints:
assert ARTIFICIAL_STATEMENT not in self.loop_state.nodes_needing_checkpoints
transformer = InsertCheckpointsInLoopBody(
@@ -1207,6 +1322,18 @@ def leave_While(
# include any necessary import added
self.add_import.update(transformer.add_import)
+ if transformer.needs_checkpoint_at_loop_start:
+ insert_at_loop_start = True
+
+ if insert_at_loop_start:
+ # insert checkpoint at start of body
+ new_body = list(updated_node.body.body)
+ new_body.insert(0, self.checkpoint_statement())
+ indentedblock = updated_node.body.with_changes(body=new_body)
+ updated_node = updated_node.with_changes(body=indentedblock)
+
+ self.ensure_imported_library()
+
self.restore_state(original_node)
return updated_node
diff --git a/tests/autofix_files/async911.py b/tests/autofix_files/async911.py
index a91a322..d4edd07 100644
--- a/tests/autofix_files/async911.py
+++ b/tests/autofix_files/async911.py
@@ -373,10 +373,10 @@ async def foo_try_1(): # error: 0, "exit", Statement("function definition", lin
# no checkpoint after yield in ValueError
async def foo_try_2(): # error: 0, "exit", Statement("yield", lineno+5)
+ await trio.lowlevel.checkpoint()
try:
await foo()
except ValueError:
- await trio.lowlevel.checkpoint()
# try might not have checkpointed
yield # error: 8, "yield", Statement("function definition", lineno-5)
except:
@@ -398,10 +398,10 @@ async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6)
async def foo_try_4(): # safe
+ await trio.lowlevel.checkpoint()
try:
...
except:
- await trio.lowlevel.checkpoint()
yield # error: 8, "yield", Statement("function definition", lineno-4)
finally:
await foo()
diff --git a/tests/autofix_files/async911.py.diff b/tests/autofix_files/async911.py.diff
index bf853b9..dee5b7d 100644
--- a/tests/autofix_files/async911.py.diff
+++ b/tests/autofix_files/async911.py.diff
@@ -251,7 +251,7 @@
async def foo_while_endless_3():
-@@ x,9 x,11 @@
+@@ x,13 x,16 @@
# try
async def foo_try_1(): # error: 0, "exit", Statement("function definition", lineno) # error: 0, "exit", Statement("yield", lineno+2)
try:
@@ -263,14 +263,12 @@
# no checkpoint after yield in ValueError
-@@ x,12 x,14 @@
+ async def foo_try_2(): # error: 0, "exit", Statement("yield", lineno+5)
++ await trio.lowlevel.checkpoint()
try:
await foo()
except ValueError:
-+ await trio.lowlevel.checkpoint()
- # try might not have checkpointed
- yield # error: 8, "yield", Statement("function definition", lineno-5)
- except:
+@@ x,6 x,7 @@
await foo()
else:
pass
@@ -278,7 +276,7 @@
async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6)
-@@ x,13 x,16 @@
+@@ x,10 x,13 @@
except:
await foo()
else:
@@ -288,13 +286,10 @@
async def foo_try_4(): # safe
++ await trio.lowlevel.checkpoint()
try:
...
except:
-+ await trio.lowlevel.checkpoint()
- yield # error: 8, "yield", Statement("function definition", lineno-4)
- finally:
- await foo()
@@ x,6 x,7 @@
try:
await foo()
diff --git a/tests/autofix_files/async91x_autofix.py b/tests/autofix_files/async91x_autofix.py
index 50f1e99..33e36e1 100644
--- a/tests/autofix_files/async91x_autofix.py
+++ b/tests/autofix_files/async91x_autofix.py
@@ -228,3 +228,52 @@ async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Stat
case _ if await foo():
...
await trio.lowlevel.checkpoint()
+
+
+# Issue #403: autofix should not insert checkpoints inside an except clause,
+# as that would trigger ASYNC120. Instead insert at top of function / loop.
+async def except_return():
+ await trio.lowlevel.checkpoint()
+ try:
+ await foo()
+ except ValueError:
+ return # ASYNC910: 8, "return", Statement("function definition", lineno-4)
+
+
+async def except_yield(): # ASYNC911: 0, "exit", Statement("yield", lineno+4)
+ await trio.lowlevel.checkpoint()
+ try:
+ await foo()
+ except ValueError:
+ yield # ASYNC911: 8, "yield", Statement("function definition", lineno-4)
+ await trio.lowlevel.checkpoint()
+
+
+async def except_return_in_for_loop(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ for _ in range(10):
+ await trio.lowlevel.checkpoint()
+ try:
+ await foo()
+ except ValueError:
+ return # ASYNC910: 12, "return", Statement("function definition", lineno-5)
+ await trio.lowlevel.checkpoint()
+
+
+async def except_yield_in_for_loop(): # ASYNC911: 0, "exit", Statement("yield", lineno+5)
+ for _ in range(10):
+ await trio.lowlevel.checkpoint()
+ try:
+ await foo()
+ except ValueError:
+ yield # ASYNC911: 12, "yield", Statement("function definition", lineno-5) # ASYNC911: 12, "yield", Statement("yield", lineno)
+ await trio.lowlevel.checkpoint()
+
+
+async def except_nested_in_if():
+ await trio.lowlevel.checkpoint()
+ if bar():
+ try:
+ await foo()
+ except ValueError:
+ return # ASYNC910: 12, "return", Statement("function definition", lineno-5)
+ await foo()
diff --git a/tests/autofix_files/async91x_autofix.py.diff b/tests/autofix_files/async91x_autofix.py.diff
index 30825a7..30744be 100644
--- a/tests/autofix_files/async91x_autofix.py.diff
+++ b/tests/autofix_files/async91x_autofix.py.diff
@@ -110,8 +110,54 @@
async def match_all_cases() -> None:
-@@ x,3 x,4 @@
+@@ x,11 x,13 @@
...
case _ if await foo():
...
+ await trio.lowlevel.checkpoint()
+
+
+ # Issue #403: autofix should not insert checkpoints inside an except clause,
+ # as that would trigger ASYNC120. Instead insert at top of function / loop.
+ async def except_return():
++ await trio.lowlevel.checkpoint()
+ try:
+ await foo()
+ except ValueError:
+@@ x,29 x,36 @@
+
+
+ async def except_yield(): # ASYNC911: 0, "exit", Statement("yield", lineno+4)
++ await trio.lowlevel.checkpoint()
+ try:
+ await foo()
+ except ValueError:
+ yield # ASYNC911: 8, "yield", Statement("function definition", lineno-4)
++ await trio.lowlevel.checkpoint()
+
+
+ async def except_return_in_for_loop(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ for _ in range(10):
++ await trio.lowlevel.checkpoint()
+ try:
+ await foo()
+ except ValueError:
+ return # ASYNC910: 12, "return", Statement("function definition", lineno-5)
++ await trio.lowlevel.checkpoint()
+
+
+ async def except_yield_in_for_loop(): # ASYNC911: 0, "exit", Statement("yield", lineno+5)
+ for _ in range(10):
++ await trio.lowlevel.checkpoint()
+ try:
+ await foo()
+ except ValueError:
+ yield # ASYNC911: 12, "yield", Statement("function definition", lineno-5) # ASYNC911: 12, "yield", Statement("yield", lineno)
++ await trio.lowlevel.checkpoint()
+
+
+ async def except_nested_in_if():
++ await trio.lowlevel.checkpoint()
+ if bar():
+ try:
+ await foo()
diff --git a/tests/eval_files/async91x_autofix.py b/tests/eval_files/async91x_autofix.py
index aecf45a..34e70b3 100644
--- a/tests/eval_files/async91x_autofix.py
+++ b/tests/eval_files/async91x_autofix.py
@@ -208,3 +208,44 @@ async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Stat
...
case _ if await foo():
...
+
+
+# Issue #403: autofix should not insert checkpoints inside an except clause,
+# as that would trigger ASYNC120. Instead insert at top of function / loop.
+async def except_return():
+ try:
+ await foo()
+ except ValueError:
+ return # ASYNC910: 8, "return", Statement("function definition", lineno-4)
+
+
+async def except_yield(): # ASYNC911: 0, "exit", Statement("yield", lineno+4)
+ try:
+ await foo()
+ except ValueError:
+ yield # ASYNC911: 8, "yield", Statement("function definition", lineno-4)
+
+
+async def except_return_in_for_loop(): # ASYNC910: 0, "exit", Statement("function definition", lineno)
+ for _ in range(10):
+ try:
+ await foo()
+ except ValueError:
+ return # ASYNC910: 12, "return", Statement("function definition", lineno-5)
+
+
+async def except_yield_in_for_loop(): # ASYNC911: 0, "exit", Statement("yield", lineno+5)
+ for _ in range(10):
+ try:
+ await foo()
+ except ValueError:
+ yield # ASYNC911: 12, "yield", Statement("function definition", lineno-5) # ASYNC911: 12, "yield", Statement("yield", lineno)
+
+
+async def except_nested_in_if():
+ if bar():
+ try:
+ await foo()
+ except ValueError:
+ return # ASYNC910: 12, "return", Statement("function definition", lineno-5)
+ await foo()